You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
346 lines
12 KiB
Python
346 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from http import HTTPStatus
|
|
from collections.abc import Iterable, Mapping
|
|
from contextvars import ContextVar
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import asdict, is_dataclass
|
|
from functools import wraps
|
|
from inspect import isawaitable
|
|
from typing import Any
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.dependencies.utils import (
|
|
_should_embed_body_fields,
|
|
get_body_field,
|
|
get_dependant,
|
|
get_flat_dependant,
|
|
get_parameterless_sub_dependant,
|
|
)
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.routing import APIRoute
|
|
from fastapi.routing import request_response
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from starlette.responses import Response
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from iti.auth.permissions import StaticPermissionProvider
|
|
from iti.audit import init_audit
|
|
from iti.cache import CacheManager
|
|
from iti.config import BaseConfig, get_config
|
|
from iti.db import configure_db
|
|
from iti.exceptions import ItiError
|
|
from iti.health import router as health_router
|
|
from iti.limiter import SimpleLimiter
|
|
from iti.logging_config import configure_logging, log_extra
|
|
from iti.modules import init_modules
|
|
from iti.responses.auto import is_envelope_payload, is_raw_response_request
|
|
from iti.responses import fail
|
|
from iti.service_client import init_service_clients
|
|
from iti.tasks import init_task_runner
|
|
|
|
logger = logging.getLogger("iti")
|
|
error_logger = logging.getLogger("iti.error")
|
|
_current_request: ContextVar[Request | None] = ContextVar("iti_current_request", default=None)
|
|
|
|
|
|
def create_app(
|
|
config_name: str | None = None,
|
|
*,
|
|
modules: Iterable[Any] | None = None,
|
|
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None = None,
|
|
permission_provider: Any | None = None,
|
|
) -> FastAPI:
|
|
config = _resolve_config(config_name, config_mapping)
|
|
configure_logging(config)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
runner = getattr(app.state, "iti_task_runner", None)
|
|
audit_dispatcher = getattr(app.state, "audit_dispatcher", None)
|
|
if audit_dispatcher:
|
|
audit_dispatcher.start()
|
|
if runner and config.tasks_enabled:
|
|
runner.start()
|
|
yield
|
|
if runner:
|
|
runner.stop()
|
|
if audit_dispatcher:
|
|
audit_dispatcher.stop()
|
|
for client in getattr(app.state, "iti_service_clients", {}).values():
|
|
client.close()
|
|
|
|
app = FastAPI(
|
|
title=config.app_name,
|
|
debug=config.debug,
|
|
lifespan=lifespan,
|
|
)
|
|
app.state.config = config
|
|
app.state.cache = CacheManager(default_timeout=config.cache_default_timeout)
|
|
app.state.limiter = SimpleLimiter(enabled=config.ratelimit_enabled)
|
|
app.state.permission_provider = permission_provider or StaticPermissionProvider()
|
|
|
|
init_middlewares(app)
|
|
|
|
if config.cors_origins:
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=config.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
engine, sessionmaker = configure_db(
|
|
config.database_url,
|
|
echo=config.sqlalchemy_echo,
|
|
pool_pre_ping=config.sqlalchemy_pool_pre_ping,
|
|
)
|
|
app.state.db_engine = engine
|
|
app.state.db_sessionmaker = sessionmaker
|
|
|
|
init_error_handlers(app)
|
|
init_service_clients(app, config.services)
|
|
init_task_runner(app)
|
|
init_audit(app)
|
|
module_registry = init_modules(app, modules)
|
|
app.state.iti_modules = module_registry
|
|
if config.health_enabled:
|
|
app.include_router(health_router)
|
|
module_registry.run_phase("register_routes", app)
|
|
module_registry.run_phase("register_permissions", app)
|
|
module_registry.run_phase("register_menu_seed", app)
|
|
install_auto_envelope(app)
|
|
return app
|
|
|
|
|
|
def init_middlewares(app: FastAPI) -> None:
|
|
@app.middleware("http")
|
|
async def request_context_middleware(request: Request, call_next):
|
|
token = _current_request.set(request)
|
|
trace_id = request.headers.get("X-Trace-Id") or uuid.uuid4().hex
|
|
request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex
|
|
request.state.trace_id = trace_id
|
|
request.state.request_id = request_id
|
|
started_at = time.perf_counter()
|
|
response: Response | None
|
|
try:
|
|
response = await call_next(request)
|
|
except Exception:
|
|
request.state.response_code = getattr(request.state, "response_code", 500)
|
|
_log_request(request, started_at, 500)
|
|
_current_request.reset(token)
|
|
raise
|
|
response.headers.setdefault("X-Trace-Id", trace_id)
|
|
response.headers.setdefault("X-Request-Id", request_id)
|
|
_log_request(request, started_at, response.status_code)
|
|
_current_request.reset(token)
|
|
return response
|
|
|
|
|
|
def _log_request(request: Request, started_at: float, status_code: int) -> None:
|
|
duration_ms = round((time.perf_counter() - started_at) * 1000, 2)
|
|
logger.info(
|
|
"request method=%s path=%s status=%s code=%s durationMs=%s ip=%s",
|
|
request.method,
|
|
request.url.path,
|
|
status_code,
|
|
getattr(request.state, "response_code", "-"),
|
|
duration_ms,
|
|
request.client.host if request.client else "-",
|
|
extra=log_extra(request),
|
|
)
|
|
|
|
|
|
def _resolve_config(
|
|
config_name: str | None,
|
|
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None,
|
|
) -> BaseConfig:
|
|
if config_mapping is None:
|
|
return get_config(config_name)
|
|
if isinstance(config_mapping, Mapping):
|
|
env_name = config_name or "dev"
|
|
value = config_mapping.get(env_name, config_mapping.get("default"))
|
|
if value is None:
|
|
return get_config(config_name)
|
|
return value() if isinstance(value, type) else value
|
|
return config_mapping() if isinstance(config_mapping, type) else config_mapping
|
|
|
|
|
|
def init_error_handlers(app: FastAPI) -> None:
|
|
@app.exception_handler(ItiError)
|
|
async def handle_iti_error(request: Request, exc: ItiError):
|
|
request.state.response_code = exc.code
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail(exc.message, code=exc.code, data=exc.data),
|
|
)
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def handle_validation_error(request: Request, exc: RequestValidationError):
|
|
request.state.response_code = 422
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail("参数验证错误", code=422, data=exc.errors()),
|
|
)
|
|
|
|
@app.exception_handler(StarletteHTTPException)
|
|
async def handle_http_error(request: Request, exc: StarletteHTTPException):
|
|
request.state.response_code = exc.status_code
|
|
message, data = _http_error_payload(exc)
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail(message, code=exc.status_code, data=data),
|
|
headers=exc.headers,
|
|
)
|
|
|
|
@app.exception_handler(SQLAlchemyError)
|
|
async def handle_db_error(request: Request, exc: SQLAlchemyError):
|
|
request.state.response_code = 500
|
|
error_logger.exception("database error", extra=log_extra(request))
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail("数据库错误", code=500, data=str(exc)),
|
|
)
|
|
|
|
@app.exception_handler(Exception)
|
|
async def handle_exception(request: Request, exc: Exception):
|
|
request.state.response_code = 500
|
|
error_logger.exception("server error", extra=log_extra(request))
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail("服务器错误", code=500, data=str(exc)),
|
|
)
|
|
|
|
|
|
def _http_error_payload(exc: StarletteHTTPException) -> tuple[str, Any]:
|
|
detail = exc.detail
|
|
if isinstance(detail, str):
|
|
return detail, None
|
|
if detail is None:
|
|
return _http_status_phrase(exc.status_code), None
|
|
if isinstance(detail, Mapping):
|
|
for key in ("message", "detail", "error"):
|
|
value = detail.get(key)
|
|
if isinstance(value, str) and value:
|
|
return value, detail
|
|
return _http_status_phrase(exc.status_code), detail
|
|
|
|
|
|
def _http_status_phrase(status_code: int) -> str:
|
|
try:
|
|
return HTTPStatus(status_code).phrase
|
|
except ValueError:
|
|
return "HTTP Error"
|
|
|
|
|
|
def to_plain_data(value: Any) -> Any:
|
|
if is_dataclass(value):
|
|
return asdict(value)
|
|
return value
|
|
|
|
|
|
def install_auto_envelope(app: FastAPI) -> None:
|
|
config = app.state.config
|
|
if not config.response_envelope_enabled:
|
|
return
|
|
raw_paths = tuple(config.raw_response_paths)
|
|
for route in app.routes:
|
|
if not isinstance(route, APIRoute):
|
|
continue
|
|
if getattr(route, "__iti_envelope_installed__", False):
|
|
continue
|
|
if _is_route_raw(route, raw_paths):
|
|
continue
|
|
original_call = route.dependant.call
|
|
if original_call is None:
|
|
continue
|
|
route.endpoint = _wrap_endpoint_with_envelope(original_call)
|
|
_rebuild_route_dependant(route)
|
|
setattr(route, "__iti_envelope_installed__", True)
|
|
|
|
|
|
def _is_route_raw(route: APIRoute, raw_paths: Iterable[str]) -> bool:
|
|
endpoint = route.endpoint
|
|
if getattr(endpoint, "__iti_raw_response__", False):
|
|
return True
|
|
for path in route.path_format, route.path:
|
|
request = _PathOnlyRequest(path)
|
|
if is_raw_response_request(request, raw_paths):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _wrap_endpoint_with_envelope(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
result = func(*args, **kwargs)
|
|
if isawaitable(result):
|
|
result = await result
|
|
if isinstance(result, Response):
|
|
return result
|
|
payload = _to_jsonable(result)
|
|
if is_envelope_payload(payload):
|
|
_mark_response_code(args, kwargs, payload["code"])
|
|
return payload
|
|
_mark_response_code(args, kwargs, 200)
|
|
return {"data": payload, "code": 200, "message": "成功"}
|
|
|
|
return wrapper
|
|
|
|
|
|
def _to_jsonable(value: Any) -> Any:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, BaseModel):
|
|
return value.model_dump(by_alias=True)
|
|
if is_dataclass(value):
|
|
return asdict(value)
|
|
return value
|
|
|
|
|
|
def _mark_response_code(args: tuple[Any, ...], kwargs: dict[str, Any], code: int) -> None:
|
|
request = _request_from_call(args, kwargs) or _current_request.get()
|
|
if request is not None:
|
|
request.state.response_code = code
|
|
|
|
|
|
def _request_from_call(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None:
|
|
for value in list(args) + list(kwargs.values()):
|
|
if isinstance(value, Request):
|
|
return value
|
|
return None
|
|
|
|
|
|
def _rebuild_route_dependant(route: APIRoute) -> None:
|
|
route.dependant = get_dependant(
|
|
path=route.path_format,
|
|
call=route.endpoint,
|
|
scope="function",
|
|
)
|
|
for depends in route.dependencies[::-1]:
|
|
route.dependant.dependencies.insert(
|
|
0,
|
|
get_parameterless_sub_dependant(depends=depends, path=route.path_format),
|
|
)
|
|
route._flat_dependant = get_flat_dependant(route.dependant)
|
|
route._embed_body_fields = _should_embed_body_fields(route._flat_dependant.body_params)
|
|
route.body_field = get_body_field(
|
|
flat_dependant=route._flat_dependant,
|
|
name=route.unique_id,
|
|
embed_body_fields=route._embed_body_fields,
|
|
)
|
|
route.app = request_response(route.get_route_handler())
|
|
|
|
|
|
class _PathOnlyRequest:
|
|
def __init__(self, path: str) -> None:
|
|
self.url = type("URL", (), {"path": path})()
|
|
self.scope = {"endpoint": None}
|