You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-Flask/iti/app.py

452 lines
16 KiB
Python

from __future__ import annotations
import logging
import time
import uuid
from html import escape
from importlib import resources
from http import HTTPStatus
from collections.abc import Iterable, Mapping
from contextvars import ContextVar
from contextlib import asynccontextmanager
from dataclasses import asdict, is_dataclass
from functools import wraps
from inspect import isawaitable
from typing import Any
from fastapi import FastAPI, Request
from fastapi.dependencies.utils import (
_should_embed_body_fields,
get_body_field,
get_dependant,
get_flat_dependant,
get_parameterless_sub_dependant,
)
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.routing import APIRoute
from fastapi.routing import request_response
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.responses import Response
from sqlalchemy.exc import SQLAlchemyError
from iti.auth.permissions import StaticPermissionProvider
from iti.audit import init_audit
from iti.cache import CacheManager
from iti.config import BaseConfig, get_config
from iti.db import configure_db
from iti.exceptions import ItiError
from iti.health import router as health_router
from iti.limiter import SimpleLimiter
from iti.logging_config import configure_logging, log_extra
from iti.modules import init_modules
from iti.responses.auto import is_envelope_payload, is_raw_response_request
from iti.responses import fail
from iti.service_client import init_service_clients
from iti.tasks import init_task_runner
logger = logging.getLogger("iti")
error_logger = logging.getLogger("iti.error")
_current_request: ContextVar[Request | None] = ContextVar("iti_current_request", default=None)
DOCS_PICKER_TEMPLATE = "docs-picker.html"
SCALAR_TEMPLATE = "scalar.html"
def create_app(
config_name: str | None = None,
*,
modules: Iterable[Any] | None = None,
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None = None,
permission_provider: Any | None = None,
) -> FastAPI:
config = _resolve_config(config_name, config_mapping)
configure_logging(config)
@asynccontextmanager
async def lifespan(app: FastAPI):
runner = getattr(app.state, "iti_task_runner", None)
audit_dispatcher = getattr(app.state, "audit_dispatcher", None)
if audit_dispatcher:
audit_dispatcher.start()
if runner and config.tasks_enabled:
runner.start()
yield
if runner:
runner.stop()
if audit_dispatcher:
audit_dispatcher.stop()
for client in getattr(app.state, "iti_service_clients", {}).values():
client.close()
app = FastAPI(
title=config.app_name,
debug=config.debug,
lifespan=lifespan,
docs_url=None,
redoc_url=None,
)
install_docs(app)
app.state.config = config
app.state.cache = CacheManager(default_timeout=config.cache_default_timeout)
app.state.limiter = SimpleLimiter(enabled=config.ratelimit_enabled)
app.state.permission_provider = permission_provider or StaticPermissionProvider()
init_middlewares(app)
if config.cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
engine, sessionmaker = configure_db(
config.database_url,
echo=config.sqlalchemy_echo,
pool_pre_ping=config.sqlalchemy_pool_pre_ping,
)
app.state.db_engine = engine
app.state.db_sessionmaker = sessionmaker
init_error_handlers(app)
init_service_clients(app, config.services)
init_task_runner(app)
init_audit(app)
module_registry = init_modules(app, modules)
app.state.iti_modules = module_registry
if config.health_enabled:
app.include_router(health_router)
module_registry.run_phase("register_routes", app)
module_registry.run_phase("register_permissions", app)
module_registry.run_phase("register_menu_seed", app)
install_auto_envelope(app)
return app
def install_docs(app: FastAPI) -> None:
@app.get("/docs", include_in_schema=False)
def docs(ui: str | None = None) -> HTMLResponse:
doc_options = _enabled_doc_options(app)
if ui == "swagger" and "swagger" in doc_options:
return get_swagger_ui_html(
openapi_url=app.openapi_url or "/openapi.json",
title=f"{app.title} - Swagger UI",
)
if ui == "scalar" and "scalar" in doc_options:
return _scalar_docs_html(app)
if ui == "redoc" and "redoc" in doc_options:
return get_redoc_html(
openapi_url=app.openapi_url or "/openapi.json",
title=f"{app.title} - ReDoc",
)
return _docs_picker_html(app, doc_options)
def _enabled_doc_options(app: FastAPI) -> dict[str, dict[str, str]]:
configured = getattr(app.state.config, "docs_ui_enabled", ["swagger", "scalar", "redoc"])
all_options = {
"swagger": {
"class": "swagger",
"label": "Swagger",
"abbr": "SW",
"description": "传统交互文档,适合快速试接口。",
},
"scalar": {
"class": "scalar",
"label": "Scalar",
"abbr": "SC",
"description": "现代接口参考,适合阅读和调试。",
},
"redoc": {
"class": "redoc",
"label": "ReDoc",
"abbr": "RD",
"description": "结构化阅读文档,适合查看模型关系。",
},
}
return {name: all_options[name] for name in configured if name in all_options}
def _docs_picker_html(app: FastAPI, doc_options: Mapping[str, dict[str, str]]) -> HTMLResponse:
title = escape(app.title)
option_cards = "\n".join(
_doc_option_card(name, option) for name, option in doc_options.items()
)
html = _render_template(
DOCS_PICKER_TEMPLATE,
{
"title": title,
"option_cards": option_cards,
},
)
return HTMLResponse(html)
def _doc_option_card(name: str, option: Mapping[str, str]) -> str:
label = escape(option["label"])
class_name = escape(option["class"])
abbr = escape(option["abbr"])
description = escape(option["description"])
href = escape(f"?ui={name}")
return f"""<a class="doc-option {class_name}" href="{href}" aria-label="打开 {label} 文档">
<span class="mark" aria-hidden="true">{abbr}</span>
<span class="copy">
<strong>{label}</strong>
<span>{description}</span>
</span>
<span class="arrow" aria-hidden="true">
<svg viewBox="0 0 24 24">
<path d="M5 12h14"></path>
<path d="m13 6 6 6-6 6"></path>
</svg>
</span>
</a>"""
def _scalar_docs_html(app: FastAPI) -> HTMLResponse:
html = _render_template(
SCALAR_TEMPLATE,
{
"title": escape(app.title),
"openapi_url": escape(app.openapi_url or "/openapi.json"),
},
)
return HTMLResponse(html)
def _render_template(name: str, values: Mapping[str, str]) -> str:
template = resources.files("iti.templates").joinpath(name).read_text(encoding="utf-8")
for key, value in values.items():
template = template.replace("{{ " + key + " }}", value)
return template
def init_middlewares(app: FastAPI) -> None:
@app.middleware("http")
async def request_context_middleware(request: Request, call_next):
token = _current_request.set(request)
trace_id = request.headers.get("X-Trace-Id") or uuid.uuid4().hex
request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex
request.state.trace_id = trace_id
request.state.request_id = request_id
started_at = time.perf_counter()
response: Response | None
try:
response = await call_next(request)
except Exception:
request.state.response_code = getattr(request.state, "response_code", 500)
_log_request(request, started_at, 500)
_current_request.reset(token)
raise
response.headers.setdefault("X-Trace-Id", trace_id)
response.headers.setdefault("X-Request-Id", request_id)
_log_request(request, started_at, response.status_code)
_current_request.reset(token)
return response
def _log_request(request: Request, started_at: float, status_code: int) -> None:
duration_ms = round((time.perf_counter() - started_at) * 1000, 2)
logger.info(
"request method=%s path=%s status=%s code=%s durationMs=%s ip=%s",
request.method,
request.url.path,
status_code,
getattr(request.state, "response_code", "-"),
duration_ms,
request.client.host if request.client else "-",
extra=log_extra(request),
)
def _resolve_config(
config_name: str | None,
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None,
) -> BaseConfig:
if config_mapping is None:
return get_config(config_name)
if isinstance(config_mapping, Mapping):
env_name = config_name or "dev"
value = config_mapping.get(env_name, config_mapping.get("default"))
if value is None:
return get_config(config_name)
return value() if isinstance(value, type) else value
return config_mapping() if isinstance(config_mapping, type) else config_mapping
def init_error_handlers(app: FastAPI) -> None:
@app.exception_handler(ItiError)
async def handle_iti_error(request: Request, exc: ItiError):
request.state.response_code = exc.code
return JSONResponse(
status_code=request.app.state.config.response_envelope_http_status,
content=fail(exc.message, code=exc.code, data=exc.data),
)
@app.exception_handler(RequestValidationError)
async def handle_validation_error(request: Request, exc: RequestValidationError):
request.state.response_code = 422
return JSONResponse(
status_code=request.app.state.config.response_envelope_http_status,
content=fail("参数验证错误", code=422, data=exc.errors()),
)
@app.exception_handler(StarletteHTTPException)
async def handle_http_error(request: Request, exc: StarletteHTTPException):
request.state.response_code = exc.status_code
message, data = _http_error_payload(exc)
return JSONResponse(
status_code=request.app.state.config.response_envelope_http_status,
content=fail(message, code=exc.status_code, data=data),
headers=exc.headers,
)
@app.exception_handler(SQLAlchemyError)
async def handle_db_error(request: Request, exc: SQLAlchemyError):
request.state.response_code = 500
error_logger.exception("database error", extra=log_extra(request))
return JSONResponse(
status_code=request.app.state.config.response_envelope_http_status,
content=fail("数据库错误", code=500, data=str(exc)),
)
@app.exception_handler(Exception)
async def handle_exception(request: Request, exc: Exception):
request.state.response_code = 500
error_logger.exception("server error", extra=log_extra(request))
return JSONResponse(
status_code=request.app.state.config.response_envelope_http_status,
content=fail("服务器错误", code=500, data=str(exc)),
)
def _http_error_payload(exc: StarletteHTTPException) -> tuple[str, Any]:
detail = exc.detail
if isinstance(detail, str):
return detail, None
if detail is None:
return _http_status_phrase(exc.status_code), None
if isinstance(detail, Mapping):
for key in ("message", "detail", "error"):
value = detail.get(key)
if isinstance(value, str) and value:
return value, detail
return _http_status_phrase(exc.status_code), detail
def _http_status_phrase(status_code: int) -> str:
try:
return HTTPStatus(status_code).phrase
except ValueError:
return "HTTP Error"
def to_plain_data(value: Any) -> Any:
if is_dataclass(value):
return asdict(value)
return value
def install_auto_envelope(app: FastAPI) -> None:
config = app.state.config
if not config.response_envelope_enabled:
return
raw_paths = tuple(config.raw_response_paths)
for route in app.routes:
if not isinstance(route, APIRoute):
continue
if getattr(route, "__iti_envelope_installed__", False):
continue
if _is_route_raw(route, raw_paths):
continue
original_call = route.dependant.call
if original_call is None:
continue
route.endpoint = _wrap_endpoint_with_envelope(original_call)
_rebuild_route_dependant(route)
setattr(route, "__iti_envelope_installed__", True)
def _is_route_raw(route: APIRoute, raw_paths: Iterable[str]) -> bool:
endpoint = route.endpoint
if getattr(endpoint, "__iti_raw_response__", False):
return True
for path in route.path_format, route.path:
request = _PathOnlyRequest(path)
if is_raw_response_request(request, raw_paths):
return True
return False
def _wrap_endpoint_with_envelope(func):
@wraps(func)
async def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if isawaitable(result):
result = await result
if isinstance(result, Response):
return result
payload = _to_jsonable(result)
if is_envelope_payload(payload):
_mark_response_code(args, kwargs, payload["code"])
return payload
_mark_response_code(args, kwargs, 200)
return {"data": payload, "code": 200, "message": "成功"}
return wrapper
def _to_jsonable(value: Any) -> Any:
if value is None:
return None
if isinstance(value, BaseModel):
return value.model_dump(by_alias=True)
if is_dataclass(value):
return asdict(value)
return value
def _mark_response_code(args: tuple[Any, ...], kwargs: dict[str, Any], code: int) -> None:
request = _request_from_call(args, kwargs) or _current_request.get()
if request is not None:
request.state.response_code = code
def _request_from_call(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None:
for value in list(args) + list(kwargs.values()):
if isinstance(value, Request):
return value
return None
def _rebuild_route_dependant(route: APIRoute) -> None:
route.dependant = get_dependant(
path=route.path_format,
call=route.endpoint,
scope="function",
)
for depends in route.dependencies[::-1]:
route.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=route.path_format),
)
route._flat_dependant = get_flat_dependant(route.dependant)
route._embed_body_fields = _should_embed_body_fields(route._flat_dependant.body_params)
route.body_field = get_body_field(
flat_dependant=route._flat_dependant,
name=route.unique_id,
embed_body_fields=route._embed_body_fields,
)
route.app = request_response(route.get_route_handler())
class _PathOnlyRequest:
def __init__(self, path: str) -> None:
self.url = type("URL", (), {"path": path})()
self.scope = {"endpoint": None}