from __future__ import annotations import logging import queue import threading from collections.abc import Callable from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from functools import wraps from inspect import isawaitable from typing import Any from fastapi import Request from iti.auth import Actor from iti.service_client import ServiceClientError, service_client logger = logging.getLogger("iti.audit") SENSITIVE_KEYS = {"password", "token", "authorization", "secret", "refreshToken", "refresh_token"} @dataclass(frozen=True) class AuditEvent: type: str title: str success: bool = True actor_id: str | None = None actor_type: str | None = None method: str | None = None path: str | None = None ip: str | None = None user_agent: str | None = None target_type: str | None = None target_id: str | None = None diff: dict[str, Any] | None = None desc: str | None = None error: str | None = None trace_id: str | None = None occurred_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) def payload(self) -> dict[str, Any]: return {key: value for key, value in asdict(self).items() if value is not None} class AuditDispatcher: def __init__(self, app) -> None: self.app = app config = app.state.config self.enabled = bool(config.audit_enabled) self.service_name = config.audit_service_name self.batch_size = max(int(config.audit_batch_size), 1) self.flush_interval = float(config.audit_flush_interval_seconds) self._queue: queue.Queue[AuditEvent] = queue.Queue(maxsize=config.audit_queue_size) self._stop = threading.Event() self._thread: threading.Thread | None = None def start(self) -> None: if not self.enabled: return if self._thread and self._thread.is_alive(): return self._thread = threading.Thread(target=self._loop, daemon=True) self._thread.start() def stop(self) -> None: self._stop.set() if self._thread: self._thread.join(timeout=3) def emit(self, event: AuditEvent) -> None: if not self.enabled: return try: self._queue.put_nowait(event) except queue.Full: try: self._queue.get_nowait() self._queue.put_nowait(event) except queue.Empty: logger.warning("audit queue full and event dropped") def _loop(self) -> None: while not self._stop.is_set(): batch = self._drain() if batch: self._send(batch) self._stop.wait(self.flush_interval) batch = self._drain() if batch: self._send(batch) def _drain(self) -> list[AuditEvent]: batch: list[AuditEvent] = [] for _ in range(self.batch_size): try: batch.append(self._queue.get_nowait()) except queue.Empty: break return batch def _send(self, batch: list[AuditEvent]) -> None: try: client = service_client(self.app, self.service_name) client.post("/internal/audit/events", json={"events": [item.payload() for item in batch]}) except ServiceClientError as exc: logger.warning("audit send failed: %s", exc) def init_audit(app) -> AuditDispatcher: dispatcher = AuditDispatcher(app) app.state.audit_dispatcher = dispatcher return dispatcher def audit_operation( request: Request, *, title: str, target_type: str | None = None, target_id: str | None = None, before: dict[str, Any] | None = None, after: dict[str, Any] | None = None, success: bool = True, desc: str | None = None, error: str | None = None, ) -> None: dispatcher = getattr(request.app.state, "audit_dispatcher", None) if dispatcher is None: return actor = getattr(request.state, "actor", None) dispatcher.emit( AuditEvent( type="operation", title=title, success=success, actor_id=getattr(actor, "id", None), actor_type=getattr(actor, "type", None), method=request.method, path=request.url.path, ip=request.client.host if request.client else None, user_agent=request.headers.get("user-agent"), target_type=target_type, target_id=target_id, diff=build_diff(before, after) if before is not None or after is not None else None, desc=desc, error=error, trace_id=getattr(request.state, "trace_id", None), ) ) def operation_log( title: str, *, target_type: str | None = None, ) -> Callable: def decorator(func: Callable) -> Callable: @wraps(func) async def async_wrapper(*args, **kwargs): request = _find_request(args, kwargs) try: result = func(*args, **kwargs) if isawaitable(result): result = await result if request is not None: audit_operation( request, title=title, target_type=target_type, ) return result except Exception as exc: if request is not None: audit_operation( request, title=title, target_type=target_type, success=False, error=str(exc), ) raise return async_wrapper return decorator def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None: for value in list(args) + list(kwargs.values()): if isinstance(value, Request): return value return None def audit_login( request: Request, *, title: str = "登录", actor: Actor | None = None, success: bool = True, desc: str | None = None, error: str | None = None, ) -> None: dispatcher = getattr(request.app.state, "audit_dispatcher", None) if dispatcher is None: return actor = actor or getattr(request.state, "actor", None) dispatcher.emit( AuditEvent( type="login", title=title, success=success, actor_id=getattr(actor, "id", None), actor_type=getattr(actor, "type", None), method=request.method, path=request.url.path, ip=request.client.host if request.client else None, user_agent=request.headers.get("user-agent"), desc=desc, error=error, trace_id=getattr(request.state, "trace_id", None), ) ) def build_diff(before: dict[str, Any] | None, after: dict[str, Any] | None) -> dict[str, Any]: before = before or {} after = after or {} keys = sorted(set(before) | set(after)) changes = {} for key in keys: old = before.get(key) new = after.get(key) if old != new: changes[key] = {"before": sanitize_value(key, old), "after": sanitize_value(key, new)} return changes def sanitize(value: Any) -> Any: if isinstance(value, dict): result = {} for key, item in value.items(): if str(key) in SENSITIVE_KEYS: result[key] = "***" else: result[key] = sanitize(item) return result if isinstance(value, list): return [sanitize(item) for item in value] return value def sanitize_value(key: str, value: Any) -> Any: if key in SENSITIVE_KEYS: return "***" return sanitize(value)