69 lines
2.5 KiB
Python
69 lines
2.5 KiB
Python
import asyncio
|
|
import time
|
|
from collections import defaultdict
|
|
from functools import wraps
|
|
from typing import Callable
|
|
|
|
from fastapi import HTTPException, Request, status
|
|
|
|
|
|
def get_client_ip(request: Request) -> str:
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
|
|
class RateLimiter:
|
|
def __init__(self, max_attempts: int = 5, window_seconds: int = 60):
|
|
self.max_attempts = max_attempts
|
|
self.window_seconds = window_seconds
|
|
self._attempts: dict[str, list[float]] = defaultdict(list)
|
|
|
|
def _cleanup(self, key: str) -> None:
|
|
now = time.monotonic()
|
|
cutoff = now - self.window_seconds
|
|
self._attempts[key] = [t for t in self._attempts[key] if t > cutoff]
|
|
|
|
def check(self, key: str) -> None:
|
|
self._cleanup(key)
|
|
if len(self._attempts[key]) >= self.max_attempts:
|
|
retry_after = int(self.window_seconds - (time.monotonic() - self._attempts[key][0]))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail=f"Trop de tentatives. Réessayez dans {max(retry_after, 1)} secondes.",
|
|
headers={"Retry-After": str(max(retry_after, 1))},
|
|
)
|
|
|
|
def record(self, key: str) -> None:
|
|
self._attempts[key].append(time.monotonic())
|
|
|
|
|
|
login_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
|
change_password_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
|
register_limiter = RateLimiter(max_attempts=3, window_seconds=300)
|
|
|
|
|
|
def rate_limit(limiter: RateLimiter) -> Callable:
|
|
def decorator(func: Callable) -> Callable:
|
|
if asyncio.iscoroutinefunction(func):
|
|
@wraps(func)
|
|
async def async_wrapper(request: Request, *args, **kwargs):
|
|
client_ip = get_client_ip(request)
|
|
limiter.check(client_ip)
|
|
limiter.record(client_ip)
|
|
return await func(request, *args, **kwargs)
|
|
return async_wrapper
|
|
else:
|
|
@wraps(func)
|
|
def sync_wrapper(request: Request, *args, **kwargs):
|
|
client_ip = get_client_ip(request)
|
|
limiter.check(client_ip)
|
|
limiter.record(client_ip)
|
|
return func(request, *args, **kwargs)
|
|
return sync_wrapper
|
|
return decorator
|