You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-Flask/iti/audit.py

258 lines
7.7 KiB
Python

from __future__ import annotations
import logging
import queue
import threading
from collections.abc import Callable
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from functools import wraps
from inspect import isawaitable
from typing import Any
from fastapi import Request
from iti.auth import Actor
from iti.service_client import ServiceClientError, service_client
logger = logging.getLogger("iti.audit")
SENSITIVE_KEYS = {"password", "token", "authorization", "secret", "refreshToken", "refresh_token"}
@dataclass(frozen=True)
class AuditEvent:
type: str
title: str
success: bool = True
actor_id: str | None = None
actor_type: str | None = None
method: str | None = None
path: str | None = None
ip: str | None = None
user_agent: str | None = None
target_type: str | None = None
target_id: str | None = None
diff: dict[str, Any] | None = None
desc: str | None = None
error: str | None = None
trace_id: str | None = None
occurred_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def payload(self) -> dict[str, Any]:
return {key: value for key, value in asdict(self).items() if value is not None}
class AuditDispatcher:
def __init__(self, app) -> None:
self.app = app
config = app.state.config
self.enabled = bool(config.audit_enabled)
self.service_name = config.audit_service_name
self.batch_size = max(int(config.audit_batch_size), 1)
self.flush_interval = float(config.audit_flush_interval_seconds)
self._queue: queue.Queue[AuditEvent] = queue.Queue(maxsize=config.audit_queue_size)
self._stop = threading.Event()
self._thread: threading.Thread | None = None
def start(self) -> None:
if not self.enabled:
return
if self._thread and self._thread.is_alive():
return
self._thread = threading.Thread(target=self._loop, daemon=True)
self._thread.start()
def stop(self) -> None:
self._stop.set()
if self._thread:
self._thread.join(timeout=3)
def emit(self, event: AuditEvent) -> None:
if not self.enabled:
return
try:
self._queue.put_nowait(event)
except queue.Full:
try:
self._queue.get_nowait()
self._queue.put_nowait(event)
except queue.Empty:
logger.warning("audit queue full and event dropped")
def _loop(self) -> None:
while not self._stop.is_set():
batch = self._drain()
if batch:
self._send(batch)
self._stop.wait(self.flush_interval)
batch = self._drain()
if batch:
self._send(batch)
def _drain(self) -> list[AuditEvent]:
batch: list[AuditEvent] = []
for _ in range(self.batch_size):
try:
batch.append(self._queue.get_nowait())
except queue.Empty:
break
return batch
def _send(self, batch: list[AuditEvent]) -> None:
try:
client = service_client(self.app, self.service_name)
client.post("/internal/audit/events", json={"events": [item.payload() for item in batch]})
except ServiceClientError as exc:
logger.warning("audit send failed: %s", exc)
def init_audit(app) -> AuditDispatcher:
dispatcher = AuditDispatcher(app)
app.state.audit_dispatcher = dispatcher
return dispatcher
def audit_operation(
request: Request,
*,
title: str,
target_type: str | None = None,
target_id: str | None = None,
before: dict[str, Any] | None = None,
after: dict[str, Any] | None = None,
success: bool = True,
desc: str | None = None,
error: str | None = None,
) -> None:
dispatcher = getattr(request.app.state, "audit_dispatcher", None)
if dispatcher is None:
return
actor = getattr(request.state, "actor", None)
dispatcher.emit(
AuditEvent(
type="operation",
title=title,
success=success,
actor_id=getattr(actor, "id", None),
actor_type=getattr(actor, "type", None),
method=request.method,
path=request.url.path,
ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
target_type=target_type,
target_id=target_id,
diff=build_diff(before, after) if before is not None or after is not None else None,
desc=desc,
error=error,
trace_id=getattr(request.state, "trace_id", None),
)
)
def operation_log(
title: str,
*,
target_type: str | None = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
async def async_wrapper(*args, **kwargs):
request = _find_request(args, kwargs)
try:
result = func(*args, **kwargs)
if isawaitable(result):
result = await result
if request is not None:
audit_operation(
request,
title=title,
target_type=target_type,
)
return result
except Exception as exc:
if request is not None:
audit_operation(
request,
title=title,
target_type=target_type,
success=False,
error=str(exc),
)
raise
return async_wrapper
return decorator
def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None:
for value in list(args) + list(kwargs.values()):
if isinstance(value, Request):
return value
return None
def audit_login(
request: Request,
*,
title: str = "登录",
actor: Actor | None = None,
success: bool = True,
desc: str | None = None,
error: str | None = None,
) -> None:
dispatcher = getattr(request.app.state, "audit_dispatcher", None)
if dispatcher is None:
return
actor = actor or getattr(request.state, "actor", None)
dispatcher.emit(
AuditEvent(
type="login",
title=title,
success=success,
actor_id=getattr(actor, "id", None),
actor_type=getattr(actor, "type", None),
method=request.method,
path=request.url.path,
ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
desc=desc,
error=error,
trace_id=getattr(request.state, "trace_id", None),
)
)
def build_diff(before: dict[str, Any] | None, after: dict[str, Any] | None) -> dict[str, Any]:
before = before or {}
after = after or {}
keys = sorted(set(before) | set(after))
changes = {}
for key in keys:
old = before.get(key)
new = after.get(key)
if old != new:
changes[key] = {"before": sanitize_value(key, old), "after": sanitize_value(key, new)}
return changes
def sanitize(value: Any) -> Any:
if isinstance(value, dict):
result = {}
for key, item in value.items():
if str(key) in SENSITIVE_KEYS:
result[key] = "***"
else:
result[key] = sanitize(item)
return result
if isinstance(value, list):
return [sanitize(item) for item in value]
return value
def sanitize_value(key: str, value: Any) -> Any:
if key in SENSITIVE_KEYS:
return "***"
return sanitize(value)