You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-Flask/iti/mq/runner.py

182 lines
5.8 KiB
Python

from __future__ import annotations
import asyncio
import logging
import threading
import time
from collections.abc import Coroutine
from inspect import isawaitable
from typing import Any
from .errors import MQConfigError
from .message import MQMessage
from .registry import MQConsumerDefinition, MQRegistry
from .serialization import decode_message_key, decode_message_value
logger = logging.getLogger("iti.mq")
class MQConsumerRunner:
def __init__(
self,
app,
backend,
registry: MQRegistry,
*,
group_id: str | None = None,
failure_backoff_seconds: float = 1.0,
poll_timeout_seconds: float = 1.0,
) -> None:
self.app = app
self.backend = backend
self.registry = registry
self.group_id = group_id
self.failure_backoff_seconds = failure_backoff_seconds
self.poll_timeout_seconds = poll_timeout_seconds
self._workers: list[_ConsumerWorker] = []
def start(self) -> None:
if self._workers:
return
pending_workers: list[_ConsumerWorker] = []
try:
for definition in self.registry.consumers.values():
group_id = definition.group_id or self.group_id
if not group_id:
raise MQConfigError(f"mq consumer {definition.name} missing group_id")
consumer = self.backend.create_consumer(group_id, definition.config)
consumer.subscribe(list(definition.topics))
worker = _ConsumerWorker(
app=self.app,
consumer=consumer,
definition=definition,
failure_backoff_seconds=(
definition.failure_backoff_seconds
if definition.failure_backoff_seconds is not None
else self.failure_backoff_seconds
),
poll_timeout_seconds=self.poll_timeout_seconds,
)
pending_workers.append(worker)
except Exception:
for worker in pending_workers:
worker.close()
raise
self._workers = pending_workers
for worker in self._workers:
worker.start()
def stop(self) -> None:
for worker in self._workers:
worker.stop()
for worker in self._workers:
worker.join()
self._workers.clear()
class _ConsumerWorker:
def __init__(
self,
*,
app,
consumer: Any,
definition: MQConsumerDefinition,
failure_backoff_seconds: float,
poll_timeout_seconds: float,
) -> None:
self.app = app
self.consumer = consumer
self.definition = definition
self.failure_backoff_seconds = failure_backoff_seconds
self.poll_timeout_seconds = poll_timeout_seconds
self._stop = threading.Event()
self._thread: threading.Thread | None = None
def start(self) -> None:
if self._thread and self._thread.is_alive():
return
self._thread = threading.Thread(target=self._loop, daemon=True)
self._thread.start()
def stop(self) -> None:
self._stop.set()
def join(self) -> None:
if self._thread:
self._thread.join(timeout=3)
self.close()
def close(self) -> None:
self.consumer.close()
def _loop(self) -> None:
while not self._stop.is_set():
raw_message = self.consumer.poll(self.poll_timeout_seconds)
if raw_message is None:
continue
if raw_message.error():
logger.warning("mq consumer error: %s", raw_message.error())
continue
self._handle_raw_message(raw_message)
def _handle_raw_message(self, raw_message: Any) -> None:
try:
message = self._build_message(raw_message)
result = self.definition.handler(message)
if isawaitable(result):
if not isinstance(result, Coroutine):
raise TypeError("mq async handler must return a coroutine")
asyncio.run(result)
self.consumer.commit(raw_message, asynchronous=False)
except Exception:
logger.exception(
"mq handler failed name=%s topic=%s partition=%s offset=%s",
self.definition.name,
_safe_call(raw_message, "topic"),
_safe_call(raw_message, "partition"),
_safe_call(raw_message, "offset"),
)
self.consumer.seek(_seek_position(raw_message))
self._stop.wait(self.failure_backoff_seconds)
def _build_message(self, raw_message: Any) -> MQMessage:
raw_key = raw_message.key()
raw_value = raw_message.value()
return MQMessage(
app=self.app,
topic=raw_message.topic(),
partition=raw_message.partition(),
offset=raw_message.offset(),
key=decode_message_key(raw_key),
raw_key=raw_key,
value=decode_message_value(raw_value, self.definition.value_format),
raw_value=raw_value,
headers=_headers_to_dict(raw_message.headers()),
timestamp=raw_message.timestamp(),
raw_message=raw_message,
)
def _headers_to_dict(headers: list[tuple[str, bytes | None]] | None) -> dict[str, bytes | None]:
return {key: value for key, value in headers or []}
def _safe_call(value: Any, method: str) -> Any:
try:
return getattr(value, method)()
except Exception:
return "-"
def _seek_position(raw_message: Any) -> Any:
try:
from confluent_kafka import TopicPartition
except ImportError:
return raw_message
return TopicPartition(
raw_message.topic(),
raw_message.partition(),
raw_message.offset(),
)