from __future__ import annotations import logging import time import uuid from html import escape from importlib import resources from http import HTTPStatus from collections.abc import Iterable, Mapping from contextvars import ContextVar from contextlib import asynccontextmanager from dataclasses import asdict, is_dataclass from functools import wraps from inspect import isawaitable from typing import Any from fastapi import FastAPI, Request from fastapi.dependencies.utils import ( _should_embed_body_fields, get_body_field, get_dependant, get_flat_dependant, get_parameterless_sub_dependant, ) from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.routing import APIRoute from fastapi.routing import request_response from fastapi.responses import HTMLResponse, JSONResponse from pydantic import BaseModel from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.responses import Response from sqlalchemy.exc import SQLAlchemyError from iti.auth.permissions import StaticPermissionProvider from iti.audit import init_audit from iti.cache import CacheManager from iti.config import BaseConfig, get_config from iti.db import configure_db from iti.exceptions import ItiError from iti.health import router as health_router from iti.limiter import SimpleLimiter from iti.logging_config import configure_logging, log_extra from iti.modules import init_modules from iti.exchange import get_exchange_registry from iti.exchange import models as _exchange_models from iti.responses.auto import is_envelope_payload, is_raw_response_request from iti.responses import fail from iti.service_client import init_service_clients from iti.tasks import init_task_runner logger = logging.getLogger("iti") error_logger = logging.getLogger("iti.error") _current_request: ContextVar[Request | None] = ContextVar("iti_current_request", default=None) DOCS_PICKER_TEMPLATE = "docs-picker.html" SCALAR_TEMPLATE = "scalar.html" def create_app( config_name: str | None = None, *, modules: Iterable[Any] | None = None, config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None = None, permission_provider: Any | None = None, ) -> FastAPI: config = _resolve_config(config_name, config_mapping) configure_logging(config) @asynccontextmanager async def lifespan(app: FastAPI): runner = getattr(app.state, "iti_task_runner", None) audit_dispatcher = getattr(app.state, "audit_dispatcher", None) if audit_dispatcher: audit_dispatcher.start() if runner and config.tasks_enabled: runner.start() yield if runner: runner.stop() if audit_dispatcher: audit_dispatcher.stop() for client in getattr(app.state, "iti_service_clients", {}).values(): client.close() app = FastAPI( title=config.app_name, debug=config.debug, lifespan=lifespan, docs_url=None, redoc_url=None, ) install_docs(app) app.state.config = config app.state.cache = CacheManager(default_timeout=config.cache_default_timeout) app.state.limiter = SimpleLimiter(enabled=config.ratelimit_enabled) app.state.permission_provider = permission_provider or StaticPermissionProvider() app.state.exchange_enabled = config.exchange_enabled init_middlewares(app) if config.cors_origins: app.add_middleware( CORSMiddleware, allow_origins=config.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) engine, sessionmaker = configure_db( config.database_url, echo=config.sqlalchemy_echo, pool_pre_ping=config.sqlalchemy_pool_pre_ping, ) app.state.db_engine = engine app.state.db_sessionmaker = sessionmaker init_error_handlers(app) init_service_clients(app, config.services) init_task_runner(app) get_exchange_registry(app) init_audit(app) module_list = list(modules or []) if config.exchange_enabled and not any( getattr(module, "name", None) == "exchange" for module in module_list ): from iti.exchange.module import create_exchange_module module_list.append(create_exchange_module()) module_registry = init_modules(app, module_list) app.state.iti_modules = module_registry if config.health_enabled: app.include_router(health_router) module_registry.run_phase("register_routes", app) module_registry.run_phase("register_permissions", app) module_registry.run_phase("register_menu_seed", app) install_auto_envelope(app) return app def install_docs(app: FastAPI) -> None: @app.get("/docs", include_in_schema=False) def docs(ui: str | None = None) -> HTMLResponse: doc_options = _enabled_doc_options(app) if ui == "swagger" and "swagger" in doc_options: return get_swagger_ui_html( openapi_url=app.openapi_url or "/openapi.json", title=f"{app.title} - Swagger UI", ) if ui == "scalar" and "scalar" in doc_options: return _scalar_docs_html(app) if ui == "redoc" and "redoc" in doc_options: return get_redoc_html( openapi_url=app.openapi_url or "/openapi.json", title=f"{app.title} - ReDoc", ) return _docs_picker_html(app, doc_options) def _enabled_doc_options(app: FastAPI) -> dict[str, dict[str, str]]: configured = getattr(app.state.config, "docs_ui_enabled", ["swagger", "scalar", "redoc"]) all_options = { "swagger": { "class": "swagger", "label": "Swagger", "abbr": "SW", "description": "传统交互文档,适合快速试接口。", }, "scalar": { "class": "scalar", "label": "Scalar", "abbr": "SC", "description": "现代接口参考,适合阅读和调试。", }, "redoc": { "class": "redoc", "label": "ReDoc", "abbr": "RD", "description": "结构化阅读文档,适合查看模型关系。", }, } return {name: all_options[name] for name in configured if name in all_options} def _docs_picker_html(app: FastAPI, doc_options: Mapping[str, dict[str, str]]) -> HTMLResponse: title = escape(app.title) option_cards = "\n".join( _doc_option_card(name, option) for name, option in doc_options.items() ) html = _render_template( DOCS_PICKER_TEMPLATE, { "title": title, "option_cards": option_cards, }, ) return HTMLResponse(html) def _doc_option_card(name: str, option: Mapping[str, str]) -> str: label = escape(option["label"]) class_name = escape(option["class"]) abbr = escape(option["abbr"]) description = escape(option["description"]) href = escape(f"?ui={name}") return f""" {label} {description} """ def _scalar_docs_html(app: FastAPI) -> HTMLResponse: html = _render_template( SCALAR_TEMPLATE, { "title": escape(app.title), "openapi_url": escape(app.openapi_url or "/openapi.json"), }, ) return HTMLResponse(html) def _render_template(name: str, values: Mapping[str, str]) -> str: template = resources.files("iti.templates").joinpath(name).read_text(encoding="utf-8") for key, value in values.items(): template = template.replace("{{ " + key + " }}", value) return template def init_middlewares(app: FastAPI) -> None: @app.middleware("http") async def request_context_middleware(request: Request, call_next): token = _current_request.set(request) trace_id = request.headers.get("X-Trace-Id") or uuid.uuid4().hex request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex request.state.trace_id = trace_id request.state.request_id = request_id started_at = time.perf_counter() response: Response | None try: response = await call_next(request) except Exception: request.state.response_code = getattr(request.state, "response_code", 500) _log_request(request, started_at, 500) _current_request.reset(token) raise response.headers.setdefault("X-Trace-Id", trace_id) response.headers.setdefault("X-Request-Id", request_id) _log_request(request, started_at, response.status_code) _current_request.reset(token) return response def _log_request(request: Request, started_at: float, status_code: int) -> None: duration_ms = round((time.perf_counter() - started_at) * 1000, 2) logger.info( "request method=%s path=%s status=%s code=%s durationMs=%s ip=%s", request.method, request.url.path, status_code, getattr(request.state, "response_code", "-"), duration_ms, request.client.host if request.client else "-", extra=log_extra(request), ) def _resolve_config( config_name: str | None, config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None, ) -> BaseConfig: if config_mapping is None: return get_config(config_name) if isinstance(config_mapping, Mapping): env_name = config_name or "dev" value = config_mapping.get(env_name, config_mapping.get("default")) if value is None: return get_config(config_name) return value() if isinstance(value, type) else value return config_mapping() if isinstance(config_mapping, type) else config_mapping def init_error_handlers(app: FastAPI) -> None: @app.exception_handler(ItiError) async def handle_iti_error(request: Request, exc: ItiError): request.state.response_code = exc.code return JSONResponse( status_code=request.app.state.config.response_envelope_http_status, content=fail(exc.message, code=exc.code, data=exc.data), ) @app.exception_handler(RequestValidationError) async def handle_validation_error(request: Request, exc: RequestValidationError): request.state.response_code = 422 return JSONResponse( status_code=request.app.state.config.response_envelope_http_status, content=fail("参数验证错误", code=422, data=exc.errors()), ) @app.exception_handler(StarletteHTTPException) async def handle_http_error(request: Request, exc: StarletteHTTPException): request.state.response_code = exc.status_code message, data = _http_error_payload(exc) return JSONResponse( status_code=request.app.state.config.response_envelope_http_status, content=fail(message, code=exc.status_code, data=data), headers=exc.headers, ) @app.exception_handler(SQLAlchemyError) async def handle_db_error(request: Request, exc: SQLAlchemyError): request.state.response_code = 500 error_logger.exception("database error", extra=log_extra(request)) return JSONResponse( status_code=request.app.state.config.response_envelope_http_status, content=fail("数据库错误", code=500, data=str(exc)), ) @app.exception_handler(Exception) async def handle_exception(request: Request, exc: Exception): request.state.response_code = 500 error_logger.exception("server error", extra=log_extra(request)) return JSONResponse( status_code=request.app.state.config.response_envelope_http_status, content=fail("服务器错误", code=500, data=str(exc)), ) def _http_error_payload(exc: StarletteHTTPException) -> tuple[str, Any]: detail = exc.detail if isinstance(detail, str): return detail, None if detail is None: return _http_status_phrase(exc.status_code), None if isinstance(detail, Mapping): for key in ("message", "detail", "error"): value = detail.get(key) if isinstance(value, str) and value: return value, detail return _http_status_phrase(exc.status_code), detail def _http_status_phrase(status_code: int) -> str: try: return HTTPStatus(status_code).phrase except ValueError: return "HTTP Error" def to_plain_data(value: Any) -> Any: if is_dataclass(value): return asdict(value) return value def install_auto_envelope(app: FastAPI) -> None: config = app.state.config if not config.response_envelope_enabled: return raw_paths = tuple(config.raw_response_paths) for route in app.routes: if not isinstance(route, APIRoute): continue if getattr(route, "__iti_envelope_installed__", False): continue if _is_route_raw(route, raw_paths): continue original_call = route.dependant.call if original_call is None: continue route.endpoint = _wrap_endpoint_with_envelope(original_call) _rebuild_route_dependant(route) setattr(route, "__iti_envelope_installed__", True) def _is_route_raw(route: APIRoute, raw_paths: Iterable[str]) -> bool: endpoint = route.endpoint if getattr(endpoint, "__iti_raw_response__", False): return True for path in route.path_format, route.path: request = _PathOnlyRequest(path) if is_raw_response_request(request, raw_paths): return True return False def _wrap_endpoint_with_envelope(func): @wraps(func) async def wrapper(*args, **kwargs): result = func(*args, **kwargs) if isawaitable(result): result = await result if isinstance(result, Response): return result payload = _to_jsonable(result) if is_envelope_payload(payload): _mark_response_code(args, kwargs, payload["code"]) return payload _mark_response_code(args, kwargs, 200) return {"data": payload, "code": 200, "message": "成功"} return wrapper def _to_jsonable(value: Any) -> Any: if value is None: return None if isinstance(value, BaseModel): return value.model_dump(by_alias=True) if is_dataclass(value): return asdict(value) return value def _mark_response_code(args: tuple[Any, ...], kwargs: dict[str, Any], code: int) -> None: request = _request_from_call(args, kwargs) or _current_request.get() if request is not None: request.state.response_code = code def _request_from_call(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None: for value in list(args) + list(kwargs.values()): if isinstance(value, Request): return value return None def _rebuild_route_dependant(route: APIRoute) -> None: route.dependant = get_dependant( path=route.path_format, call=route.endpoint, scope="function", ) for depends in route.dependencies[::-1]: route.dependant.dependencies.insert( 0, get_parameterless_sub_dependant(depends=depends, path=route.path_format), ) route._flat_dependant = get_flat_dependant(route.dependant) route._embed_body_fields = _should_embed_body_fields(route._flat_dependant.body_params) route.body_field = get_body_field( flat_dependant=route._flat_dependant, name=route.unique_id, embed_body_fields=route._embed_body_fields, ) route.app = request_response(route.get_route_handler()) class _PathOnlyRequest: def __init__(self, path: str) -> None: self.url = type("URL", (), {"path": path})() self.scope = {"endpoint": None}