Files
2026-06-26 11:54:29 +02:00

69 lines
2.5 KiB
Python

import asyncio
import time
from collections import defaultdict
from functools import wraps
from typing import Callable
from fastapi import HTTPException, Request, status
def get_client_ip(request: Request) -> str:
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
class RateLimiter:
def __init__(self, max_attempts: int = 5, window_seconds: int = 60):
self.max_attempts = max_attempts
self.window_seconds = window_seconds
self._attempts: dict[str, list[float]] = defaultdict(list)
def _cleanup(self, key: str) -> None:
now = time.monotonic()
cutoff = now - self.window_seconds
self._attempts[key] = [t for t in self._attempts[key] if t > cutoff]
def check(self, key: str) -> None:
self._cleanup(key)
if len(self._attempts[key]) >= self.max_attempts:
retry_after = int(self.window_seconds - (time.monotonic() - self._attempts[key][0]))
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Trop de tentatives. Réessayez dans {max(retry_after, 1)} secondes.",
headers={"Retry-After": str(max(retry_after, 1))},
)
def record(self, key: str) -> None:
self._attempts[key].append(time.monotonic())
login_limiter = RateLimiter(max_attempts=5, window_seconds=60)
change_password_limiter = RateLimiter(max_attempts=5, window_seconds=60)
register_limiter = RateLimiter(max_attempts=3, window_seconds=300)
def rate_limit(limiter: RateLimiter) -> Callable:
def decorator(func: Callable) -> Callable:
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(request: Request, *args, **kwargs):
client_ip = get_client_ip(request)
limiter.check(client_ip)
limiter.record(client_ip)
return await func(request, *args, **kwargs)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(request: Request, *args, **kwargs):
client_ip = get_client_ip(request)
limiter.check(client_ip)
limiter.record(client_ip)
return func(request, *args, **kwargs)
return sync_wrapper
return decorator