You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from collections.abc import Callable
|
|
|
|
from fastapi import Depends, Request
|
|
|
|
from iti.exceptions import BizError
|
|
|
|
|
|
class SimpleLimiter:
|
|
def __init__(self, *, enabled: bool = True) -> None:
|
|
self.enabled = enabled
|
|
self._hits: dict[str, deque[float]] = defaultdict(deque)
|
|
|
|
def limit(self, rule: str) -> Callable:
|
|
count, seconds = parse_rule(rule)
|
|
|
|
def dependency(request: Request) -> None:
|
|
if not self.enabled:
|
|
return
|
|
client = request.client.host if request.client else "unknown"
|
|
key = f"{client}:{request.url.path}:{rule}"
|
|
now = time.time()
|
|
hits = self._hits[key]
|
|
while hits and hits[0] <= now - seconds:
|
|
hits.popleft()
|
|
if len(hits) >= count:
|
|
raise BizError("请求过于频繁,请稍后再试", code=429, status_code=429)
|
|
hits.append(now)
|
|
|
|
return dependency
|
|
|
|
|
|
def parse_rule(rule: str) -> tuple[int, int]:
|
|
parts = rule.strip().split()
|
|
if len(parts) < 3 or parts[1] != "per":
|
|
raise ValueError(f"invalid rate limit rule: {rule}")
|
|
count = int(parts[0])
|
|
unit = parts[2].lower()
|
|
if unit.startswith("second"):
|
|
seconds = 1
|
|
elif unit.startswith("minute"):
|
|
seconds = 60
|
|
elif unit.startswith("hour"):
|
|
seconds = 3600
|
|
else:
|
|
raise ValueError(f"invalid rate limit unit: {unit}")
|
|
return count, seconds
|
|
|
|
|
|
def limit(rule: str) -> Callable:
|
|
def dependency(request: Request) -> None:
|
|
limiter = getattr(request.app.state, "limiter", SimpleLimiter())
|
|
return limiter.limit(rule)(request)
|
|
|
|
return Depends(dependency)
|