You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
554 lines
19 KiB
Python
554 lines
19 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from html import escape
|
|
from importlib import resources
|
|
from http import HTTPStatus
|
|
from collections.abc import Iterable, Mapping
|
|
from contextvars import ContextVar
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import asdict, is_dataclass
|
|
from functools import wraps
|
|
from inspect import isawaitable
|
|
from typing import Any
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.dependencies.utils import (
|
|
_should_embed_body_fields,
|
|
get_body_field,
|
|
get_dependant,
|
|
get_flat_dependant,
|
|
get_parameterless_sub_dependant,
|
|
)
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
|
from fastapi.routing import APIRoute
|
|
from fastapi.routing import request_response
|
|
from fastapi.responses import HTMLResponse, JSONResponse
|
|
from pydantic import BaseModel
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from starlette.responses import Response
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from iti.auth.permissions import StaticPermissionProvider
|
|
from iti.audit import init_audit
|
|
from iti.cache import CacheManager
|
|
from iti.config import BaseConfig, get_config, get_env_name
|
|
from iti.db import configure_db
|
|
from iti.exceptions import ItiError
|
|
from iti.health import router as health_router
|
|
from iti.limiter import SimpleLimiter
|
|
from iti.logging_config import configure_logging, log_extra
|
|
from iti.modules import init_modules
|
|
from iti.exchange import get_exchange_registry
|
|
from iti.exchange import models as _exchange_models
|
|
from iti.responses.auto import is_envelope_payload, is_raw_response_request
|
|
from iti.responses import fail
|
|
from iti.service_client import init_service_clients
|
|
from iti.tasks import init_task_runner
|
|
|
|
logger = logging.getLogger("iti")
|
|
error_logger = logging.getLogger("iti.error")
|
|
_current_request: ContextVar[Request | None] = ContextVar("iti_current_request", default=None)
|
|
DOCS_PICKER_TEMPLATE = "docs-picker.html"
|
|
SCALAR_TEMPLATE = "scalar.html"
|
|
OPENAPI_HTTP_METHODS = {"get", "put", "post", "delete", "options", "head", "patch", "trace"}
|
|
|
|
|
|
def create_app(
|
|
config_name: str | None = None,
|
|
*,
|
|
modules: Iterable[Any] | None = None,
|
|
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None = None,
|
|
permission_provider: Any | None = None,
|
|
) -> FastAPI:
|
|
config = _resolve_config(config_name, config_mapping)
|
|
configure_logging(config)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
runner = getattr(app.state, "iti_task_runner", None)
|
|
audit_dispatcher = getattr(app.state, "audit_dispatcher", None)
|
|
if audit_dispatcher:
|
|
audit_dispatcher.start()
|
|
if runner and config.tasks_enabled:
|
|
runner.start()
|
|
yield
|
|
if runner:
|
|
runner.stop()
|
|
if audit_dispatcher:
|
|
audit_dispatcher.stop()
|
|
for client in getattr(app.state, "iti_service_clients", {}).values():
|
|
client.close()
|
|
|
|
app = FastAPI(
|
|
title=config.app_name,
|
|
debug=config.debug,
|
|
lifespan=lifespan,
|
|
docs_url=None,
|
|
redoc_url=None,
|
|
)
|
|
install_docs(app)
|
|
install_openapi_tag_groups(app)
|
|
app.state.config = config
|
|
app.state.cache = CacheManager(default_timeout=config.cache_default_timeout)
|
|
app.state.limiter = SimpleLimiter(enabled=config.ratelimit_enabled)
|
|
app.state.permission_provider = permission_provider or StaticPermissionProvider()
|
|
app.state.exchange_enabled = config.exchange_enabled
|
|
|
|
init_middlewares(app)
|
|
|
|
if config.cors_origins:
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=config.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
engine, sessionmaker = configure_db(
|
|
config.database_url,
|
|
echo=config.sqlalchemy_echo,
|
|
pool_pre_ping=config.sqlalchemy_pool_pre_ping,
|
|
)
|
|
app.state.db_engine = engine
|
|
app.state.db_sessionmaker = sessionmaker
|
|
|
|
init_error_handlers(app)
|
|
init_service_clients(app, config.services)
|
|
init_task_runner(app)
|
|
get_exchange_registry(app)
|
|
init_audit(app)
|
|
module_list = list(modules or [])
|
|
if config.exchange_enabled and not any(
|
|
getattr(module, "name", None) == "exchange" for module in module_list
|
|
):
|
|
from iti.exchange.module import create_exchange_module
|
|
|
|
module_list.append(create_exchange_module())
|
|
module_registry = init_modules(app, module_list)
|
|
app.state.iti_modules = module_registry
|
|
if config.health_enabled:
|
|
app.include_router(health_router)
|
|
module_registry.run_phase("register_routes", app)
|
|
module_registry.run_phase("register_permissions", app)
|
|
module_registry.run_phase("register_menu_seed", app)
|
|
install_auto_envelope(app)
|
|
return app
|
|
|
|
|
|
def install_docs(app: FastAPI) -> None:
|
|
@app.get("/docs", include_in_schema=False)
|
|
def docs(ui: str | None = None) -> HTMLResponse:
|
|
doc_options = _enabled_doc_options(app)
|
|
if ui == "swagger" and "swagger" in doc_options:
|
|
return get_swagger_ui_html(
|
|
openapi_url=app.openapi_url or "/openapi.json",
|
|
title=f"{app.title} - Swagger UI",
|
|
)
|
|
if ui == "scalar" and "scalar" in doc_options:
|
|
return _scalar_docs_html(app)
|
|
if ui == "redoc" and "redoc" in doc_options:
|
|
return get_redoc_html(
|
|
openapi_url=app.openapi_url or "/openapi.json",
|
|
title=f"{app.title} - ReDoc",
|
|
)
|
|
return _docs_picker_html(app, doc_options)
|
|
|
|
|
|
def install_openapi_tag_groups(app: FastAPI) -> None:
|
|
default_openapi = app.openapi
|
|
|
|
def openapi_with_tag_groups() -> dict[str, Any]:
|
|
schema = default_openapi()
|
|
_apply_openapi_tag_groups(schema)
|
|
return schema
|
|
|
|
app.openapi = openapi_with_tag_groups
|
|
|
|
|
|
def _enabled_doc_options(app: FastAPI) -> dict[str, dict[str, str]]:
|
|
configured = getattr(app.state.config, "docs_ui_enabled", ["swagger", "scalar", "redoc"])
|
|
all_options = {
|
|
"swagger": {
|
|
"class": "swagger",
|
|
"label": "Swagger",
|
|
"abbr": "SW",
|
|
"description": "传统交互文档,适合快速试接口。",
|
|
},
|
|
"scalar": {
|
|
"class": "scalar",
|
|
"label": "Scalar",
|
|
"abbr": "SC",
|
|
"description": "现代接口参考,适合阅读和调试。",
|
|
},
|
|
"redoc": {
|
|
"class": "redoc",
|
|
"label": "ReDoc",
|
|
"abbr": "RD",
|
|
"description": "结构化阅读文档,适合查看模型关系。",
|
|
},
|
|
}
|
|
return {name: all_options[name] for name in configured if name in all_options}
|
|
|
|
|
|
def _docs_picker_html(app: FastAPI, doc_options: Mapping[str, dict[str, str]]) -> HTMLResponse:
|
|
title = escape(app.title)
|
|
option_cards = "\n".join(
|
|
_doc_option_card(name, option) for name, option in doc_options.items()
|
|
)
|
|
html = _render_template(
|
|
DOCS_PICKER_TEMPLATE,
|
|
{
|
|
"title": title,
|
|
"option_cards": option_cards,
|
|
},
|
|
)
|
|
return HTMLResponse(html)
|
|
|
|
|
|
def _doc_option_card(name: str, option: Mapping[str, str]) -> str:
|
|
label = escape(option["label"])
|
|
class_name = escape(option["class"])
|
|
abbr = escape(option["abbr"])
|
|
description = escape(option["description"])
|
|
href = escape(f"?ui={name}")
|
|
return f"""<a class="doc-option {class_name}" href="{href}" aria-label="打开 {label} 文档">
|
|
<span class="mark" aria-hidden="true">{abbr}</span>
|
|
<span class="copy">
|
|
<strong>{label}</strong>
|
|
<span>{description}</span>
|
|
</span>
|
|
<span class="arrow" aria-hidden="true">
|
|
<svg viewBox="0 0 24 24">
|
|
<path d="M5 12h14"></path>
|
|
<path d="m13 6 6 6-6 6"></path>
|
|
</svg>
|
|
</span>
|
|
</a>"""
|
|
|
|
|
|
def _scalar_docs_html(app: FastAPI) -> HTMLResponse:
|
|
html = _render_template(
|
|
SCALAR_TEMPLATE,
|
|
{
|
|
"title": escape(app.title),
|
|
"openapi_url": escape(app.openapi_url or "/openapi.json"),
|
|
},
|
|
)
|
|
return HTMLResponse(html)
|
|
|
|
|
|
def _render_template(name: str, values: Mapping[str, str]) -> str:
|
|
template = resources.files("iti.templates").joinpath(name).read_text(encoding="utf-8")
|
|
for key, value in values.items():
|
|
template = template.replace("{{ " + key + " }}", value)
|
|
return template
|
|
|
|
|
|
def _apply_openapi_tag_groups(schema: dict[str, Any]) -> None:
|
|
tag_names = _openapi_tag_names(schema)
|
|
if not tag_names or not any(_openapi_tag_display_name(tag) for tag in tag_names):
|
|
return
|
|
|
|
schema["tags"] = _openapi_tag_objects(schema, tag_names)
|
|
groups: list[dict[str, Any]] = []
|
|
group_index: dict[str, dict[str, Any]] = {}
|
|
for tag in tag_names:
|
|
group_name = _openapi_tag_group_name(tag)
|
|
group = group_index.get(group_name)
|
|
if group is None:
|
|
group = {"name": group_name, "tags": []}
|
|
groups.append(group)
|
|
group_index[group_name] = group
|
|
group["tags"].append(tag)
|
|
schema["x-tagGroups"] = groups
|
|
|
|
|
|
def _openapi_tag_names(schema: Mapping[str, Any]) -> list[str]:
|
|
names: list[str] = []
|
|
seen: set[str] = set()
|
|
|
|
def append_tag(value: Any) -> None:
|
|
if isinstance(value, str) and value not in seen:
|
|
names.append(value)
|
|
seen.add(value)
|
|
|
|
for tag in schema.get("tags") or []:
|
|
if isinstance(tag, Mapping):
|
|
append_tag(tag.get("name"))
|
|
|
|
paths = schema.get("paths") or {}
|
|
if not isinstance(paths, Mapping):
|
|
return names
|
|
for path_item in paths.values():
|
|
if not isinstance(path_item, Mapping):
|
|
continue
|
|
for method, operation in path_item.items():
|
|
if method.lower() not in OPENAPI_HTTP_METHODS or not isinstance(operation, Mapping):
|
|
continue
|
|
for tag in operation.get("tags") or []:
|
|
append_tag(tag)
|
|
return names
|
|
|
|
|
|
def _openapi_tag_objects(schema: Mapping[str, Any], tag_names: list[str]) -> list[dict[str, Any]]:
|
|
existing: dict[str, dict[str, Any]] = {}
|
|
for tag in schema.get("tags") or []:
|
|
if not isinstance(tag, Mapping) or not isinstance(tag.get("name"), str):
|
|
continue
|
|
existing.setdefault(tag["name"], dict(tag))
|
|
|
|
tag_objects: list[dict[str, Any]] = []
|
|
for tag_name in tag_names:
|
|
tag = dict(existing.get(tag_name, {"name": tag_name}))
|
|
tag["name"] = tag_name
|
|
display_name = _openapi_tag_display_name(tag_name)
|
|
if display_name and "x-displayName" not in tag:
|
|
tag["x-displayName"] = display_name
|
|
tag_objects.append(tag)
|
|
return tag_objects
|
|
|
|
|
|
def _openapi_tag_group_name(tag: str) -> str:
|
|
prefix, separator, suffix = tag.partition(".")
|
|
if separator and prefix and suffix:
|
|
return prefix
|
|
return tag
|
|
|
|
|
|
def _openapi_tag_display_name(tag: str) -> str | None:
|
|
prefix, separator, suffix = tag.partition(".")
|
|
if separator and prefix and suffix:
|
|
return suffix
|
|
return None
|
|
|
|
|
|
def init_middlewares(app: FastAPI) -> None:
|
|
@app.middleware("http")
|
|
async def request_context_middleware(request: Request, call_next):
|
|
token = _current_request.set(request)
|
|
trace_id = request.headers.get("X-Trace-Id") or uuid.uuid4().hex
|
|
request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex
|
|
request.state.trace_id = trace_id
|
|
request.state.request_id = request_id
|
|
started_at = time.perf_counter()
|
|
response: Response | None
|
|
try:
|
|
response = await call_next(request)
|
|
except Exception:
|
|
request.state.response_code = getattr(request.state, "response_code", 500)
|
|
_log_request(request, started_at, 500)
|
|
_current_request.reset(token)
|
|
raise
|
|
response.headers.setdefault("X-Trace-Id", trace_id)
|
|
response.headers.setdefault("X-Request-Id", request_id)
|
|
_log_request(request, started_at, response.status_code)
|
|
_current_request.reset(token)
|
|
return response
|
|
|
|
|
|
def _log_request(request: Request, started_at: float, status_code: int) -> None:
|
|
duration_ms = round((time.perf_counter() - started_at) * 1000, 2)
|
|
logger.info(
|
|
"request method=%s path=%s status=%s code=%s durationMs=%s ip=%s",
|
|
request.method,
|
|
request.url.path,
|
|
status_code,
|
|
getattr(request.state, "response_code", "-"),
|
|
duration_ms,
|
|
request.client.host if request.client else "-",
|
|
extra=log_extra(request),
|
|
)
|
|
|
|
|
|
def _resolve_config(
|
|
config_name: str | None,
|
|
config_mapping: Mapping[str, type[BaseConfig] | BaseConfig] | type[BaseConfig] | BaseConfig | None,
|
|
) -> BaseConfig:
|
|
if config_mapping is None:
|
|
return get_config(config_name)
|
|
if isinstance(config_mapping, Mapping):
|
|
env_name = config_name or get_env_name()
|
|
value = config_mapping.get(env_name, config_mapping.get("default"))
|
|
if value is None:
|
|
return get_config(config_name)
|
|
return value() if isinstance(value, type) else value
|
|
return config_mapping() if isinstance(config_mapping, type) else config_mapping
|
|
|
|
|
|
def init_error_handlers(app: FastAPI) -> None:
|
|
@app.exception_handler(ItiError)
|
|
async def handle_iti_error(request: Request, exc: ItiError):
|
|
request.state.response_code = exc.code
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail(exc.message, code=exc.code, data=exc.data),
|
|
)
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def handle_validation_error(request: Request, exc: RequestValidationError):
|
|
request.state.response_code = 422
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail("参数验证错误", code=422, data=exc.errors()),
|
|
)
|
|
|
|
@app.exception_handler(StarletteHTTPException)
|
|
async def handle_http_error(request: Request, exc: StarletteHTTPException):
|
|
request.state.response_code = exc.status_code
|
|
message, data = _http_error_payload(exc)
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail(message, code=exc.status_code, data=data),
|
|
headers=exc.headers,
|
|
)
|
|
|
|
@app.exception_handler(SQLAlchemyError)
|
|
async def handle_db_error(request: Request, exc: SQLAlchemyError):
|
|
request.state.response_code = 500
|
|
error_logger.exception("database error", extra=log_extra(request))
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail("数据库错误", code=500, data=str(exc)),
|
|
)
|
|
|
|
@app.exception_handler(Exception)
|
|
async def handle_exception(request: Request, exc: Exception):
|
|
request.state.response_code = 500
|
|
error_logger.exception("server error", extra=log_extra(request))
|
|
return JSONResponse(
|
|
status_code=request.app.state.config.response_envelope_http_status,
|
|
content=fail("服务器错误", code=500, data=str(exc)),
|
|
)
|
|
|
|
|
|
def _http_error_payload(exc: StarletteHTTPException) -> tuple[str, Any]:
|
|
detail = exc.detail
|
|
if isinstance(detail, str):
|
|
return detail, None
|
|
if detail is None:
|
|
return _http_status_phrase(exc.status_code), None
|
|
if isinstance(detail, Mapping):
|
|
for key in ("message", "detail", "error"):
|
|
value = detail.get(key)
|
|
if isinstance(value, str) and value:
|
|
return value, detail
|
|
return _http_status_phrase(exc.status_code), detail
|
|
|
|
|
|
def _http_status_phrase(status_code: int) -> str:
|
|
try:
|
|
return HTTPStatus(status_code).phrase
|
|
except ValueError:
|
|
return "HTTP Error"
|
|
|
|
|
|
def to_plain_data(value: Any) -> Any:
|
|
if is_dataclass(value):
|
|
return asdict(value)
|
|
return value
|
|
|
|
|
|
def install_auto_envelope(app: FastAPI) -> None:
|
|
config = app.state.config
|
|
if not config.response_envelope_enabled:
|
|
return
|
|
raw_paths = tuple(config.raw_response_paths)
|
|
for route in app.routes:
|
|
if not isinstance(route, APIRoute):
|
|
continue
|
|
if getattr(route, "__iti_envelope_installed__", False):
|
|
continue
|
|
if _is_route_raw(route, raw_paths):
|
|
continue
|
|
original_call = route.dependant.call
|
|
if original_call is None:
|
|
continue
|
|
route.endpoint = _wrap_endpoint_with_envelope(original_call)
|
|
_rebuild_route_dependant(route)
|
|
setattr(route, "__iti_envelope_installed__", True)
|
|
|
|
|
|
def _is_route_raw(route: APIRoute, raw_paths: Iterable[str]) -> bool:
|
|
endpoint = route.endpoint
|
|
if getattr(endpoint, "__iti_raw_response__", False):
|
|
return True
|
|
for path in route.path_format, route.path:
|
|
request = _PathOnlyRequest(path)
|
|
if is_raw_response_request(request, raw_paths):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _wrap_endpoint_with_envelope(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
result = func(*args, **kwargs)
|
|
if isawaitable(result):
|
|
result = await result
|
|
if isinstance(result, Response):
|
|
return result
|
|
payload = _to_jsonable(result)
|
|
if is_envelope_payload(payload):
|
|
_mark_response_code(args, kwargs, payload["code"])
|
|
return payload
|
|
_mark_response_code(args, kwargs, 200)
|
|
return {"data": payload, "code": 200, "message": "成功"}
|
|
|
|
return wrapper
|
|
|
|
|
|
def _to_jsonable(value: Any) -> Any:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, BaseModel):
|
|
return value.model_dump(by_alias=True)
|
|
if is_dataclass(value):
|
|
return asdict(value)
|
|
return value
|
|
|
|
|
|
def _mark_response_code(args: tuple[Any, ...], kwargs: dict[str, Any], code: int) -> None:
|
|
request = _request_from_call(args, kwargs) or _current_request.get()
|
|
if request is not None:
|
|
request.state.response_code = code
|
|
|
|
|
|
def _request_from_call(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request | None:
|
|
for value in list(args) + list(kwargs.values()):
|
|
if isinstance(value, Request):
|
|
return value
|
|
return None
|
|
|
|
|
|
def _rebuild_route_dependant(route: APIRoute) -> None:
|
|
route.dependant = get_dependant(
|
|
path=route.path_format,
|
|
call=route.endpoint,
|
|
scope="function",
|
|
)
|
|
for depends in route.dependencies[::-1]:
|
|
route.dependant.dependencies.insert(
|
|
0,
|
|
get_parameterless_sub_dependant(depends=depends, path=route.path_format),
|
|
)
|
|
route._flat_dependant = get_flat_dependant(route.dependant)
|
|
route._embed_body_fields = _should_embed_body_fields(route._flat_dependant.body_params)
|
|
route.body_field = get_body_field(
|
|
flat_dependant=route._flat_dependant,
|
|
name=route.unique_id,
|
|
embed_body_fields=route._embed_body_fields,
|
|
)
|
|
route.app = request_response(route.get_route_handler())
|
|
|
|
|
|
class _PathOnlyRequest:
|
|
def __init__(self, path: str) -> None:
|
|
self.url = type("URL", (), {"path": path})()
|
|
self.scope = {"endpoint": None}
|