diff --git a/iti/app.py b/iti/app.py index 5ebc0d7..95c656e 100644 --- a/iti/app.py +++ b/iti/app.py @@ -471,12 +471,15 @@ def install_auto_envelope(app: FastAPI) -> None: if not config.response_envelope_enabled: return raw_paths = tuple(config.raw_response_paths) - for route in app.routes: - if not isinstance(route, APIRoute): - continue + targets: dict[int, tuple[APIRoute, list[str]]] = {} + for route, paths in _iter_envelope_route_targets(app.routes): + targets.setdefault(id(route), (route, []))[1].extend(paths) + + wrapped_any = False + for route, paths in targets.values(): if getattr(route, "__iti_envelope_installed__", False): continue - if _is_route_raw(route, raw_paths): + if _is_route_raw(route, raw_paths, paths): continue original_call = route.dependant.call if original_call is None: @@ -484,13 +487,53 @@ def install_auto_envelope(app: FastAPI) -> None: route.endpoint = _wrap_endpoint_with_envelope(original_call) _rebuild_route_dependant(route) setattr(route, "__iti_envelope_installed__", True) + wrapped_any = True + + if wrapped_any: + _invalidate_included_router_caches(app.routes) + + +def _iter_envelope_route_targets(routes: Iterable[Any]): + for route in routes: + effective_contexts = getattr(route, "effective_route_contexts", None) + if callable(effective_contexts): + for context in effective_contexts(): + original_route = getattr(context, "original_route", None) + if isinstance(original_route, APIRoute): + yield original_route, ( + getattr(context, "path_format", ""), + getattr(context, "path", ""), + ) + continue + if isinstance(route, APIRoute): + yield route, (route.path_format, route.path) + +def _invalidate_included_router_caches(routes: Iterable[Any]) -> None: + for route in routes: + if hasattr(route, "_effective_candidates_version"): + route._effective_candidates_version = None + route._effective_low_priority_routes_version = None + route._effective_candidates = [] + route._effective_low_priority_routes = [] + original_router = getattr(route, "original_router", None) + if original_router is not None: + _invalidate_included_router_caches(getattr(original_router, "routes", ())) -def _is_route_raw(route: APIRoute, raw_paths: Iterable[str]) -> bool: + +def _is_route_raw(route: APIRoute, raw_paths: Iterable[str], paths: Iterable[str] = ()) -> bool: endpoint = route.endpoint if getattr(endpoint, "__iti_raw_response__", False): return True - for path in route.path_format, route.path: + dependant_call = getattr(getattr(route, "dependant", None), "call", None) + if getattr(dependant_call, "__iti_raw_response__", False): + return True + + seen_paths: set[str] = set() + for path in (*paths, route.path_format, route.path): + if not path or path in seen_paths: + continue + seen_paths.add(path) request = _PathOnlyRequest(path) if is_raw_response_request(request, raw_paths): return True diff --git a/tests/test_cli.py b/tests/test_cli.py index 28a5dda..337cc38 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -262,7 +262,7 @@ def test_parser_accepts_worker_command() -> None: assert args.command == "worker" assert args.name == "mq" - assert args.args == ["--", "--config", "dev"] + assert args.args == ["--config", "dev"] def test_normalize_worker_name_accepts_dash_alias() -> None: