feat: add mq
parent
627eb8a37a
commit
7cc0501c27
@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .backend import KafkaBackend
|
||||
from .client import MQClient, MQSender, mq_client
|
||||
from .errors import MQConfigError, MQError, MQPublishError
|
||||
from .message import MQMessage
|
||||
from .registry import (
|
||||
MQConsumerDefinition,
|
||||
MQProducerDefinition,
|
||||
MQRegistry,
|
||||
mq_consumer,
|
||||
mq_registry,
|
||||
)
|
||||
from .runner import MQConsumerRunner
|
||||
|
||||
|
||||
def get_mq_registry(app) -> MQRegistry:
|
||||
registry = getattr(app.state, "iti_mq_registry", None)
|
||||
if registry is None:
|
||||
registry = mq_registry
|
||||
app.state.iti_mq_registry = registry
|
||||
return registry
|
||||
|
||||
|
||||
def init_mq(
|
||||
app,
|
||||
config: dict[str, Any] | None = None,
|
||||
*,
|
||||
registry: MQRegistry | None = None,
|
||||
producer_factory: Any | None = None,
|
||||
consumer_factory: Any | None = None,
|
||||
) -> None:
|
||||
config = dict(config or {})
|
||||
registry = registry or get_mq_registry(app)
|
||||
backend_name = str(config.get("backend", "kafka"))
|
||||
if backend_name != "kafka":
|
||||
raise MQConfigError(f"unsupported mq backend: {backend_name}")
|
||||
backend = KafkaBackend(
|
||||
config,
|
||||
producer_factory=producer_factory,
|
||||
consumer_factory=consumer_factory,
|
||||
)
|
||||
client = MQClient(backend.create_producer(), registry)
|
||||
runner = MQConsumerRunner(
|
||||
app,
|
||||
backend,
|
||||
registry,
|
||||
group_id=config.get("group_id"),
|
||||
failure_backoff_seconds=float(config.get("failure_backoff_seconds", 1.0)),
|
||||
poll_timeout_seconds=float(config.get("poll_timeout_seconds", 1.0)),
|
||||
)
|
||||
app.state.iti_mq_registry = registry
|
||||
app.state.iti_mq_backend = backend
|
||||
app.state.iti_mq_client = client
|
||||
app.state.iti_mq_runner = runner
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KafkaBackend",
|
||||
"MQClient",
|
||||
"MQConfigError",
|
||||
"MQConsumerDefinition",
|
||||
"MQConsumerRunner",
|
||||
"MQError",
|
||||
"MQMessage",
|
||||
"MQProducerDefinition",
|
||||
"MQPublishError",
|
||||
"MQRegistry",
|
||||
"MQSender",
|
||||
"get_mq_registry",
|
||||
"init_mq",
|
||||
"mq_client",
|
||||
"mq_consumer",
|
||||
"mq_registry",
|
||||
]
|
||||
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .errors import MQConfigError
|
||||
|
||||
|
||||
class KafkaBackend:
|
||||
def __init__(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
producer_factory: Any | None = None,
|
||||
consumer_factory: Any | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self._producer_factory = producer_factory
|
||||
self._consumer_factory = consumer_factory
|
||||
|
||||
def create_producer(self):
|
||||
if self._producer_factory is not None:
|
||||
return self._producer_factory(self.producer_config())
|
||||
return self._confluent_producer()(self.producer_config())
|
||||
|
||||
def create_consumer(self, group_id: str, config: dict[str, Any] | None = None):
|
||||
consumer_config = self.consumer_config(group_id, config)
|
||||
if self._consumer_factory is not None:
|
||||
return self._consumer_factory(consumer_config)
|
||||
return self._confluent_consumer()(consumer_config)
|
||||
|
||||
def producer_config(self) -> dict[str, Any]:
|
||||
base = self._common_config()
|
||||
base.update(dict(self.config.get("producer") or {}))
|
||||
return base
|
||||
|
||||
def consumer_config(self, group_id: str, config: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
base = self._common_config()
|
||||
base.update(dict(self.config.get("consumer") or {}))
|
||||
base.update(dict(config or {}))
|
||||
base["group.id"] = group_id
|
||||
base["enable.auto.commit"] = False
|
||||
base.setdefault("auto.offset.reset", self.config.get("auto_offset_reset", "earliest"))
|
||||
return base
|
||||
|
||||
def _common_config(self) -> dict[str, Any]:
|
||||
bootstrap_servers = self.config.get("bootstrap_servers")
|
||||
if not bootstrap_servers:
|
||||
raise MQConfigError("mq kafka bootstrap_servers is required")
|
||||
common = {"bootstrap.servers": bootstrap_servers}
|
||||
client_id = self.config.get("client_id")
|
||||
if client_id:
|
||||
common["client.id"] = client_id
|
||||
common.update(dict(self.config.get("common") or {}))
|
||||
return common
|
||||
|
||||
def _confluent_producer(self):
|
||||
try:
|
||||
from confluent_kafka import Producer
|
||||
except ImportError as exc:
|
||||
raise MQConfigError(
|
||||
"confluent-kafka is required for kafka mq; install iti-flask[mq-kafka]"
|
||||
) from exc
|
||||
return Producer
|
||||
|
||||
def _confluent_consumer(self):
|
||||
try:
|
||||
from confluent_kafka import Consumer
|
||||
except ImportError as exc:
|
||||
raise MQConfigError(
|
||||
"confluent-kafka is required for kafka mq; install iti-flask[mq-kafka]"
|
||||
) from exc
|
||||
return Consumer
|
||||
@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .errors import MQConfigError
|
||||
from .registry import MQRegistry
|
||||
from .serialization import encode_message_key, encode_message_value
|
||||
|
||||
|
||||
class MQClient:
|
||||
def __init__(self, producer: Any, registry: MQRegistry) -> None:
|
||||
self._producer = producer
|
||||
self._registry = registry
|
||||
|
||||
def send_json(
|
||||
self,
|
||||
topic: str,
|
||||
value: Any,
|
||||
*,
|
||||
key: str | bytes | None = None,
|
||||
headers: dict[str, str | bytes | None] | None = None,
|
||||
) -> None:
|
||||
self.send(topic, value=value, key=key, headers=headers, value_format="json")
|
||||
|
||||
def send(
|
||||
self,
|
||||
topic: str,
|
||||
*,
|
||||
value: Any,
|
||||
key: str | bytes | None = None,
|
||||
headers: dict[str, str | bytes | None] | None = None,
|
||||
value_format: str = "bytes",
|
||||
) -> None:
|
||||
self._producer.poll(0)
|
||||
self._producer.produce(
|
||||
topic,
|
||||
value=encode_message_value(value, value_format),
|
||||
key=encode_message_key(key),
|
||||
headers=_encode_headers(headers),
|
||||
)
|
||||
|
||||
def sender(self, name: str) -> "MQSender":
|
||||
definition = self._registry.producers.get(name)
|
||||
if definition is None:
|
||||
raise MQConfigError(f"mq producer not registered: {name}")
|
||||
return MQSender(self, definition.topic, definition.value_format)
|
||||
|
||||
def flush(self, timeout: float | None = None) -> None:
|
||||
if timeout is None:
|
||||
self._producer.flush()
|
||||
else:
|
||||
self._producer.flush(timeout)
|
||||
|
||||
|
||||
class MQSender:
|
||||
def __init__(self, client: MQClient, topic: str, value_format: str) -> None:
|
||||
self._client = client
|
||||
self.topic = topic
|
||||
self.value_format = value_format
|
||||
|
||||
def send_json(
|
||||
self,
|
||||
value: Any,
|
||||
*,
|
||||
key: str | bytes | None = None,
|
||||
headers: dict[str, str | bytes | None] | None = None,
|
||||
) -> None:
|
||||
self._client.send(
|
||||
self.topic,
|
||||
value=value,
|
||||
key=key,
|
||||
headers=headers,
|
||||
value_format="json",
|
||||
)
|
||||
|
||||
def send(
|
||||
self,
|
||||
value: Any,
|
||||
*,
|
||||
key: str | bytes | None = None,
|
||||
headers: dict[str, str | bytes | None] | None = None,
|
||||
) -> None:
|
||||
self._client.send(
|
||||
self.topic,
|
||||
value=value,
|
||||
key=key,
|
||||
headers=headers,
|
||||
value_format=self.value_format,
|
||||
)
|
||||
|
||||
|
||||
def _encode_headers(
|
||||
headers: dict[str, str | bytes | None] | None,
|
||||
) -> list[tuple[str, bytes | None]] | None:
|
||||
if headers is None:
|
||||
return None
|
||||
result: list[tuple[str, bytes | None]] = []
|
||||
for key, value in headers.items():
|
||||
if value is None or isinstance(value, bytes):
|
||||
result.append((key, value))
|
||||
else:
|
||||
result.append((key, str(value).encode("utf-8")))
|
||||
return result
|
||||
|
||||
|
||||
def mq_client(app) -> MQClient:
|
||||
client = getattr(app.state, "iti_mq_client", None)
|
||||
if client is None:
|
||||
raise MQConfigError("mq client is not configured")
|
||||
return client
|
||||
@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class MQError(RuntimeError):
|
||||
"""Base MQ error."""
|
||||
|
||||
|
||||
class MQConfigError(MQError):
|
||||
"""Raised when MQ configuration is missing or invalid."""
|
||||
|
||||
|
||||
class MQPublishError(MQError):
|
||||
"""Raised when a message cannot be published."""
|
||||
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MQMessage:
|
||||
app: Any
|
||||
topic: str
|
||||
partition: int
|
||||
offset: int
|
||||
key: Any
|
||||
raw_key: bytes | None
|
||||
value: Any
|
||||
raw_value: bytes | None
|
||||
headers: dict[str, bytes | None]
|
||||
timestamp: tuple[int, int] | None
|
||||
raw_message: Any
|
||||
@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MQProducerDefinition:
|
||||
name: str
|
||||
topic: str
|
||||
value_format: str = "json"
|
||||
config: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MQConsumerDefinition:
|
||||
name: str
|
||||
topics: tuple[str, ...]
|
||||
handler: Callable
|
||||
group_id: str | None = None
|
||||
value_format: str = "json"
|
||||
failure_backoff_seconds: float | None = None
|
||||
config: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MQRegistry:
|
||||
producers: dict[str, MQProducerDefinition] = field(default_factory=dict)
|
||||
consumers: dict[str, MQConsumerDefinition] = field(default_factory=dict)
|
||||
|
||||
def register_producer(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
topic: str,
|
||||
value_format: str = "json",
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> MQProducerDefinition:
|
||||
if not name:
|
||||
raise ValueError("mq producer name is required")
|
||||
if not topic:
|
||||
raise ValueError("mq producer topic is required")
|
||||
if name in self.producers:
|
||||
raise ValueError(f"mq producer already registered: {name}")
|
||||
definition = MQProducerDefinition(
|
||||
name=name,
|
||||
topic=topic,
|
||||
value_format=value_format,
|
||||
config=dict(config or {}),
|
||||
)
|
||||
self.producers[name] = definition
|
||||
return definition
|
||||
|
||||
def register_consumer(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
topics: list[str] | tuple[str, ...] | str,
|
||||
handler: Callable,
|
||||
group_id: str | None = None,
|
||||
value_format: str = "json",
|
||||
failure_backoff_seconds: float | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> MQConsumerDefinition:
|
||||
if not name:
|
||||
raise ValueError("mq consumer name is required")
|
||||
topic_values = _normalize_topics(topics)
|
||||
if not topic_values:
|
||||
raise ValueError("mq consumer topics are required")
|
||||
if name in self.consumers:
|
||||
raise ValueError(f"mq consumer already registered: {name}")
|
||||
definition = MQConsumerDefinition(
|
||||
name=name,
|
||||
topics=topic_values,
|
||||
handler=handler,
|
||||
group_id=group_id,
|
||||
value_format=value_format,
|
||||
failure_backoff_seconds=failure_backoff_seconds,
|
||||
config=dict(config or {}),
|
||||
)
|
||||
self.consumers[name] = definition
|
||||
return definition
|
||||
|
||||
|
||||
def _normalize_topics(topics: list[str] | tuple[str, ...] | str) -> tuple[str, ...]:
|
||||
if isinstance(topics, str):
|
||||
topics = (topics,)
|
||||
return tuple(topic for topic in topics if topic)
|
||||
|
||||
|
||||
mq_registry = MQRegistry()
|
||||
|
||||
|
||||
def mq_consumer(
|
||||
*topics: str,
|
||||
name: str | None = None,
|
||||
group_id: str | None = None,
|
||||
value_format: str = "json",
|
||||
failure_backoff_seconds: float | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
):
|
||||
def decorator(func: Callable) -> Callable:
|
||||
consumer_name = name or ".".join(topics) or func.__name__
|
||||
mq_registry.register_consumer(
|
||||
name=consumer_name,
|
||||
topics=topics,
|
||||
group_id=group_id,
|
||||
handler=func,
|
||||
value_format=value_format,
|
||||
failure_backoff_seconds=failure_backoff_seconds,
|
||||
config=config,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@ -0,0 +1,181 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Coroutine
|
||||
from inspect import isawaitable
|
||||
from typing import Any
|
||||
|
||||
from .errors import MQConfigError
|
||||
from .message import MQMessage
|
||||
from .registry import MQConsumerDefinition, MQRegistry
|
||||
from .serialization import decode_message_key, decode_message_value
|
||||
|
||||
|
||||
logger = logging.getLogger("iti.mq")
|
||||
|
||||
|
||||
class MQConsumerRunner:
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
backend,
|
||||
registry: MQRegistry,
|
||||
*,
|
||||
group_id: str | None = None,
|
||||
failure_backoff_seconds: float = 1.0,
|
||||
poll_timeout_seconds: float = 1.0,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.registry = registry
|
||||
self.group_id = group_id
|
||||
self.failure_backoff_seconds = failure_backoff_seconds
|
||||
self.poll_timeout_seconds = poll_timeout_seconds
|
||||
self._workers: list[_ConsumerWorker] = []
|
||||
|
||||
def start(self) -> None:
|
||||
if self._workers:
|
||||
return
|
||||
pending_workers: list[_ConsumerWorker] = []
|
||||
try:
|
||||
for definition in self.registry.consumers.values():
|
||||
group_id = definition.group_id or self.group_id
|
||||
if not group_id:
|
||||
raise MQConfigError(f"mq consumer {definition.name} missing group_id")
|
||||
consumer = self.backend.create_consumer(group_id, definition.config)
|
||||
consumer.subscribe(list(definition.topics))
|
||||
worker = _ConsumerWorker(
|
||||
app=self.app,
|
||||
consumer=consumer,
|
||||
definition=definition,
|
||||
failure_backoff_seconds=(
|
||||
definition.failure_backoff_seconds
|
||||
if definition.failure_backoff_seconds is not None
|
||||
else self.failure_backoff_seconds
|
||||
),
|
||||
poll_timeout_seconds=self.poll_timeout_seconds,
|
||||
)
|
||||
pending_workers.append(worker)
|
||||
except Exception:
|
||||
for worker in pending_workers:
|
||||
worker.close()
|
||||
raise
|
||||
self._workers = pending_workers
|
||||
for worker in self._workers:
|
||||
worker.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
for worker in self._workers:
|
||||
worker.stop()
|
||||
for worker in self._workers:
|
||||
worker.join()
|
||||
self._workers.clear()
|
||||
|
||||
|
||||
class _ConsumerWorker:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
app,
|
||||
consumer: Any,
|
||||
definition: MQConsumerDefinition,
|
||||
failure_backoff_seconds: float,
|
||||
poll_timeout_seconds: float,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.consumer = consumer
|
||||
self.definition = definition
|
||||
self.failure_backoff_seconds = failure_backoff_seconds
|
||||
self.poll_timeout_seconds = poll_timeout_seconds
|
||||
self._stop = threading.Event()
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
self._thread = threading.Thread(target=self._loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
|
||||
def join(self) -> None:
|
||||
if self._thread:
|
||||
self._thread.join(timeout=3)
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self.consumer.close()
|
||||
|
||||
def _loop(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
raw_message = self.consumer.poll(self.poll_timeout_seconds)
|
||||
if raw_message is None:
|
||||
continue
|
||||
if raw_message.error():
|
||||
logger.warning("mq consumer error: %s", raw_message.error())
|
||||
continue
|
||||
self._handle_raw_message(raw_message)
|
||||
|
||||
def _handle_raw_message(self, raw_message: Any) -> None:
|
||||
try:
|
||||
message = self._build_message(raw_message)
|
||||
result = self.definition.handler(message)
|
||||
if isawaitable(result):
|
||||
if not isinstance(result, Coroutine):
|
||||
raise TypeError("mq async handler must return a coroutine")
|
||||
asyncio.run(result)
|
||||
self.consumer.commit(raw_message, asynchronous=False)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"mq handler failed name=%s topic=%s partition=%s offset=%s",
|
||||
self.definition.name,
|
||||
_safe_call(raw_message, "topic"),
|
||||
_safe_call(raw_message, "partition"),
|
||||
_safe_call(raw_message, "offset"),
|
||||
)
|
||||
self.consumer.seek(_seek_position(raw_message))
|
||||
self._stop.wait(self.failure_backoff_seconds)
|
||||
|
||||
def _build_message(self, raw_message: Any) -> MQMessage:
|
||||
raw_key = raw_message.key()
|
||||
raw_value = raw_message.value()
|
||||
return MQMessage(
|
||||
app=self.app,
|
||||
topic=raw_message.topic(),
|
||||
partition=raw_message.partition(),
|
||||
offset=raw_message.offset(),
|
||||
key=decode_message_key(raw_key),
|
||||
raw_key=raw_key,
|
||||
value=decode_message_value(raw_value, self.definition.value_format),
|
||||
raw_value=raw_value,
|
||||
headers=_headers_to_dict(raw_message.headers()),
|
||||
timestamp=raw_message.timestamp(),
|
||||
raw_message=raw_message,
|
||||
)
|
||||
|
||||
|
||||
def _headers_to_dict(headers: list[tuple[str, bytes | None]] | None) -> dict[str, bytes | None]:
|
||||
return {key: value for key, value in headers or []}
|
||||
|
||||
|
||||
def _safe_call(value: Any, method: str) -> Any:
|
||||
try:
|
||||
return getattr(value, method)()
|
||||
except Exception:
|
||||
return "-"
|
||||
|
||||
|
||||
def _seek_position(raw_message: Any) -> Any:
|
||||
try:
|
||||
from confluent_kafka import TopicPartition
|
||||
except ImportError:
|
||||
return raw_message
|
||||
return TopicPartition(
|
||||
raw_message.topic(),
|
||||
raw_message.partition(),
|
||||
raw_message.offset(),
|
||||
)
|
||||
@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .errors import MQConfigError
|
||||
|
||||
|
||||
SUPPORTED_VALUE_FORMATS = {"json", "bytes"}
|
||||
|
||||
|
||||
def encode_message_value(value: Any, value_format: str = "json") -> bytes | None:
|
||||
_validate_value_format(value_format)
|
||||
if value is None:
|
||||
return None
|
||||
if value_format == "bytes":
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
raise TypeError("bytes mq value must be bytes")
|
||||
return json.dumps(value, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
||||
|
||||
|
||||
def decode_message_value(value: bytes | None, value_format: str = "json") -> Any:
|
||||
_validate_value_format(value_format)
|
||||
if value is None:
|
||||
return None
|
||||
if value_format == "bytes":
|
||||
return value
|
||||
return json.loads(value.decode("utf-8"))
|
||||
|
||||
|
||||
def encode_message_key(key: str | bytes | None) -> bytes | None:
|
||||
if key is None or isinstance(key, bytes):
|
||||
return key
|
||||
return str(key).encode("utf-8")
|
||||
|
||||
|
||||
def decode_message_key(key: bytes | None) -> str | None:
|
||||
if key is None:
|
||||
return None
|
||||
return key.decode("utf-8")
|
||||
|
||||
|
||||
def _validate_value_format(value_format: str) -> None:
|
||||
if value_format not in SUPPORTED_VALUE_FORMATS:
|
||||
supported = ", ".join(sorted(SUPPORTED_VALUE_FORMATS))
|
||||
raise MQConfigError(f"unsupported mq value_format: {value_format!r}, supported: {supported}")
|
||||
@ -0,0 +1,383 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from iti import create_app
|
||||
from iti.config import BaseConfig
|
||||
from iti.mq import MQConfigError, init_mq, mq_client, mq_consumer
|
||||
from iti.mq.backend import KafkaBackend
|
||||
from iti.mq.registry import MQRegistry as RegistryClass
|
||||
|
||||
|
||||
class FakeProducer:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.produced = []
|
||||
self.polled = []
|
||||
self.flushed = []
|
||||
|
||||
def poll(self, timeout):
|
||||
self.polled.append(timeout)
|
||||
|
||||
def produce(self, topic, *, value=None, key=None, headers=None):
|
||||
self.produced.append(
|
||||
{"topic": topic, "value": value, "key": key, "headers": headers}
|
||||
)
|
||||
|
||||
def flush(self, timeout=None):
|
||||
self.flushed.append(timeout)
|
||||
|
||||
|
||||
class FakeConsumer:
|
||||
def __init__(self, messages=None):
|
||||
self.messages = list(messages or [])
|
||||
self.subscribed = []
|
||||
self.committed = []
|
||||
self.sought = []
|
||||
self.closed = False
|
||||
self.config = None
|
||||
|
||||
def subscribe(self, topics):
|
||||
self.subscribed.append(topics)
|
||||
|
||||
def poll(self, timeout):
|
||||
if self.messages:
|
||||
return self.messages.pop(0)
|
||||
return None
|
||||
|
||||
def commit(self, message, asynchronous=False):
|
||||
self.committed.append((message, asynchronous))
|
||||
|
||||
def seek(self, position):
|
||||
if all(hasattr(position, name) for name in ("topic", "partition", "offset")):
|
||||
self.sought.append(
|
||||
(
|
||||
call_or_value(position, "topic"),
|
||||
call_or_value(position, "partition"),
|
||||
call_or_value(position, "offset"),
|
||||
)
|
||||
)
|
||||
return
|
||||
self.sought.append(position)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class FakeMessage:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topic="demo.topic",
|
||||
partition=0,
|
||||
offset=1,
|
||||
key=b"k1",
|
||||
value=b'{"ok":true}',
|
||||
headers=None,
|
||||
):
|
||||
self._topic = topic
|
||||
self._partition = partition
|
||||
self._offset = offset
|
||||
self._key = key
|
||||
self._value = value
|
||||
self._headers = headers or [("source", b"test")]
|
||||
|
||||
def topic(self):
|
||||
return self._topic
|
||||
|
||||
def partition(self):
|
||||
return self._partition
|
||||
|
||||
def offset(self):
|
||||
return self._offset
|
||||
|
||||
def key(self):
|
||||
return self._key
|
||||
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
def headers(self):
|
||||
return self._headers
|
||||
|
||||
def timestamp(self):
|
||||
return (0, 0)
|
||||
|
||||
def error(self):
|
||||
return None
|
||||
|
||||
|
||||
def test_mq_registry_registers_decorator_and_explicit_consumer():
|
||||
registry = RegistryClass()
|
||||
registry.register_consumer(
|
||||
name="explicit",
|
||||
topics=["demo.explicit"],
|
||||
group_id="g1",
|
||||
handler=lambda message: None,
|
||||
)
|
||||
|
||||
assert registry.consumers["explicit"].topics == ("demo.explicit",)
|
||||
|
||||
before = set(mq_consumer_registry_names())
|
||||
|
||||
try:
|
||||
@mq_consumer("demo.decorated", name="decorated-test", group_id="g1")
|
||||
def decorated(message):
|
||||
return None
|
||||
|
||||
assert "decorated-test" in mq_consumer_registry_names() - before
|
||||
assert decorated.__name__ == "decorated"
|
||||
finally:
|
||||
from iti.mq import mq_registry
|
||||
|
||||
mq_registry.consumers.pop("decorated-test", None)
|
||||
|
||||
|
||||
def test_mq_registry_rejects_duplicate_names():
|
||||
registry = RegistryClass()
|
||||
registry.register_producer(name="events", topic="demo.events")
|
||||
registry.register_consumer(
|
||||
name="consumer",
|
||||
topics="demo.events",
|
||||
handler=lambda message: None,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="producer already registered"):
|
||||
registry.register_producer(name="events", topic="other")
|
||||
with pytest.raises(ValueError, match="consumer already registered"):
|
||||
registry.register_consumer(
|
||||
name="consumer",
|
||||
topics="other",
|
||||
handler=lambda message: None,
|
||||
)
|
||||
|
||||
|
||||
def test_mq_enabled_false_does_not_import_or_configure_kafka(monkeypatch):
|
||||
sys.modules.pop("confluent_kafka", None)
|
||||
|
||||
app = create_app(
|
||||
config_mapping=BaseConfig(
|
||||
database_url="sqlite+pysqlite:///:memory:",
|
||||
testing=True,
|
||||
mq_enabled=False,
|
||||
exchange_enabled=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert not hasattr(app.state, "iti_mq_client")
|
||||
assert "confluent_kafka" not in sys.modules
|
||||
|
||||
|
||||
def test_mq_enabled_true_without_dependency_raises_install_hint(monkeypatch):
|
||||
sys.modules.pop("confluent_kafka", None)
|
||||
real_import = builtins.__import__
|
||||
|
||||
def missing_confluent_kafka(name, *args, **kwargs):
|
||||
if name == "confluent_kafka":
|
||||
raise ImportError("missing")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", missing_confluent_kafka)
|
||||
|
||||
with pytest.raises(MQConfigError, match=r"iti-flask\[mq-kafka\]"):
|
||||
create_app(
|
||||
config_mapping=BaseConfig(
|
||||
database_url="sqlite+pysqlite:///:memory:",
|
||||
testing=True,
|
||||
mq_enabled=True,
|
||||
mq={"backend": "kafka", "bootstrap_servers": "127.0.0.1:9092"},
|
||||
exchange_enabled=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_mq_client_sends_json_bytes_and_registered_sender():
|
||||
producer = FakeProducer({"bootstrap.servers": "localhost:9092"})
|
||||
registry = RegistryClass()
|
||||
registry.register_producer(name="events", topic="demo.events")
|
||||
app = create_app(
|
||||
config_mapping=BaseConfig(
|
||||
database_url="sqlite+pysqlite:///:memory:",
|
||||
testing=True,
|
||||
mq_enabled=False,
|
||||
exchange_enabled=False,
|
||||
)
|
||||
)
|
||||
init_mq(
|
||||
app,
|
||||
{"backend": "kafka", "bootstrap_servers": "localhost:9092"},
|
||||
registry=registry,
|
||||
producer_factory=lambda config: producer,
|
||||
)
|
||||
|
||||
client = mq_client(app)
|
||||
client.send_json("demo.raw", {"id": "1"}, key="k1", headers={"h": "v"})
|
||||
client.send("demo.bytes", value=b"raw", key=b"k2")
|
||||
client.sender("events").send_json({"id": "2"})
|
||||
client.flush(2)
|
||||
|
||||
assert producer.produced[0] == {
|
||||
"topic": "demo.raw",
|
||||
"value": b'{"id":"1"}',
|
||||
"key": b"k1",
|
||||
"headers": [("h", b"v")],
|
||||
}
|
||||
assert producer.produced[1]["value"] == b"raw"
|
||||
assert producer.produced[1]["key"] == b"k2"
|
||||
assert producer.produced[2]["topic"] == "demo.events"
|
||||
assert producer.flushed == [2]
|
||||
|
||||
|
||||
def test_runner_commits_after_successful_sync_handler():
|
||||
handled = []
|
||||
message = FakeMessage()
|
||||
fake_consumer = FakeConsumer([message])
|
||||
app = make_mq_app(
|
||||
registry_with_consumer(lambda item: handled.append(item.value)),
|
||||
fake_consumer,
|
||||
)
|
||||
runner = app.state.iti_mq_runner
|
||||
|
||||
runner.start()
|
||||
wait_until(lambda: bool(fake_consumer.committed))
|
||||
runner.stop()
|
||||
|
||||
assert handled == [{"ok": True}]
|
||||
assert fake_consumer.committed == [(message, False)]
|
||||
assert fake_consumer.sought == []
|
||||
assert fake_consumer.closed is True
|
||||
|
||||
|
||||
def test_runner_executes_async_handler():
|
||||
handled = []
|
||||
|
||||
async def handler(message):
|
||||
await asyncio.sleep(0)
|
||||
handled.append(message.key)
|
||||
|
||||
fake_consumer = FakeConsumer([FakeMessage()])
|
||||
app = make_mq_app(registry_with_consumer(handler), fake_consumer)
|
||||
runner = app.state.iti_mq_runner
|
||||
|
||||
runner.start()
|
||||
wait_until(lambda: bool(fake_consumer.committed))
|
||||
runner.stop()
|
||||
|
||||
assert handled == ["k1"]
|
||||
|
||||
|
||||
def test_runner_does_not_commit_and_seeks_after_handler_failure():
|
||||
message = FakeMessage()
|
||||
fake_consumer = FakeConsumer([message])
|
||||
|
||||
def handler(_message):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
app = make_mq_app(registry_with_consumer(handler), fake_consumer)
|
||||
runner = app.state.iti_mq_runner
|
||||
|
||||
runner.start()
|
||||
wait_until(lambda: bool(fake_consumer.sought))
|
||||
runner.stop()
|
||||
|
||||
assert fake_consumer.committed == []
|
||||
assert fake_consumer.sought == [("demo.topic", 0, 1)]
|
||||
|
||||
|
||||
def test_runner_raises_when_group_id_missing():
|
||||
registry = RegistryClass()
|
||||
registry.register_consumer(
|
||||
name="demo",
|
||||
topics="demo.topic",
|
||||
handler=lambda message: None,
|
||||
)
|
||||
app = make_mq_app(registry, FakeConsumer())
|
||||
|
||||
with pytest.raises(MQConfigError, match="missing group_id"):
|
||||
app.state.iti_mq_runner.start()
|
||||
|
||||
|
||||
def test_kafka_backend_forces_manual_commit():
|
||||
backend = KafkaBackend(
|
||||
{
|
||||
"bootstrap_servers": "localhost:9092",
|
||||
"group_id": "global",
|
||||
"auto_offset_reset": "latest",
|
||||
"consumer": {"enable.auto.commit": True},
|
||||
},
|
||||
producer_factory=lambda config: FakeProducer(config),
|
||||
consumer_factory=lambda config: FakeConsumer(),
|
||||
)
|
||||
|
||||
config = backend.consumer_config("g1")
|
||||
|
||||
assert config["bootstrap.servers"] == "localhost:9092"
|
||||
assert config["group.id"] == "g1"
|
||||
assert config["auto.offset.reset"] == "latest"
|
||||
assert config["enable.auto.commit"] is False
|
||||
|
||||
|
||||
def mq_consumer_registry_names():
|
||||
from iti.mq import mq_registry
|
||||
|
||||
return set(mq_registry.consumers.keys())
|
||||
|
||||
|
||||
def registry_with_consumer(handler):
|
||||
registry = RegistryClass()
|
||||
registry.register_consumer(
|
||||
name="demo",
|
||||
topics="demo.topic",
|
||||
group_id="g1",
|
||||
handler=handler,
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
def make_mq_app(registry, fake_consumer):
|
||||
app = create_app(
|
||||
config_mapping=BaseConfig(
|
||||
database_url="sqlite+pysqlite:///:memory:",
|
||||
testing=True,
|
||||
mq_enabled=False,
|
||||
exchange_enabled=False,
|
||||
)
|
||||
)
|
||||
|
||||
def consumer_factory(config):
|
||||
fake_consumer.config = config
|
||||
return fake_consumer
|
||||
|
||||
init_mq(
|
||||
app,
|
||||
{
|
||||
"backend": "kafka",
|
||||
"bootstrap_servers": "localhost:9092",
|
||||
"failure_backoff_seconds": 0.01,
|
||||
"poll_timeout_seconds": 0.01,
|
||||
},
|
||||
registry=registry,
|
||||
producer_factory=lambda config: FakeProducer(config),
|
||||
consumer_factory=consumer_factory,
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
def wait_until(predicate, timeout=1.0):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
deadline = loop.time() + timeout
|
||||
while loop.time() < deadline:
|
||||
if predicate():
|
||||
return
|
||||
loop.run_until_complete(asyncio.sleep(0.01))
|
||||
finally:
|
||||
loop.close()
|
||||
raise AssertionError("condition not reached")
|
||||
|
||||
|
||||
def call_or_value(value, name):
|
||||
attr = getattr(value, name)
|
||||
return attr() if callable(attr) else attr
|
||||
Loading…
Reference in New Issue