You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-Flask/iti/app.py

313 lines
11 KiB
Python

from __future__ import annotations
import logging
import time
import uuid
from collections.abc import Iterable, Mapping
from contextvars import ContextVar
from contextlib import asynccontextmanager
from dataclasses import asdict, is_dataclass
from functools import wraps
from inspect import isawaitable
from typing import Any
from fastapi import FastAPI, Request
from fastapi.dependencies.utils import (
_should_embed_body_fields,
get_body_field,
get_dependant,
get_flat_dependant,
get_parameterless_sub_dependant,
)
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute
from fastapi.routing import request_response
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from starlette.responses import Response
from sqlalchemy.exc import SQLAlchemyError
from iti.auth.permissions import StaticPermissionProvider
from iti.audit import init_audit
from iti.cache import CacheManager
from iti.config import BaseConfig, get_config
from iti.db import configure_db
from iti.exceptions import ItiError
from iti.health import router as health_router
from iti.limiter import SimpleLimiter
from iti.logging_config import configure_logging, log_extra
from iti.modules import init_modules
from iti.responses.auto import is_envelope_payload, is_raw_response_request
from iti.responses import fail
from iti.service_client import init_service_clients
from iti.tasks import init_task_runner
logger = logging.getLogger("iti")
error_logger = logging.getLogger("iti.error")
_current_request: ContextVar[Request | None] = ContextVar("iti_current_request", default=None)
def create_app(
config_name: str | None = None,
*,
modules: Iterable[Any] | None = None,
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None = None,
permission_provider: Any | None = None,
) -> FastAPI:
config = _resolve_config(config_name, config_mapping)
configure_logging(config)
@asynccontextmanager
async def lifespan(app: FastAPI):
runner = getattr(app.state, "iti_task_runner", None)
audit_dispatcher = getattr(app.state, "audit_dispatcher", None)
if audit_dispatcher:
audit_dispatcher.start()
if runner and config.tasks_enabled:
runner.start()
yield
if runner:
runner.stop()
if audit_dispatcher:
audit_dispatcher.stop()
for client in getattr(app.state, "iti_service_clients", {}).values():
client.close()
app = FastAPI(
title=config.app_name,
debug=config.debug,
lifespan=lifespan,
)
app.state.config = config
app.state.cache = CacheManager(default_timeout=config.cache_default_timeout)
app.state.limiter = SimpleLimiter(enabled=config.ratelimit_enabled)
app.state.permission_provider = permission_provider or StaticPermissionProvider()
init_middlewares(app)
if config.cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
engine, sessionmaker = configure_db(
config.database_url,
echo=config.sqlalchemy_echo,
pool_pre_ping=config.sqlalchemy_pool_pre_ping,
)
app.state.db_engine = engine
app.state.db_sessionmaker = sessionmaker
init_error_handlers(app)
init_service_clients(app, config.services)
init_task_runner(app)
init_audit(app)
module_registry = init_modules(app, modules)
app.state.iti_modules = module_registry
if config.health_enabled:
app.include_router(health_router)
module_registry.run_phase("register_routes", app)
module_registry.run_phase("register_permissions", app)
module_registry.run_phase("register_menu_seed", app)
install_auto_envelope(app)
return app
def init_middlewares(app: FastAPI) -> None:
@app.middleware("http")
async def request_context_middleware(request: Request, call_next):
token = _current_request.set(request)
trace_id = request.headers.get("X-Trace-Id") or uuid.uuid4().hex
request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex
request.state.trace_id = trace_id
request.state.request_id = request_id
started_at = time.perf_counter()
response: Response | None
try:
response = await call_next(request)
except Exception:
request.state.response_code = getattr(request.state, "response_code", 500)
_log_request(request, started_at, 500)
_current_request.reset(token)
raise
response.headers.setdefault("X-Trace-Id", trace_id)
response.headers.setdefault("X-Request-Id", request_id)
_log_request(request, started_at, response.status_code)
_current_request.reset(token)
return response
def _log_request(request: Request, started_at: float, status_code: int) -> None:
duration_ms = round((time.perf_counter() - started_at) * 1000, 2)
logger.info(
"request method=%s path=%s status=%s code=%s durationMs=%s ip=%s",
request.method,
request.url.path,
status_code,
getattr(request.state, "response_code", "-"),
duration_ms,
request.client.host if request.client else "-",
extra=log_extra(request),
)
def _resolve_config(
config_name: str | None,
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None,
) -> BaseConfig:
if config_mapping is None:
return get_config(config_name)
if isinstance(config_mapping, Mapping):
env_name = config_name or "dev"
value = config_mapping.get(env_name, config_mapping.get("default"))
if value is None:
return get_config(config_name)
return value() if isinstance(value, type) else value
return config_mapping() if isinstance(config_mapping, type) else config_mapping
def init_error_handlers(app: FastAPI) -> None:
@app.exception_handler(ItiError)
async def handle_iti_error(request: Request, exc: ItiError):
request.state.response_code = exc.code
return JSONResponse(
status_code=request.app.state.config.response_envelope_http_status,
content=fail(exc.message, code=exc.code, data=exc.data),
)
@app.exception_handler(RequestValidationError)
async def handle_validation_error(request: Request, exc: RequestValidationError):
request.state.response_code = 422
return JSONResponse(
status_code=request.app.state.config.response_envelope_http_status,
content=fail("参数验证错误", code=422, data=exc.errors()),
)
@app.exception_handler(SQLAlchemyError)
async def handle_db_error(request: Request, exc: SQLAlchemyError):
request.state.response_code = 500
error_logger.exception("database error", extra=log_extra(request))
return JSONResponse(
status_code=request.app.state.config.response_envelope_http_status,
content=fail("数据库错误", code=500, data=str(exc)),
)
@app.exception_handler(Exception)
async def handle_exception(request: Request, exc: Exception):
request.state.response_code = 500
error_logger.exception("server error", extra=log_extra(request))
return JSONResponse(
status_code=request.app.state.config.response_envelope_http_status,
content=fail("服务器错误", code=500, data=str(exc)),
)
def to_plain_data(value: Any) -> Any:
if is_dataclass(value):
return asdict(value)
return value
def install_auto_envelope(app: FastAPI) -> None:
config = app.state.config
if not config.response_envelope_enabled:
return
raw_paths = tuple(config.raw_response_paths)
for route in app.routes:
if not isinstance(route, APIRoute):
continue
if getattr(route, "__iti_envelope_installed__", False):
continue
if _is_route_raw(route, raw_paths):
continue
original_call = route.dependant.call
if original_call is None:
continue
route.endpoint = _wrap_endpoint_with_envelope(original_call)
_rebuild_route_dependant(route)
setattr(route, "__iti_envelope_installed__", True)
def _is_route_raw(route: APIRoute, raw_paths: Iterable[str]) -> bool:
endpoint = route.endpoint
if getattr(endpoint, "__iti_raw_response__", False):
return True
for path in route.path_format, route.path:
request = _PathOnlyRequest(path)
if is_raw_response_request(request, raw_paths):
return True
return False
def _wrap_endpoint_with_envelope(func):
@wraps(func)
async def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if isawaitable(result):
result = await result
if isinstance(result, Response):
return result
payload = _to_jsonable(result)
if is_envelope_payload(payload):
_mark_response_code(args, kwargs, payload["code"])
return payload
_mark_response_code(args, kwargs, 200)
return {"data": payload, "code": 200, "message": "成功"}
return wrapper
def _to_jsonable(value: Any) -> Any:
if value is None:
return None
if isinstance(value, BaseModel):
return value.model_dump(by_alias=True)
if is_dataclass(value):
return asdict(value)
return value
def _mark_response_code(args: tuple[Any, ...], kwargs: dict[str, Any], code: int) -> None:
request = _request_from_call(args, kwargs) or _current_request.get()
if request is not None:
request.state.response_code = code
def _request_from_call(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None:
for value in list(args) + list(kwargs.values()):
if isinstance(value, Request):
return value
return None
def _rebuild_route_dependant(route: APIRoute) -> None:
route.dependant = get_dependant(
path=route.path_format,
call=route.endpoint,
scope="function",
)
for depends in route.dependencies[::-1]:
route.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=route.path_format),
)
route._flat_dependant = get_flat_dependant(route.dependant)
route._embed_body_fields = _should_embed_body_fields(route._flat_dependant.body_params)
route.body_field = get_body_field(
flat_dependant=route._flat_dependant,
name=route.unique_id,
embed_body_fields=route._embed_body_fields,
)
route.app = request_response(route.get_route_handler())
class _PathOnlyRequest:
def __init__(self, path: str) -> None:
self.url = type("URL", (), {"path": path})()
self.scope = {"endpoint": None}