from __future__ import annotations from collections.abc import Callable, Mapping from dataclasses import dataclass, field from typing import Protocol from fastapi import Depends, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from iti.exceptions import PermissionDenied, Unauthorized from .jwt import decode_token bearer_scheme = HTTPBearer(auto_error=False) @dataclass(frozen=True) class Principal: id: str type: str = "user" permissions: frozenset[str] = field(default_factory=frozenset) roles: frozenset[str] = field(default_factory=frozenset) scopes: frozenset[str] = field(default_factory=frozenset) @dataclass(frozen=True) class Actor: id: str type: str principal: Principal | None = None service_name: str | None = None class PermissionProvider(Protocol): def load_principal(self, principal_id: str, request: Request) -> Principal | None: ... def has_permission(self, principal: Principal, code: str) -> bool: ... def has_scope(self, principal: Principal, scope: str) -> bool: ... class StaticPermissionProvider: def load_principal(self, principal_id: str, request: Request) -> Principal | None: return Principal(id=principal_id) def has_permission(self, principal: Principal, code: str) -> bool: return code in principal.permissions def has_scope(self, principal: Principal, scope: str) -> bool: return scope in principal.scopes permission_provider: PermissionProvider = StaticPermissionProvider() def set_permission_provider(provider: PermissionProvider) -> None: global permission_provider permission_provider = provider def get_principal( request: Request, credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), ) -> Principal | None: if credentials is None: return None config = request.app.state.config payload = decode_token(credentials.credentials, config, token_type="access") principal_id = payload.get("sub") if not principal_id: raise Unauthorized("无效的令牌") provider = getattr(request.app.state, "permission_provider", permission_provider) principal = provider.load_principal(principal_id, request) if principal is None: raise Unauthorized("用户不存在或已失效") request.state.principal = principal request.state.actor = Actor(id=principal.id, type="user", principal=principal) return principal def require_user( principal: Principal | None = Depends(get_principal), ) -> Principal: if principal is None: raise Unauthorized("缺少令牌参数 Authorization Bearer") return principal def require_permission(code: str) -> Callable: def dependency( request: Request, principal: Principal = Depends(require_user), ) -> Principal: provider = getattr(request.app.state, "permission_provider", permission_provider) if not provider.has_permission(principal, code): raise PermissionDenied("权限不足", code=403) return principal return dependency def require_service_scope(scope: str) -> Callable: def dependency( request: Request, principal: Principal = Depends(require_user), ) -> Principal: provider = getattr(request.app.state, "permission_provider", permission_provider) if not provider.has_scope(principal, scope): raise PermissionDenied("服务权限不足", code=403) return principal return dependency def get_service_actor( request: Request, credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), ) -> Actor | None: if credentials is None: return None service_name = match_service_token( getattr(request.app.state.config, "service_tokens", {}), credentials.credentials, ) if service_name is None: return None actor = Actor(id=service_name, type="service", service_name=service_name) request.state.actor = actor return actor def match_service_token(tokens: Mapping[str, str], token: str) -> str | None: for name, expected in tokens.items(): if expected and token == expected: return name return None def require_service(service_name: str | None = None) -> Callable: def dependency( request: Request, credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), ) -> Actor: actor = get_service_actor(request, credentials) if actor is None: raise Unauthorized("无效的服务令牌") if service_name is not None and actor.service_name != service_name: raise PermissionDenied("服务权限不足", code=403) return actor return dependency def require_actor( *, permissions: list[str] | tuple[str, ...] | None = None, allow_service: bool = False, service_name: str | None = None, ) -> Callable: required_permissions = tuple(permissions or ()) def dependency( request: Request, credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), ) -> Actor: if credentials is None: raise Unauthorized("缺少令牌参数 Authorization Bearer") service_actor = get_service_actor(request, credentials) if service_actor is not None: if not allow_service: raise PermissionDenied("服务权限不足", code=403) if service_name is not None and service_actor.service_name != service_name: raise PermissionDenied("服务权限不足", code=403) return service_actor principal = get_principal(request, credentials) if principal is None: raise Unauthorized("缺少令牌参数 Authorization Bearer") provider = getattr(request.app.state, "permission_provider", permission_provider) missing = [ code for code in required_permissions if not provider.has_permission(principal, code) ] if missing: raise PermissionDenied("权限不足", code=403) actor = Actor(id=principal.id, type="user", principal=principal) request.state.actor = actor return actor return dependency