from __future__ import annotations import logging import time import uuid from collections.abc import Iterable, Mapping from contextvars import ContextVar from contextlib import asynccontextmanager from dataclasses import asdict, is_dataclass from functools import wraps from inspect import isawaitable from typing import Any from fastapi import FastAPI, Request from fastapi.dependencies.utils import ( _should_embed_body_fields, get_body_field, get_dependant, get_flat_dependant, get_parameterless_sub_dependant, ) from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.routing import APIRoute from fastapi.routing import request_response from fastapi.responses import JSONResponse from pydantic import BaseModel from starlette.responses import Response from sqlalchemy.exc import SQLAlchemyError from iti.auth.permissions import StaticPermissionProvider from iti.audit import init_audit from iti.cache import CacheManager from iti.config import BaseConfig, get_config from iti.db import configure_db from iti.exceptions import ItiError from iti.health import router as health_router from iti.limiter import SimpleLimiter from iti.logging_config import configure_logging, log_extra from iti.modules import init_modules from iti.responses.auto import is_envelope_payload, is_raw_response_request from iti.responses import fail from iti.service_client import init_service_clients from iti.tasks import init_task_runner logger = logging.getLogger("iti") error_logger = logging.getLogger("iti.error") _current_request: ContextVar[Request | None] = ContextVar("iti_current_request", default=None) def create_app( config_name: str | None = None, *, modules: Iterable[Any] | None = None, config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None = None, permission_provider: Any | None = None, ) -> FastAPI: config = _resolve_config(config_name, config_mapping) configure_logging(config) @asynccontextmanager async def lifespan(app: FastAPI): runner = getattr(app.state, "iti_task_runner", None) audit_dispatcher = getattr(app.state, "audit_dispatcher", None) if audit_dispatcher: audit_dispatcher.start() if runner and config.tasks_enabled: runner.start() yield if runner: runner.stop() if audit_dispatcher: audit_dispatcher.stop() for client in getattr(app.state, "iti_service_clients", {}).values(): client.close() app = FastAPI( title=config.app_name, debug=config.debug, lifespan=lifespan, ) app.state.config = config app.state.cache = CacheManager(default_timeout=config.cache_default_timeout) app.state.limiter = SimpleLimiter(enabled=config.ratelimit_enabled) app.state.permission_provider = permission_provider or StaticPermissionProvider() init_middlewares(app) if config.cors_origins: app.add_middleware( CORSMiddleware, allow_origins=config.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) engine, sessionmaker = configure_db( config.database_url, echo=config.sqlalchemy_echo, pool_pre_ping=config.sqlalchemy_pool_pre_ping, ) app.state.db_engine = engine app.state.db_sessionmaker = sessionmaker init_error_handlers(app) init_service_clients(app, config.services) init_task_runner(app) init_audit(app) module_registry = init_modules(app, modules) app.state.iti_modules = module_registry if config.health_enabled: app.include_router(health_router) module_registry.run_phase("register_routes", app) module_registry.run_phase("register_permissions", app) module_registry.run_phase("register_menu_seed", app) install_auto_envelope(app) return app def init_middlewares(app: FastAPI) -> None: @app.middleware("http") async def request_context_middleware(request: Request, call_next): token = _current_request.set(request) trace_id = request.headers.get("X-Trace-Id") or uuid.uuid4().hex request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex request.state.trace_id = trace_id request.state.request_id = request_id started_at = time.perf_counter() response: Response | None try: response = await call_next(request) except Exception: request.state.response_code = getattr(request.state, "response_code", 500) _log_request(request, started_at, 500) _current_request.reset(token) raise response.headers.setdefault("X-Trace-Id", trace_id) response.headers.setdefault("X-Request-Id", request_id) _log_request(request, started_at, response.status_code) _current_request.reset(token) return response def _log_request(request: Request, started_at: float, status_code: int) -> None: duration_ms = round((time.perf_counter() - started_at) * 1000, 2) logger.info( "request method=%s path=%s status=%s code=%s durationMs=%s ip=%s", request.method, request.url.path, status_code, getattr(request.state, "response_code", "-"), duration_ms, request.client.host if request.client else "-", extra=log_extra(request), ) def _resolve_config( config_name: str | None, config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None, ) -> BaseConfig: if config_mapping is None: return get_config(config_name) if isinstance(config_mapping, Mapping): env_name = config_name or "dev" value = config_mapping.get(env_name, config_mapping.get("default")) if value is None: return get_config(config_name) return value() if isinstance(value, type) else value return config_mapping() if isinstance(config_mapping, type) else config_mapping def init_error_handlers(app: FastAPI) -> None: @app.exception_handler(ItiError) async def handle_iti_error(request: Request, exc: ItiError): request.state.response_code = exc.code return JSONResponse( status_code=request.app.state.config.response_envelope_http_status, content=fail(exc.message, code=exc.code, data=exc.data), ) @app.exception_handler(RequestValidationError) async def handle_validation_error(request: Request, exc: RequestValidationError): request.state.response_code = 422 return JSONResponse( status_code=request.app.state.config.response_envelope_http_status, content=fail("参数验证错误", code=422, data=exc.errors()), ) @app.exception_handler(SQLAlchemyError) async def handle_db_error(request: Request, exc: SQLAlchemyError): request.state.response_code = 500 error_logger.exception("database error", extra=log_extra(request)) return JSONResponse( status_code=request.app.state.config.response_envelope_http_status, content=fail("数据库错误", code=500, data=str(exc)), ) @app.exception_handler(Exception) async def handle_exception(request: Request, exc: Exception): request.state.response_code = 500 error_logger.exception("server error", extra=log_extra(request)) return JSONResponse( status_code=request.app.state.config.response_envelope_http_status, content=fail("服务器错误", code=500, data=str(exc)), ) def to_plain_data(value: Any) -> Any: if is_dataclass(value): return asdict(value) return value def install_auto_envelope(app: FastAPI) -> None: config = app.state.config if not config.response_envelope_enabled: return raw_paths = tuple(config.raw_response_paths) for route in app.routes: if not isinstance(route, APIRoute): continue if getattr(route, "__iti_envelope_installed__", False): continue if _is_route_raw(route, raw_paths): continue original_call = route.dependant.call if original_call is None: continue route.endpoint = _wrap_endpoint_with_envelope(original_call) _rebuild_route_dependant(route) setattr(route, "__iti_envelope_installed__", True) def _is_route_raw(route: APIRoute, raw_paths: Iterable[str]) -> bool: endpoint = route.endpoint if getattr(endpoint, "__iti_raw_response__", False): return True for path in route.path_format, route.path: request = _PathOnlyRequest(path) if is_raw_response_request(request, raw_paths): return True return False def _wrap_endpoint_with_envelope(func): @wraps(func) async def wrapper(*args, **kwargs): result = func(*args, **kwargs) if isawaitable(result): result = await result if isinstance(result, Response): return result payload = _to_jsonable(result) if is_envelope_payload(payload): _mark_response_code(args, kwargs, payload["code"]) return payload _mark_response_code(args, kwargs, 200) return {"data": payload, "code": 200, "message": "成功"} return wrapper def _to_jsonable(value: Any) -> Any: if value is None: return None if isinstance(value, BaseModel): return value.model_dump(by_alias=True) if is_dataclass(value): return asdict(value) return value def _mark_response_code(args: tuple[Any, ...], kwargs: dict[str, Any], code: int) -> None: request = _request_from_call(args, kwargs) or _current_request.get() if request is not None: request.state.response_code = code def _request_from_call(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None: for value in list(args) + list(kwargs.values()): if isinstance(value, Request): return value return None def _rebuild_route_dependant(route: APIRoute) -> None: route.dependant = get_dependant( path=route.path_format, call=route.endpoint, scope="function", ) for depends in route.dependencies[::-1]: route.dependant.dependencies.insert( 0, get_parameterless_sub_dependant(depends=depends, path=route.path_format), ) route._flat_dependant = get_flat_dependant(route.dependant) route._embed_body_fields = _should_embed_body_fields(route._flat_dependant.body_params) route.body_field = get_body_field( flat_dependant=route._flat_dependant, name=route.unique_id, embed_body_fields=route._embed_body_fields, ) route.app = request_response(route.get_route_handler()) class _PathOnlyRequest: def __init__(self, path: str) -> None: self.url = type("URL", (), {"path": path})() self.scope = {"endpoint": None}