You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-Flask/iti/mq/runner.py

362 lines
13 KiB
Python

from __future__ import annotations
import asyncio
import logging
import threading
import time
from collections.abc import Coroutine
from dataclasses import dataclass
from inspect import isawaitable
from pathlib import Path
from typing import Any
from .errors import MQConfigError
from .message import MQMessage
from .offset_store import MQOffsetStore, create_offset_store
from .registry import MQConsumerDefinition, MQRegistry
from .serialization import decode_message_key, decode_message_value
logger = logging.getLogger("iti.mq")
class MQConsumerRunner:
def __init__(
self,
app,
backend,
registry: MQRegistry,
*,
group_id: str | None = None,
consumer_mode: str = "subscribe",
offset_store_config: dict[str, Any] | None = None,
offset_store_path: str | Path | None = None,
failure_backoff_seconds: float = 1.0,
poll_timeout_seconds: float = 1.0,
) -> None:
self.app = app
self.backend = backend
self.registry = registry
self.group_id = group_id
self.consumer_mode = _normalize_consumer_mode(consumer_mode)
self.offset_store_config = dict(offset_store_config or {})
self.offset_store_path = offset_store_path
self.failure_backoff_seconds = failure_backoff_seconds
self.poll_timeout_seconds = poll_timeout_seconds
self._workers: list[_ConsumerWorker] = []
def start(self) -> None:
if self._workers:
return
pending_workers: list[_ConsumerWorker] = []
try:
for definition in self.registry.consumers.values():
group_id = definition.group_id or self.group_id
mode = definition.mode or self.consumer_mode
if mode == "assign":
consumer = self.backend.create_consumer(group_id or definition.name, definition.config)
offset_store = self._create_offset_store(definition)
self._assign_consumer(consumer, definition, offset_store)
logger.info(
"mq consumer started name=%s mode=assign topics=%s",
definition.name,
",".join(definition.topics),
)
else:
if not group_id:
raise MQConfigError(f"mq consumer {definition.name} missing group_id")
consumer = self.backend.create_consumer(group_id, definition.config)
consumer.subscribe(list(definition.topics))
offset_store = None
logger.info(
"mq consumer started name=%s mode=subscribe group_id=%s topics=%s",
definition.name,
group_id,
",".join(definition.topics),
)
worker = _ConsumerWorker(
app=self.app,
consumer=consumer,
definition=definition,
offset_store=offset_store,
auto_offset_reset=str(self.backend.config.get("auto_offset_reset", "earliest")),
failure_backoff_seconds=(
definition.failure_backoff_seconds
if definition.failure_backoff_seconds is not None
else self.failure_backoff_seconds
),
poll_timeout_seconds=self.poll_timeout_seconds,
)
pending_workers.append(worker)
except Exception:
for worker in pending_workers:
worker.close()
raise
self._workers = pending_workers
for worker in self._workers:
worker.start()
def stop(self) -> None:
for worker in self._workers:
worker.stop()
for worker in self._workers:
worker.join()
self._workers.clear()
def _create_offset_store(self, definition: MQConsumerDefinition) -> MQOffsetStore:
config = dict(self.offset_store_config)
config.update(definition.offset_store)
default_path = self.offset_store_path
if default_path is None:
base_dir = getattr(getattr(self.app, "state", None), "config", None)
base_dir = getattr(base_dir, "base_dir", Path.cwd())
default_path = Path(base_dir) / "runtime" / "mq_offsets.sqlite"
try:
return create_offset_store(
config,
default_path=default_path,
)
except ValueError as exc:
raise MQConfigError(str(exc)) from exc
def _assign_consumer(
self,
consumer: Any,
definition: MQConsumerDefinition,
offset_store: MQOffsetStore,
) -> None:
partitions = _resolve_partitions(consumer, definition)
if not partitions:
raise MQConfigError(f"mq consumer {definition.name} has no partitions to assign")
consumer.assign(
[
self._topic_partition(definition.name, topic, partition, offset_store)
for topic, partition in partitions
]
)
def _topic_partition(
self,
consumer_name: str,
topic: str,
partition: int,
offset_store: MQOffsetStore,
) -> Any:
offset = offset_store.get(consumer_name, topic, partition)
try:
from confluent_kafka import TopicPartition
except ImportError:
logger.debug("confluent-kafka unavailable; using test topic partition fallback")
return _AssignedPartition(topic, partition, offset if offset is not None else _auto_offset(self))
return TopicPartition(topic, partition, offset if offset is not None else _auto_offset(self))
class _ConsumerWorker:
def __init__(
self,
*,
app,
consumer: Any,
definition: MQConsumerDefinition,
offset_store: MQOffsetStore | None,
auto_offset_reset: str,
failure_backoff_seconds: float,
poll_timeout_seconds: float,
) -> None:
self.app = app
self.consumer = consumer
self.definition = definition
self.offset_store = offset_store
self.auto_offset_reset = auto_offset_reset
self.failure_backoff_seconds = failure_backoff_seconds
self.poll_timeout_seconds = poll_timeout_seconds
self._stop = threading.Event()
self._thread: threading.Thread | None = None
def start(self) -> None:
if self._thread and self._thread.is_alive():
return
self._thread = threading.Thread(target=self._loop, daemon=True)
self._thread.start()
def stop(self) -> None:
self._stop.set()
def join(self) -> None:
if self._thread:
self._thread.join(timeout=3)
self.close()
def close(self) -> None:
self.consumer.close()
if self.offset_store is not None:
self.offset_store.close()
def _loop(self) -> None:
while not self._stop.is_set():
raw_message = self.consumer.poll(self.poll_timeout_seconds)
if raw_message is None:
continue
if raw_message.error():
logger.warning("mq consumer error: %s", raw_message.error())
_handle_offset_error(
self.consumer,
raw_message,
self.definition,
self.offset_store,
self.auto_offset_reset,
)
continue
self._handle_raw_message(raw_message)
def _handle_raw_message(self, raw_message: Any) -> None:
try:
message = self._build_message(raw_message)
result = self.definition.handler(message)
if isawaitable(result):
if not isinstance(result, Coroutine):
raise TypeError("mq async handler must return a coroutine")
asyncio.run(result)
if self.offset_store is None:
self.consumer.commit(raw_message, asynchronous=False)
else:
self.offset_store.set(
self.definition.name,
raw_message.topic(),
raw_message.partition(),
raw_message.offset() + 1,
)
except Exception:
logger.exception(
"mq handler failed name=%s topic=%s partition=%s offset=%s",
self.definition.name,
_safe_call(raw_message, "topic"),
_safe_call(raw_message, "partition"),
_safe_call(raw_message, "offset"),
)
self.consumer.seek(_seek_position(raw_message))
self._stop.wait(self.failure_backoff_seconds)
def _build_message(self, raw_message: Any) -> MQMessage:
raw_key = raw_message.key()
raw_value = raw_message.value()
return MQMessage(
app=self.app,
topic=raw_message.topic(),
partition=raw_message.partition(),
offset=raw_message.offset(),
key=decode_message_key(raw_key),
raw_key=raw_key,
value=decode_message_value(raw_value, self.definition.value_format),
raw_value=raw_value,
headers=_headers_to_dict(raw_message.headers()),
timestamp=raw_message.timestamp(),
raw_message=raw_message,
)
def _headers_to_dict(headers: list[tuple[str, bytes | None]] | None) -> dict[str, bytes | None]:
return {key: value for key, value in headers or []}
def _safe_call(value: Any, method: str) -> Any:
try:
return getattr(value, method)()
except Exception:
return "-"
def _seek_position(raw_message: Any) -> Any:
try:
from confluent_kafka import TopicPartition
except ImportError:
return raw_message
return TopicPartition(
raw_message.topic(),
raw_message.partition(),
raw_message.offset(),
)
def _resolve_partitions(consumer: Any, definition: MQConsumerDefinition) -> list[tuple[str, int]]:
if definition.partitions == "all":
metadata = consumer.list_topics(timeout=10)
result: list[tuple[str, int]] = []
for topic in definition.topics:
topic_metadata = metadata.topics.get(topic)
if topic_metadata is None or getattr(topic_metadata, "error", None):
raise MQConfigError(f"mq topic metadata unavailable: {topic}")
result.extend((topic, partition) for partition in sorted(topic_metadata.partitions))
return result
if isinstance(definition.partitions, dict):
return [
(topic, int(partition))
for topic, partitions in definition.partitions.items()
for partition in partitions
]
result = []
for value in definition.partitions:
if isinstance(value, tuple) and len(value) == 2:
result.append((str(value[0]), int(value[1])))
continue
if isinstance(value, dict):
result.append((str(value["topic"]), int(value["partition"])))
continue
raise MQConfigError(f"unsupported mq partitions item: {value!r}")
return result
def _auto_offset(runner: MQConsumerRunner) -> int:
try:
from confluent_kafka import OFFSET_BEGINNING, OFFSET_END
except ImportError:
OFFSET_BEGINNING = -2
OFFSET_END = -1
return OFFSET_END if str(runner.backend.config.get("auto_offset_reset")) == "latest" else OFFSET_BEGINNING
def _handle_offset_error(
consumer: Any,
raw_message: Any,
definition: MQConsumerDefinition,
offset_store: MQOffsetStore | None,
auto_offset_reset: str,
) -> None:
if offset_store is None:
return
error = raw_message.error()
code = error.code() if hasattr(error, "code") else None
try:
from confluent_kafka import KafkaError, TopicPartition
except ImportError:
return
if code != KafkaError._OFFSET_OUT_OF_RANGE:
return
topic = raw_message.topic()
partition = raw_message.partition()
try:
low, high = consumer.get_watermark_offsets(
TopicPartition(topic, partition),
timeout=10,
cached=False,
)
except Exception:
low = 0
offset = high if auto_offset_reset == "latest" else low
offset_store.set(definition.name, topic, partition, offset)
consumer.seek(TopicPartition(topic, partition, offset))
def _normalize_consumer_mode(mode: str) -> str:
mode = str(mode or "subscribe").strip().lower()
if mode not in {"subscribe", "assign"}:
raise MQConfigError("mq consumer_mode must be 'subscribe' or 'assign'")
return mode
@dataclass(frozen=True)
class _AssignedPartition:
topic: str
partition: int
offset: int