You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
258 lines
7.7 KiB
Python
258 lines
7.7 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import queue
|
|
import threading
|
|
from collections.abc import Callable
|
|
from dataclasses import asdict, dataclass, field
|
|
from datetime import datetime, timezone
|
|
from functools import wraps
|
|
from inspect import isawaitable
|
|
from typing import Any
|
|
|
|
from fastapi import Request
|
|
|
|
from iti.auth import Actor
|
|
from iti.service_client import ServiceClientError, service_client
|
|
|
|
|
|
logger = logging.getLogger("iti.audit")
|
|
SENSITIVE_KEYS = {"password", "token", "authorization", "secret", "refreshToken", "refresh_token"}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AuditEvent:
|
|
type: str
|
|
title: str
|
|
success: bool = True
|
|
actor_id: str | None = None
|
|
actor_type: str | None = None
|
|
method: str | None = None
|
|
path: str | None = None
|
|
ip: str | None = None
|
|
user_agent: str | None = None
|
|
target_type: str | None = None
|
|
target_id: str | None = None
|
|
diff: dict[str, Any] | None = None
|
|
desc: str | None = None
|
|
error: str | None = None
|
|
trace_id: str | None = None
|
|
occurred_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
|
|
def payload(self) -> dict[str, Any]:
|
|
return {key: value for key, value in asdict(self).items() if value is not None}
|
|
|
|
|
|
class AuditDispatcher:
|
|
def __init__(self, app) -> None:
|
|
self.app = app
|
|
config = app.state.config
|
|
self.enabled = bool(config.audit_enabled)
|
|
self.service_name = config.audit_service_name
|
|
self.batch_size = max(int(config.audit_batch_size), 1)
|
|
self.flush_interval = float(config.audit_flush_interval_seconds)
|
|
self._queue: queue.Queue[AuditEvent] = queue.Queue(maxsize=config.audit_queue_size)
|
|
self._stop = threading.Event()
|
|
self._thread: threading.Thread | None = None
|
|
|
|
def start(self) -> None:
|
|
if not self.enabled:
|
|
return
|
|
if self._thread and self._thread.is_alive():
|
|
return
|
|
self._thread = threading.Thread(target=self._loop, daemon=True)
|
|
self._thread.start()
|
|
|
|
def stop(self) -> None:
|
|
self._stop.set()
|
|
if self._thread:
|
|
self._thread.join(timeout=3)
|
|
|
|
def emit(self, event: AuditEvent) -> None:
|
|
if not self.enabled:
|
|
return
|
|
try:
|
|
self._queue.put_nowait(event)
|
|
except queue.Full:
|
|
try:
|
|
self._queue.get_nowait()
|
|
self._queue.put_nowait(event)
|
|
except queue.Empty:
|
|
logger.warning("audit queue full and event dropped")
|
|
|
|
def _loop(self) -> None:
|
|
while not self._stop.is_set():
|
|
batch = self._drain()
|
|
if batch:
|
|
self._send(batch)
|
|
self._stop.wait(self.flush_interval)
|
|
batch = self._drain()
|
|
if batch:
|
|
self._send(batch)
|
|
|
|
def _drain(self) -> list[AuditEvent]:
|
|
batch: list[AuditEvent] = []
|
|
for _ in range(self.batch_size):
|
|
try:
|
|
batch.append(self._queue.get_nowait())
|
|
except queue.Empty:
|
|
break
|
|
return batch
|
|
|
|
def _send(self, batch: list[AuditEvent]) -> None:
|
|
try:
|
|
client = service_client(self.app, self.service_name)
|
|
client.post("/internal/audit/events", json={"events": [item.payload() for item in batch]})
|
|
except ServiceClientError as exc:
|
|
logger.warning("audit send failed: %s", exc)
|
|
|
|
|
|
def init_audit(app) -> AuditDispatcher:
|
|
dispatcher = AuditDispatcher(app)
|
|
app.state.audit_dispatcher = dispatcher
|
|
return dispatcher
|
|
|
|
|
|
def audit_operation(
|
|
request: Request,
|
|
*,
|
|
title: str,
|
|
target_type: str | None = None,
|
|
target_id: str | None = None,
|
|
before: dict[str, Any] | None = None,
|
|
after: dict[str, Any] | None = None,
|
|
success: bool = True,
|
|
desc: str | None = None,
|
|
error: str | None = None,
|
|
) -> None:
|
|
dispatcher = getattr(request.app.state, "audit_dispatcher", None)
|
|
if dispatcher is None:
|
|
return
|
|
actor = getattr(request.state, "actor", None)
|
|
dispatcher.emit(
|
|
AuditEvent(
|
|
type="operation",
|
|
title=title,
|
|
success=success,
|
|
actor_id=getattr(actor, "id", None),
|
|
actor_type=getattr(actor, "type", None),
|
|
method=request.method,
|
|
path=request.url.path,
|
|
ip=request.client.host if request.client else None,
|
|
user_agent=request.headers.get("user-agent"),
|
|
target_type=target_type,
|
|
target_id=target_id,
|
|
diff=build_diff(before, after) if before is not None or after is not None else None,
|
|
desc=desc,
|
|
error=error,
|
|
trace_id=getattr(request.state, "trace_id", None),
|
|
)
|
|
)
|
|
|
|
|
|
def operation_log(
|
|
title: str,
|
|
*,
|
|
target_type: str | None = None,
|
|
) -> Callable:
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
request = _find_request(args, kwargs)
|
|
try:
|
|
result = func(*args, **kwargs)
|
|
if isawaitable(result):
|
|
result = await result
|
|
if request is not None:
|
|
audit_operation(
|
|
request,
|
|
title=title,
|
|
target_type=target_type,
|
|
)
|
|
return result
|
|
except Exception as exc:
|
|
if request is not None:
|
|
audit_operation(
|
|
request,
|
|
title=title,
|
|
target_type=target_type,
|
|
success=False,
|
|
error=str(exc),
|
|
)
|
|
raise
|
|
|
|
return async_wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None:
|
|
for value in list(args) + list(kwargs.values()):
|
|
if isinstance(value, Request):
|
|
return value
|
|
return None
|
|
|
|
|
|
def audit_login(
|
|
request: Request,
|
|
*,
|
|
title: str = "登录",
|
|
actor: Actor | None = None,
|
|
success: bool = True,
|
|
desc: str | None = None,
|
|
error: str | None = None,
|
|
) -> None:
|
|
dispatcher = getattr(request.app.state, "audit_dispatcher", None)
|
|
if dispatcher is None:
|
|
return
|
|
actor = actor or getattr(request.state, "actor", None)
|
|
dispatcher.emit(
|
|
AuditEvent(
|
|
type="login",
|
|
title=title,
|
|
success=success,
|
|
actor_id=getattr(actor, "id", None),
|
|
actor_type=getattr(actor, "type", None),
|
|
method=request.method,
|
|
path=request.url.path,
|
|
ip=request.client.host if request.client else None,
|
|
user_agent=request.headers.get("user-agent"),
|
|
desc=desc,
|
|
error=error,
|
|
trace_id=getattr(request.state, "trace_id", None),
|
|
)
|
|
)
|
|
|
|
|
|
def build_diff(before: dict[str, Any] | None, after: dict[str, Any] | None) -> dict[str, Any]:
|
|
before = before or {}
|
|
after = after or {}
|
|
keys = sorted(set(before) | set(after))
|
|
changes = {}
|
|
for key in keys:
|
|
old = before.get(key)
|
|
new = after.get(key)
|
|
if old != new:
|
|
changes[key] = {"before": sanitize_value(key, old), "after": sanitize_value(key, new)}
|
|
return changes
|
|
|
|
|
|
def sanitize(value: Any) -> Any:
|
|
if isinstance(value, dict):
|
|
result = {}
|
|
for key, item in value.items():
|
|
if str(key) in SENSITIVE_KEYS:
|
|
result[key] = "***"
|
|
else:
|
|
result[key] = sanitize(item)
|
|
return result
|
|
if isinstance(value, list):
|
|
return [sanitize(item) for item in value]
|
|
return value
|
|
|
|
|
|
def sanitize_value(key: str, value: Any) -> Any:
|
|
if key in SENSITIVE_KEYS:
|
|
return "***"
|
|
return sanitize(value)
|