fix: 兼容 FastAPI 嵌套路由响应包裹

main
917232558@qq.com 6 days ago
parent 193ab7cf16
commit 9359a96266

@ -471,12 +471,15 @@ def install_auto_envelope(app: FastAPI) -> None:
if not config.response_envelope_enabled:
return
raw_paths = tuple(config.raw_response_paths)
for route in app.routes:
if not isinstance(route, APIRoute):
continue
targets: dict[int, tuple[APIRoute, list[str]]] = {}
for route, paths in _iter_envelope_route_targets(app.routes):
targets.setdefault(id(route), (route, []))[1].extend(paths)
wrapped_any = False
for route, paths in targets.values():
if getattr(route, "__iti_envelope_installed__", False):
continue
if _is_route_raw(route, raw_paths):
if _is_route_raw(route, raw_paths, paths):
continue
original_call = route.dependant.call
if original_call is None:
@ -484,13 +487,53 @@ def install_auto_envelope(app: FastAPI) -> None:
route.endpoint = _wrap_endpoint_with_envelope(original_call)
_rebuild_route_dependant(route)
setattr(route, "__iti_envelope_installed__", True)
wrapped_any = True
if wrapped_any:
_invalidate_included_router_caches(app.routes)
def _iter_envelope_route_targets(routes: Iterable[Any]):
for route in routes:
effective_contexts = getattr(route, "effective_route_contexts", None)
if callable(effective_contexts):
for context in effective_contexts():
original_route = getattr(context, "original_route", None)
if isinstance(original_route, APIRoute):
yield original_route, (
getattr(context, "path_format", ""),
getattr(context, "path", ""),
)
continue
if isinstance(route, APIRoute):
yield route, (route.path_format, route.path)
def _invalidate_included_router_caches(routes: Iterable[Any]) -> None:
for route in routes:
if hasattr(route, "_effective_candidates_version"):
route._effective_candidates_version = None
route._effective_low_priority_routes_version = None
route._effective_candidates = []
route._effective_low_priority_routes = []
original_router = getattr(route, "original_router", None)
if original_router is not None:
_invalidate_included_router_caches(getattr(original_router, "routes", ()))
def _is_route_raw(route: APIRoute, raw_paths: Iterable[str]) -> bool:
def _is_route_raw(route: APIRoute, raw_paths: Iterable[str], paths: Iterable[str] = ()) -> bool:
endpoint = route.endpoint
if getattr(endpoint, "__iti_raw_response__", False):
return True
for path in route.path_format, route.path:
dependant_call = getattr(getattr(route, "dependant", None), "call", None)
if getattr(dependant_call, "__iti_raw_response__", False):
return True
seen_paths: set[str] = set()
for path in (*paths, route.path_format, route.path):
if not path or path in seen_paths:
continue
seen_paths.add(path)
request = _PathOnlyRequest(path)
if is_raw_response_request(request, raw_paths):
return True

@ -262,7 +262,7 @@ def test_parser_accepts_worker_command() -> None:
assert args.command == "worker"
assert args.name == "mq"
assert args.args == ["--", "--config", "dev"]
assert args.args == ["--config", "dev"]
def test_normalize_worker_name_accepts_dash_alias() -> None:

Loading…
Cancel
Save