from __future__ import annotations import time from collections import defaultdict, deque from collections.abc import Callable from fastapi import Depends, Request from iti.exceptions import BizError class SimpleLimiter: def __init__(self, *, enabled: bool = True) -> None: self.enabled = enabled self._hits: dict[str, deque[float]] = defaultdict(deque) def limit(self, rule: str) -> Callable: count, seconds = parse_rule(rule) def dependency(request: Request) -> None: if not self.enabled: return client = request.client.host if request.client else "unknown" key = f"{client}:{request.url.path}:{rule}" now = time.time() hits = self._hits[key] while hits and hits[0] <= now - seconds: hits.popleft() if len(hits) >= count: raise BizError("请求过于频繁,请稍后再试", code=429, status_code=429) hits.append(now) return dependency def parse_rule(rule: str) -> tuple[int, int]: parts = rule.strip().split() if len(parts) < 3 or parts[1] != "per": raise ValueError(f"invalid rate limit rule: {rule}") count = int(parts[0]) unit = parts[2].lower() if unit.startswith("second"): seconds = 1 elif unit.startswith("minute"): seconds = 60 elif unit.startswith("hour"): seconds = 3600 else: raise ValueError(f"invalid rate limit unit: {unit}") return count, seconds def limit(rule: str) -> Callable: def dependency(request: Request) -> None: limiter = getattr(request.app.state, "limiter", SimpleLimiter()) return limiter.limit(rule)(request) return Depends(dependency)