from __future__ import annotations import asyncio import logging import threading import time from collections.abc import Coroutine from inspect import isawaitable from typing import Any from .errors import MQConfigError from .message import MQMessage from .registry import MQConsumerDefinition, MQRegistry from .serialization import decode_message_key, decode_message_value logger = logging.getLogger("iti.mq") class MQConsumerRunner: def __init__( self, app, backend, registry: MQRegistry, *, group_id: str | None = None, failure_backoff_seconds: float = 1.0, poll_timeout_seconds: float = 1.0, ) -> None: self.app = app self.backend = backend self.registry = registry self.group_id = group_id self.failure_backoff_seconds = failure_backoff_seconds self.poll_timeout_seconds = poll_timeout_seconds self._workers: list[_ConsumerWorker] = [] def start(self) -> None: if self._workers: return pending_workers: list[_ConsumerWorker] = [] try: for definition in self.registry.consumers.values(): group_id = definition.group_id or self.group_id if not group_id: raise MQConfigError(f"mq consumer {definition.name} missing group_id") consumer = self.backend.create_consumer(group_id, definition.config) consumer.subscribe(list(definition.topics)) worker = _ConsumerWorker( app=self.app, consumer=consumer, definition=definition, failure_backoff_seconds=( definition.failure_backoff_seconds if definition.failure_backoff_seconds is not None else self.failure_backoff_seconds ), poll_timeout_seconds=self.poll_timeout_seconds, ) pending_workers.append(worker) except Exception: for worker in pending_workers: worker.close() raise self._workers = pending_workers for worker in self._workers: worker.start() def stop(self) -> None: for worker in self._workers: worker.stop() for worker in self._workers: worker.join() self._workers.clear() class _ConsumerWorker: def __init__( self, *, app, consumer: Any, definition: MQConsumerDefinition, failure_backoff_seconds: float, poll_timeout_seconds: float, ) -> None: self.app = app self.consumer = consumer self.definition = definition self.failure_backoff_seconds = failure_backoff_seconds self.poll_timeout_seconds = poll_timeout_seconds self._stop = threading.Event() self._thread: threading.Thread | None = None def start(self) -> None: if self._thread and self._thread.is_alive(): return self._thread = threading.Thread(target=self._loop, daemon=True) self._thread.start() def stop(self) -> None: self._stop.set() def join(self) -> None: if self._thread: self._thread.join(timeout=3) self.close() def close(self) -> None: self.consumer.close() def _loop(self) -> None: while not self._stop.is_set(): raw_message = self.consumer.poll(self.poll_timeout_seconds) if raw_message is None: continue if raw_message.error(): logger.warning("mq consumer error: %s", raw_message.error()) continue self._handle_raw_message(raw_message) def _handle_raw_message(self, raw_message: Any) -> None: try: message = self._build_message(raw_message) result = self.definition.handler(message) if isawaitable(result): if not isinstance(result, Coroutine): raise TypeError("mq async handler must return a coroutine") asyncio.run(result) self.consumer.commit(raw_message, asynchronous=False) except Exception: logger.exception( "mq handler failed name=%s topic=%s partition=%s offset=%s", self.definition.name, _safe_call(raw_message, "topic"), _safe_call(raw_message, "partition"), _safe_call(raw_message, "offset"), ) self.consumer.seek(_seek_position(raw_message)) self._stop.wait(self.failure_backoff_seconds) def _build_message(self, raw_message: Any) -> MQMessage: raw_key = raw_message.key() raw_value = raw_message.value() return MQMessage( app=self.app, topic=raw_message.topic(), partition=raw_message.partition(), offset=raw_message.offset(), key=decode_message_key(raw_key), raw_key=raw_key, value=decode_message_value(raw_value, self.definition.value_format), raw_value=raw_value, headers=_headers_to_dict(raw_message.headers()), timestamp=raw_message.timestamp(), raw_message=raw_message, ) def _headers_to_dict(headers: list[tuple[str, bytes | None]] | None) -> dict[str, bytes | None]: return {key: value for key, value in headers or []} def _safe_call(value: Any, method: str) -> Any: try: return getattr(value, method)() except Exception: return "-" def _seek_position(raw_message: Any) -> Any: try: from confluent_kafka import TopicPartition except ImportError: return raw_message return TopicPartition( raw_message.topic(), raw_message.partition(), raw_message.offset(), )