from __future__ import annotations import asyncio import logging import threading import time from collections.abc import Coroutine from dataclasses import dataclass from inspect import isawaitable from pathlib import Path from typing import Any from .errors import MQConfigError from .message import MQMessage from .offset_store import MQOffsetStore, create_offset_store from .registry import MQConsumerDefinition, MQRegistry from .serialization import decode_message_key, decode_message_value logger = logging.getLogger("iti.mq") class MQConsumerRunner: def __init__( self, app, backend, registry: MQRegistry, *, group_id: str | None = None, consumer_mode: str = "subscribe", offset_store_config: dict[str, Any] | None = None, offset_store_path: str | Path | None = None, failure_backoff_seconds: float = 1.0, poll_timeout_seconds: float = 1.0, ) -> None: self.app = app self.backend = backend self.registry = registry self.group_id = group_id self.consumer_mode = _normalize_consumer_mode(consumer_mode) self.offset_store_config = dict(offset_store_config or {}) self.offset_store_path = offset_store_path self.failure_backoff_seconds = failure_backoff_seconds self.poll_timeout_seconds = poll_timeout_seconds self._workers: list[_ConsumerWorker] = [] def start(self) -> None: if self._workers: return pending_workers: list[_ConsumerWorker] = [] try: for definition in self.registry.consumers.values(): group_id = definition.group_id or self.group_id mode = definition.mode or self.consumer_mode if mode == "assign": consumer = self.backend.create_consumer(group_id or definition.name, definition.config) offset_store = self._create_offset_store(definition) self._assign_consumer(consumer, definition, offset_store) logger.info( "mq consumer started name=%s mode=assign topics=%s", definition.name, ",".join(definition.topics), ) else: if not group_id: raise MQConfigError(f"mq consumer {definition.name} missing group_id") consumer = self.backend.create_consumer(group_id, definition.config) consumer.subscribe(list(definition.topics)) offset_store = None logger.info( "mq consumer started name=%s mode=subscribe group_id=%s topics=%s", definition.name, group_id, ",".join(definition.topics), ) worker = _ConsumerWorker( app=self.app, consumer=consumer, definition=definition, offset_store=offset_store, auto_offset_reset=str(self.backend.config.get("auto_offset_reset", "earliest")), failure_backoff_seconds=( definition.failure_backoff_seconds if definition.failure_backoff_seconds is not None else self.failure_backoff_seconds ), poll_timeout_seconds=self.poll_timeout_seconds, ) pending_workers.append(worker) except Exception: for worker in pending_workers: worker.close() raise self._workers = pending_workers for worker in self._workers: worker.start() def stop(self) -> None: for worker in self._workers: worker.stop() for worker in self._workers: worker.join() self._workers.clear() def _create_offset_store(self, definition: MQConsumerDefinition) -> MQOffsetStore: config = dict(self.offset_store_config) config.update(definition.offset_store) default_path = self.offset_store_path if default_path is None: base_dir = getattr(getattr(self.app, "state", None), "config", None) base_dir = getattr(base_dir, "base_dir", Path.cwd()) default_path = Path(base_dir) / "runtime" / "mq_offsets.sqlite" try: return create_offset_store( config, default_path=default_path, ) except ValueError as exc: raise MQConfigError(str(exc)) from exc def _assign_consumer( self, consumer: Any, definition: MQConsumerDefinition, offset_store: MQOffsetStore, ) -> None: partitions = _resolve_partitions(consumer, definition) if not partitions: raise MQConfigError(f"mq consumer {definition.name} has no partitions to assign") consumer.assign( [ self._topic_partition(definition.name, topic, partition, offset_store) for topic, partition in partitions ] ) def _topic_partition( self, consumer_name: str, topic: str, partition: int, offset_store: MQOffsetStore, ) -> Any: offset = offset_store.get(consumer_name, topic, partition) try: from confluent_kafka import TopicPartition except ImportError: logger.debug("confluent-kafka unavailable; using test topic partition fallback") return _AssignedPartition(topic, partition, offset if offset is not None else _auto_offset(self)) return TopicPartition(topic, partition, offset if offset is not None else _auto_offset(self)) class _ConsumerWorker: def __init__( self, *, app, consumer: Any, definition: MQConsumerDefinition, offset_store: MQOffsetStore | None, auto_offset_reset: str, failure_backoff_seconds: float, poll_timeout_seconds: float, ) -> None: self.app = app self.consumer = consumer self.definition = definition self.offset_store = offset_store self.auto_offset_reset = auto_offset_reset self.failure_backoff_seconds = failure_backoff_seconds self.poll_timeout_seconds = poll_timeout_seconds self._stop = threading.Event() self._thread: threading.Thread | None = None def start(self) -> None: if self._thread and self._thread.is_alive(): return self._thread = threading.Thread(target=self._loop, daemon=True) self._thread.start() def stop(self) -> None: self._stop.set() def join(self) -> None: if self._thread: self._thread.join(timeout=3) self.close() def close(self) -> None: self.consumer.close() if self.offset_store is not None: self.offset_store.close() def _loop(self) -> None: while not self._stop.is_set(): raw_message = self.consumer.poll(self.poll_timeout_seconds) if raw_message is None: continue if raw_message.error(): logger.warning("mq consumer error: %s", raw_message.error()) _handle_offset_error( self.consumer, raw_message, self.definition, self.offset_store, self.auto_offset_reset, ) continue self._handle_raw_message(raw_message) def _handle_raw_message(self, raw_message: Any) -> None: try: message = self._build_message(raw_message) result = self.definition.handler(message) if isawaitable(result): if not isinstance(result, Coroutine): raise TypeError("mq async handler must return a coroutine") asyncio.run(result) if self.offset_store is None: self.consumer.commit(raw_message, asynchronous=False) else: self.offset_store.set( self.definition.name, raw_message.topic(), raw_message.partition(), raw_message.offset() + 1, ) except Exception: logger.exception( "mq handler failed name=%s topic=%s partition=%s offset=%s", self.definition.name, _safe_call(raw_message, "topic"), _safe_call(raw_message, "partition"), _safe_call(raw_message, "offset"), ) self.consumer.seek(_seek_position(raw_message)) self._stop.wait(self.failure_backoff_seconds) def _build_message(self, raw_message: Any) -> MQMessage: raw_key = raw_message.key() raw_value = raw_message.value() return MQMessage( app=self.app, topic=raw_message.topic(), partition=raw_message.partition(), offset=raw_message.offset(), key=decode_message_key(raw_key), raw_key=raw_key, value=decode_message_value(raw_value, self.definition.value_format), raw_value=raw_value, headers=_headers_to_dict(raw_message.headers()), timestamp=raw_message.timestamp(), raw_message=raw_message, ) def _headers_to_dict(headers: list[tuple[str, bytes | None]] | None) -> dict[str, bytes | None]: return {key: value for key, value in headers or []} def _safe_call(value: Any, method: str) -> Any: try: return getattr(value, method)() except Exception: return "-" def _seek_position(raw_message: Any) -> Any: try: from confluent_kafka import TopicPartition except ImportError: return raw_message return TopicPartition( raw_message.topic(), raw_message.partition(), raw_message.offset(), ) def _resolve_partitions(consumer: Any, definition: MQConsumerDefinition) -> list[tuple[str, int]]: if definition.partitions == "all": metadata = consumer.list_topics(timeout=10) result: list[tuple[str, int]] = [] for topic in definition.topics: topic_metadata = metadata.topics.get(topic) if topic_metadata is None or getattr(topic_metadata, "error", None): raise MQConfigError(f"mq topic metadata unavailable: {topic}") result.extend((topic, partition) for partition in sorted(topic_metadata.partitions)) return result if isinstance(definition.partitions, dict): return [ (topic, int(partition)) for topic, partitions in definition.partitions.items() for partition in partitions ] result = [] for value in definition.partitions: if isinstance(value, tuple) and len(value) == 2: result.append((str(value[0]), int(value[1]))) continue if isinstance(value, dict): result.append((str(value["topic"]), int(value["partition"]))) continue raise MQConfigError(f"unsupported mq partitions item: {value!r}") return result def _auto_offset(runner: MQConsumerRunner) -> int: try: from confluent_kafka import OFFSET_BEGINNING, OFFSET_END except ImportError: OFFSET_BEGINNING = -2 OFFSET_END = -1 return OFFSET_END if str(runner.backend.config.get("auto_offset_reset")) == "latest" else OFFSET_BEGINNING def _handle_offset_error( consumer: Any, raw_message: Any, definition: MQConsumerDefinition, offset_store: MQOffsetStore | None, auto_offset_reset: str, ) -> None: if offset_store is None: return error = raw_message.error() code = error.code() if hasattr(error, "code") else None try: from confluent_kafka import KafkaError, TopicPartition except ImportError: return if code != KafkaError._OFFSET_OUT_OF_RANGE: return topic = raw_message.topic() partition = raw_message.partition() try: low, high = consumer.get_watermark_offsets( TopicPartition(topic, partition), timeout=10, cached=False, ) except Exception: low = 0 offset = high if auto_offset_reset == "latest" else low offset_store.set(definition.name, topic, partition, offset) consumer.seek(TopicPartition(topic, partition, offset)) def _normalize_consumer_mode(mode: str) -> str: mode = str(mode or "subscribe").strip().lower() if mode not in {"subscribe", "assign"}: raise MQConfigError("mq consumer_mode must be 'subscribe' or 'assign'") return mode @dataclass(frozen=True) class _AssignedPartition: topic: str partition: int offset: int