You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
182 lines
5.8 KiB
Python
182 lines
5.8 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import threading
|
|
import time
|
|
from collections.abc import Coroutine
|
|
from inspect import isawaitable
|
|
from typing import Any
|
|
|
|
from .errors import MQConfigError
|
|
from .message import MQMessage
|
|
from .registry import MQConsumerDefinition, MQRegistry
|
|
from .serialization import decode_message_key, decode_message_value
|
|
|
|
|
|
logger = logging.getLogger("iti.mq")
|
|
|
|
|
|
class MQConsumerRunner:
|
|
def __init__(
|
|
self,
|
|
app,
|
|
backend,
|
|
registry: MQRegistry,
|
|
*,
|
|
group_id: str | None = None,
|
|
failure_backoff_seconds: float = 1.0,
|
|
poll_timeout_seconds: float = 1.0,
|
|
) -> None:
|
|
self.app = app
|
|
self.backend = backend
|
|
self.registry = registry
|
|
self.group_id = group_id
|
|
self.failure_backoff_seconds = failure_backoff_seconds
|
|
self.poll_timeout_seconds = poll_timeout_seconds
|
|
self._workers: list[_ConsumerWorker] = []
|
|
|
|
def start(self) -> None:
|
|
if self._workers:
|
|
return
|
|
pending_workers: list[_ConsumerWorker] = []
|
|
try:
|
|
for definition in self.registry.consumers.values():
|
|
group_id = definition.group_id or self.group_id
|
|
if not group_id:
|
|
raise MQConfigError(f"mq consumer {definition.name} missing group_id")
|
|
consumer = self.backend.create_consumer(group_id, definition.config)
|
|
consumer.subscribe(list(definition.topics))
|
|
worker = _ConsumerWorker(
|
|
app=self.app,
|
|
consumer=consumer,
|
|
definition=definition,
|
|
failure_backoff_seconds=(
|
|
definition.failure_backoff_seconds
|
|
if definition.failure_backoff_seconds is not None
|
|
else self.failure_backoff_seconds
|
|
),
|
|
poll_timeout_seconds=self.poll_timeout_seconds,
|
|
)
|
|
pending_workers.append(worker)
|
|
except Exception:
|
|
for worker in pending_workers:
|
|
worker.close()
|
|
raise
|
|
self._workers = pending_workers
|
|
for worker in self._workers:
|
|
worker.start()
|
|
|
|
def stop(self) -> None:
|
|
for worker in self._workers:
|
|
worker.stop()
|
|
for worker in self._workers:
|
|
worker.join()
|
|
self._workers.clear()
|
|
|
|
|
|
class _ConsumerWorker:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
app,
|
|
consumer: Any,
|
|
definition: MQConsumerDefinition,
|
|
failure_backoff_seconds: float,
|
|
poll_timeout_seconds: float,
|
|
) -> None:
|
|
self.app = app
|
|
self.consumer = consumer
|
|
self.definition = definition
|
|
self.failure_backoff_seconds = failure_backoff_seconds
|
|
self.poll_timeout_seconds = poll_timeout_seconds
|
|
self._stop = threading.Event()
|
|
self._thread: threading.Thread | None = None
|
|
|
|
def start(self) -> None:
|
|
if self._thread and self._thread.is_alive():
|
|
return
|
|
self._thread = threading.Thread(target=self._loop, daemon=True)
|
|
self._thread.start()
|
|
|
|
def stop(self) -> None:
|
|
self._stop.set()
|
|
|
|
def join(self) -> None:
|
|
if self._thread:
|
|
self._thread.join(timeout=3)
|
|
self.close()
|
|
|
|
def close(self) -> None:
|
|
self.consumer.close()
|
|
|
|
def _loop(self) -> None:
|
|
while not self._stop.is_set():
|
|
raw_message = self.consumer.poll(self.poll_timeout_seconds)
|
|
if raw_message is None:
|
|
continue
|
|
if raw_message.error():
|
|
logger.warning("mq consumer error: %s", raw_message.error())
|
|
continue
|
|
self._handle_raw_message(raw_message)
|
|
|
|
def _handle_raw_message(self, raw_message: Any) -> None:
|
|
try:
|
|
message = self._build_message(raw_message)
|
|
result = self.definition.handler(message)
|
|
if isawaitable(result):
|
|
if not isinstance(result, Coroutine):
|
|
raise TypeError("mq async handler must return a coroutine")
|
|
asyncio.run(result)
|
|
self.consumer.commit(raw_message, asynchronous=False)
|
|
except Exception:
|
|
logger.exception(
|
|
"mq handler failed name=%s topic=%s partition=%s offset=%s",
|
|
self.definition.name,
|
|
_safe_call(raw_message, "topic"),
|
|
_safe_call(raw_message, "partition"),
|
|
_safe_call(raw_message, "offset"),
|
|
)
|
|
self.consumer.seek(_seek_position(raw_message))
|
|
self._stop.wait(self.failure_backoff_seconds)
|
|
|
|
def _build_message(self, raw_message: Any) -> MQMessage:
|
|
raw_key = raw_message.key()
|
|
raw_value = raw_message.value()
|
|
return MQMessage(
|
|
app=self.app,
|
|
topic=raw_message.topic(),
|
|
partition=raw_message.partition(),
|
|
offset=raw_message.offset(),
|
|
key=decode_message_key(raw_key),
|
|
raw_key=raw_key,
|
|
value=decode_message_value(raw_value, self.definition.value_format),
|
|
raw_value=raw_value,
|
|
headers=_headers_to_dict(raw_message.headers()),
|
|
timestamp=raw_message.timestamp(),
|
|
raw_message=raw_message,
|
|
)
|
|
|
|
|
|
def _headers_to_dict(headers: list[tuple[str, bytes | None]] | None) -> dict[str, bytes | None]:
|
|
return {key: value for key, value in headers or []}
|
|
|
|
|
|
def _safe_call(value: Any, method: str) -> Any:
|
|
try:
|
|
return getattr(value, method)()
|
|
except Exception:
|
|
return "-"
|
|
|
|
|
|
def _seek_position(raw_message: Any) -> Any:
|
|
try:
|
|
from confluent_kafka import TopicPartition
|
|
except ImportError:
|
|
return raw_message
|
|
return TopicPartition(
|
|
raw_message.topic(),
|
|
raw_message.partition(),
|
|
raw_message.offset(),
|
|
)
|