Push V1 app
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
node_modules
|
||||
dist
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.db
|
||||
uploads/*
|
||||
!uploads/.gitkeep
|
||||
.env
|
||||
@@ -0,0 +1,9 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.db
|
||||
*.db-journal
|
||||
backend/uploads/*
|
||||
!backend/uploads/.gitkeep
|
||||
.env
|
||||
venv/
|
||||
.venv/
|
||||
@@ -0,0 +1,32 @@
|
||||
FROM python:3.12-slim AS build
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc libffi-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||
|
||||
FROM python:3.12-slim
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=build /install /usr/local
|
||||
|
||||
RUN useradd -r -s /bin/false appuser
|
||||
RUN mkdir -p /app/data /app/uploads && chown appuser:appuser /app/data /app/uploads
|
||||
|
||||
WORKDIR /app
|
||||
COPY . .
|
||||
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
+104
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import bcrypt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import User
|
||||
|
||||
PLACEHOLDER_KEY = "change-me-to-a-random-64-char-string"
|
||||
_raw_secret = os.environ.get("SECRET_KEY", "")
|
||||
if not _raw_secret or _raw_secret == PLACEHOLDER_KEY:
|
||||
print("╔══════════════════════════════════════════════════╗")
|
||||
print("║ ERREUR: SECRET_KEY non défini ou valeur par ║")
|
||||
print("║ défaut. Définissez une clé secrète dans .env ║")
|
||||
print("║ ou dans les variables d'environnement. ║")
|
||||
print("╚══════════════════════════════════════════════════╝")
|
||||
sys.exit(1)
|
||||
|
||||
SECRET_KEY = _raw_secret
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 24 hours
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(plain.encode(), hashed.encode())
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token invalide",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
sub = payload.get("sub")
|
||||
if sub is None:
|
||||
raise credentials_exception
|
||||
user_id = int(sub)
|
||||
token_version = payload.get("token_version", 0)
|
||||
except (JWTError, ValueError):
|
||||
raise credentials_exception
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
if user.token_version != token_version:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
def get_user_from_token(token: str, db: Session) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token invalide",
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
sub = payload.get("sub")
|
||||
if sub is None:
|
||||
raise credentials_exception
|
||||
user_id = int(sub)
|
||||
token_version = payload.get("token_version", 0)
|
||||
except (JWTError, ValueError):
|
||||
raise credentials_exception
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
if user.token_version != token_version:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(user: User = Depends(get_current_user)) -> User:
|
||||
if not user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Accès réservé à l'administrateur",
|
||||
)
|
||||
return user
|
||||
@@ -0,0 +1,28 @@
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
|
||||
DATABASE_URL = "sqlite:///./data/cellar.db"
|
||||
|
||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def migrate():
|
||||
with engine.connect() as conn:
|
||||
try:
|
||||
conn.execute(text("ALTER TABLE users ADD COLUMN token_version INTEGER NOT NULL DEFAULT 0"))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass # Column already exists
|
||||
+103
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Load .env file if present (docker-compose handles this in production)
|
||||
_env_path = os.path.join(os.path.dirname(__file__), "..", ".env")
|
||||
if os.path.isfile(_env_path):
|
||||
with open(_env_path) as _f:
|
||||
for _line in _f:
|
||||
_line = _line.strip()
|
||||
if _line and not _line.startswith("#") and "=" in _line:
|
||||
_key, _, _value = _line.partition("=")
|
||||
os.environ.setdefault(_key.strip(), _value.strip())
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from auth import hash_password
|
||||
from database import Base, SessionLocal, engine, migrate
|
||||
from models import Invitation, User
|
||||
from routers.admin import router as admin_router
|
||||
from routers.auth import router as auth_router
|
||||
from routers.drinks import router as drinks_router
|
||||
|
||||
migrate()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
app = FastAPI(title="Cellar API", version="1.0.0")
|
||||
|
||||
ALLOWED_ORIGINS = [
|
||||
origin.strip()
|
||||
for origin in os.environ.get("CORS_ORIGINS", "http://localhost:5173,http://localhost:3000").split(",")
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
app.include_router(admin_router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(drinks_router)
|
||||
|
||||
uploads_dir = os.path.join(os.path.dirname(__file__), "uploads")
|
||||
os.makedirs(uploads_dir, exist_ok=True)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def create_default_user():
|
||||
admin_password = os.environ.get("ADMIN_PASSWORD")
|
||||
if not admin_password:
|
||||
admin_password = secrets.token_urlsafe(16)
|
||||
print("╔══════════════════════════════════════════════════╗")
|
||||
print("║ ADMIN_PASSWORD non défini dans l'environnement ║")
|
||||
print("║ Un mot de passe aléatoire a été généré ║")
|
||||
print("║ Configurez ADMIN_PASSWORD pour la production ║")
|
||||
print("╚══════════════════════════════════════════════════╝")
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin_exists = db.query(User).filter(
|
||||
(User.username == "admin") | (User.email == "admin@cellar.local")
|
||||
).first()
|
||||
if not admin_exists:
|
||||
try:
|
||||
user = User(
|
||||
username="admin",
|
||||
email="admin@cellar.local",
|
||||
hashed_password=hash_password(admin_password),
|
||||
is_admin=True,
|
||||
)
|
||||
db.add(user)
|
||||
db.flush()
|
||||
|
||||
invite = Invitation(
|
||||
token=secrets.token_urlsafe(32),
|
||||
created_by=user.id,
|
||||
expires_at=datetime.utcnow() + timedelta(days=7),
|
||||
)
|
||||
db.add(invite)
|
||||
db.commit()
|
||||
|
||||
print("╔══════════════════════════════════════════╗")
|
||||
print("║ Premier démarrage - Compte créé ║")
|
||||
print("║ Username : admin ║")
|
||||
print("║ Password : (depuis ADMIN_PASSWORD ou ║")
|
||||
print("║ mot de passe aléatoire) ║")
|
||||
print("╚══════════════════════════════════════════╝")
|
||||
except Exception:
|
||||
db.rollback()
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
@@ -0,0 +1,106 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, Enum, Float, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from database import Base
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class DrinkCategory(str, enum.Enum):
|
||||
WINE = "wine"
|
||||
BEER = "beer"
|
||||
SPIRIT = "spirit"
|
||||
|
||||
|
||||
class WineColor(str, enum.Enum):
|
||||
RED = "red"
|
||||
WHITE = "white"
|
||||
ROSE = "rose"
|
||||
SPARKLING = "sparkling"
|
||||
|
||||
|
||||
class SpiritType(str, enum.Enum):
|
||||
WHISKY = "whisky"
|
||||
VODKA = "vodka"
|
||||
RUM = "rum"
|
||||
GIN = "gin"
|
||||
TEQUILA = "tequila"
|
||||
COGNAC = "cognac"
|
||||
CALVADOS = "calvados"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class BeerStyle(str, enum.Enum):
|
||||
IPA = "ipa"
|
||||
STOUT = "stout"
|
||||
LAGER = "lager"
|
||||
ALE = "ale"
|
||||
WHEAT = "wheat"
|
||||
SOUR = "sour"
|
||||
PILSNER = "pilsner"
|
||||
PORTER = "porter"
|
||||
BELGIAN = "belgian"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String(100), unique=True, nullable=False, index=True)
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
is_admin = Column(Boolean, default=False)
|
||||
token_version = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
drinks = relationship("Drink", back_populates="owner")
|
||||
|
||||
|
||||
class Invitation(Base):
|
||||
__tablename__ = "invitations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
token = Column(String(64), unique=True, nullable=False, index=True)
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
used_by = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class Drink(Base):
|
||||
__tablename__ = "drinks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
category = Column(Enum(DrinkCategory), nullable=False)
|
||||
image_path = Column(String(500), nullable=True)
|
||||
rating = Column(Float, nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
owner = relationship("User", back_populates="drinks")
|
||||
|
||||
# Wine fields
|
||||
grape_variety = Column(String(255), nullable=True)
|
||||
vintage = Column(Integer, nullable=True)
|
||||
region = Column(String(255), nullable=True)
|
||||
producer = Column(String(255), nullable=True)
|
||||
wine_color = Column(Enum(WineColor), nullable=True)
|
||||
|
||||
# Beer fields
|
||||
brewery = Column(String(255), nullable=True)
|
||||
beer_style = Column(Enum(BeerStyle), nullable=True)
|
||||
ibu = Column(Float, nullable=True)
|
||||
abv = Column(Float, nullable=True)
|
||||
|
||||
# Spirit fields
|
||||
spirit_type = Column(Enum(SpiritType), nullable=True)
|
||||
age_years = Column(Integer, nullable=True)
|
||||
distillery = Column(String(255), nullable=True)
|
||||
country = Column(String(255), nullable=True)
|
||||
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, max_attempts: int = 5, window_seconds: int = 60):
|
||||
self.max_attempts = max_attempts
|
||||
self.window_seconds = window_seconds
|
||||
self._attempts: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def _cleanup(self, key: str) -> None:
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
self._attempts[key] = [t for t in self._attempts[key] if t > cutoff]
|
||||
|
||||
def check(self, key: str) -> None:
|
||||
self._cleanup(key)
|
||||
if len(self._attempts[key]) >= self.max_attempts:
|
||||
retry_after = int(self.window_seconds - (time.monotonic() - self._attempts[key][0]))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Trop de tentatives. Réessayez dans {max(retry_after, 1)} secondes.",
|
||||
headers={"Retry-After": str(max(retry_after, 1))},
|
||||
)
|
||||
|
||||
def record(self, key: str) -> None:
|
||||
self._attempts[key].append(time.monotonic())
|
||||
|
||||
|
||||
login_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||
change_password_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||
register_limiter = RateLimiter(max_attempts=3, window_seconds=300)
|
||||
|
||||
|
||||
def rate_limit(limiter: RateLimiter) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(request: Request, *args, **kwargs):
|
||||
client_ip = get_client_ip(request)
|
||||
limiter.check(client_ip)
|
||||
limiter.record(client_ip)
|
||||
return await func(request, *args, **kwargs)
|
||||
return async_wrapper
|
||||
else:
|
||||
@wraps(func)
|
||||
def sync_wrapper(request: Request, *args, **kwargs):
|
||||
client_ip = get_client_ip(request)
|
||||
limiter.check(client_ip)
|
||||
limiter.record(client_ip)
|
||||
return func(request, *args, **kwargs)
|
||||
return sync_wrapper
|
||||
return decorator
|
||||
@@ -0,0 +1,2 @@
|
||||
pytest>=9.0
|
||||
httpx>=0.28.0
|
||||
@@ -0,0 +1,9 @@
|
||||
fastapi==0.138.1
|
||||
uvicorn[standard]==0.49.0
|
||||
SQLAlchemy==2.0.51
|
||||
pydantic==2.13.4
|
||||
python-multipart==0.0.32
|
||||
Pillow==12.2.0
|
||||
aiofiles==25.1.0
|
||||
python-jose[cryptography]==3.5.0
|
||||
bcrypt==5.0.0
|
||||
@@ -0,0 +1 @@
|
||||
# Cellar API - backend router
|
||||
@@ -0,0 +1,116 @@
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from auth import require_admin
|
||||
from database import get_db
|
||||
from models import Drink, Invitation, User
|
||||
from auth import hash_password
|
||||
from schemas import AdminResetPassword, AdminToggleAdmin, UserResponse
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.get("/users", response_model=list[UserResponse])
|
||||
def list_users(
|
||||
admin: User = Depends(require_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
users = db.query(User).order_by(User.created_at.desc()).all()
|
||||
return [UserResponse.model_validate(u) for u in users]
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}")
|
||||
def delete_user(
|
||||
user_id: int,
|
||||
admin: User = Depends(require_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
if user_id == admin.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Vous ne pouvez pas vous supprimer vous-même",
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
uploads_dir = os.path.join(os.path.dirname(__file__), "..", "uploads")
|
||||
|
||||
drinks = db.query(Drink).filter(Drink.owner_id == user_id).all()
|
||||
for drink in drinks:
|
||||
if drink.image_path:
|
||||
filepath = os.path.join(os.path.dirname(__file__), "..", drink.image_path)
|
||||
resolved = os.path.realpath(filepath)
|
||||
if resolved.startswith(os.path.realpath(uploads_dir)) and os.path.exists(resolved):
|
||||
os.remove(resolved)
|
||||
|
||||
db.query(Drink).filter(Drink.owner_id == user_id).delete()
|
||||
db.query(Invitation).filter(
|
||||
(Invitation.created_by == user_id) | (Invitation.used_by == user_id)
|
||||
).delete()
|
||||
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
return {"detail": "User deleted"}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/reset-password")
|
||||
def reset_user_password(
|
||||
user_id: int,
|
||||
body: AdminResetPassword,
|
||||
admin: User = Depends(require_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
if user_id == admin.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Utilisez /api/auth/change-password pour changer votre propre mot de passe",
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user.hashed_password = hash_password(body.new_password)
|
||||
user.token_version += 1
|
||||
db.commit()
|
||||
return {"detail": "Password reset"}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/toggle-admin")
|
||||
def toggle_admin(
|
||||
user_id: int,
|
||||
body: AdminToggleAdmin,
|
||||
admin: User = Depends(require_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
if user_id == admin.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Vous ne pouvez pas modifier votre propre statut admin",
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user.is_admin = body.is_admin
|
||||
user.token_version += 1
|
||||
db.commit()
|
||||
return {"detail": f"Admin status set to {body.is_admin}"}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
def stats(
|
||||
admin: User = Depends(require_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return {
|
||||
"users": db.query(User).count(),
|
||||
"drinks": db.query(Drink).count(),
|
||||
"invitations": db.query(Invitation).count(),
|
||||
"invitations_used": db.query(Invitation).filter(Invitation.used_by.isnot(None)).count(),
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import secrets as secrets_mod
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from auth import (
|
||||
create_access_token,
|
||||
get_current_user,
|
||||
hash_password,
|
||||
require_admin,
|
||||
verify_password,
|
||||
)
|
||||
from database import get_db
|
||||
from models import Invitation, User
|
||||
from ratelimit import change_password_limiter, login_limiter, rate_limit, register_limiter
|
||||
from schemas import (
|
||||
InvitationCreate,
|
||||
InvitationResponse,
|
||||
PasswordChange,
|
||||
TokenResponse,
|
||||
UserCreate,
|
||||
UserLogin,
|
||||
UserResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
INVITATION_EXPIRY_DAYS = 7
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse)
|
||||
@rate_limit(register_limiter)
|
||||
def register(request: Request, data: UserCreate, db: Session = Depends(get_db)):
|
||||
invite = (
|
||||
db.query(Invitation)
|
||||
.filter(
|
||||
Invitation.token == data.invite_token,
|
||||
Invitation.used_by.is_(None),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not invite:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invitation invalide ou déjà utilisée",
|
||||
)
|
||||
|
||||
if invite.expires_at and invite.expires_at < datetime.utcnow():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invitation expirée",
|
||||
)
|
||||
|
||||
if db.query(User).filter(
|
||||
(User.username == data.username) | (User.email == data.email)
|
||||
).first():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Un compte avec ces informations existe déjà",
|
||||
)
|
||||
|
||||
user = User(
|
||||
username=data.username,
|
||||
email=data.email,
|
||||
hashed_password=hash_password(data.password),
|
||||
)
|
||||
db.add(user)
|
||||
db.flush()
|
||||
|
||||
invite.used_by = user.id
|
||||
invite.used_at = datetime.utcnow()
|
||||
db.flush()
|
||||
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
token = create_access_token({"sub": str(user.id), "token_version": user.token_version})
|
||||
return TokenResponse(
|
||||
access_token=token,
|
||||
user=UserResponse.model_validate(user),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
@rate_limit(login_limiter)
|
||||
def login(request: Request, data: UserLogin, db: Session = Depends(get_db)):
|
||||
user = db.query(User).filter(User.username == data.username).first()
|
||||
if not user or not verify_password(data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Identifiants incorrects",
|
||||
)
|
||||
|
||||
token = create_access_token({"sub": str(user.id), "token_version": user.token_version})
|
||||
return TokenResponse(
|
||||
access_token=token,
|
||||
user=UserResponse.model_validate(user),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
def me(user: User = Depends(get_current_user)):
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
@rate_limit(change_password_limiter)
|
||||
def change_password(
|
||||
request: Request,
|
||||
data: PasswordChange,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
if not verify_password(data.current_password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Mot de passe actuel incorrect",
|
||||
)
|
||||
user.hashed_password = hash_password(data.new_password)
|
||||
user.token_version += 1
|
||||
db.commit()
|
||||
return {"detail": "Mot de passe modifié"}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
user.token_version += 1
|
||||
db.commit()
|
||||
return {"detail": "Déconnexion réussie"}
|
||||
|
||||
|
||||
@router.post("/invitations", response_model=InvitationResponse)
|
||||
def create_invitation(
|
||||
admin: User = Depends(require_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
invitation = Invitation(
|
||||
token=secrets_mod.token_urlsafe(32),
|
||||
created_by=admin.id,
|
||||
expires_at=datetime.utcnow() + timedelta(days=INVITATION_EXPIRY_DAYS),
|
||||
)
|
||||
db.add(invitation)
|
||||
db.commit()
|
||||
db.refresh(invitation)
|
||||
return InvitationResponse.model_validate(invitation)
|
||||
|
||||
|
||||
@router.get("/invitations", response_model=list[InvitationResponse])
|
||||
def list_invitations(
|
||||
admin: User = Depends(require_admin),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
invitations = (
|
||||
db.query(Invitation)
|
||||
.filter(Invitation.created_by == admin.id)
|
||||
.order_by(Invitation.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [InvitationResponse.model_validate(i) for i in invitations]
|
||||
@@ -0,0 +1,221 @@
|
||||
import io
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from auth import get_current_user, get_user_from_token
|
||||
from database import get_db
|
||||
from models import Drink, DrinkCategory, User
|
||||
from schemas import DrinkCreate, DrinkResponse, DrinkUpdate
|
||||
|
||||
router = APIRouter(prefix="/api/drinks", tags=["drinks"])
|
||||
|
||||
UPLOAD_DIR = os.path.join(os.path.dirname(__file__), "..", "uploads")
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
ALLOWED_MIMES = {"image/jpeg", "image/png", "image/webp", "image/gif"}
|
||||
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
|
||||
def get_user_drink(drink_id: int, user: User, db: Session) -> Drink:
|
||||
drink = db.query(Drink).filter(Drink.id == drink_id, Drink.owner_id == user.id).first()
|
||||
if not drink:
|
||||
raise HTTPException(status_code=404, detail="Drink not found")
|
||||
return drink
|
||||
|
||||
|
||||
@router.post("", response_model=DrinkResponse)
|
||||
async def create_drink(
|
||||
drink: DrinkCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
db_drink = Drink(**drink.model_dump(), owner_id=user.id)
|
||||
db.add(db_drink)
|
||||
db.commit()
|
||||
db.refresh(db_drink)
|
||||
return db_drink
|
||||
|
||||
|
||||
@router.get("", response_model=list[DrinkResponse])
|
||||
def list_drinks(
|
||||
category: DrinkCategory | None = None,
|
||||
search: str | None = None,
|
||||
min_rating: float | None = Query(None, ge=0, le=5),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
query = db.query(Drink).filter(Drink.owner_id == user.id)
|
||||
|
||||
if category:
|
||||
query = query.filter(Drink.category == category)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
Drink.name.ilike(search_term),
|
||||
Drink.region.ilike(search_term),
|
||||
Drink.producer.ilike(search_term),
|
||||
Drink.brewery.ilike(search_term),
|
||||
Drink.distillery.ilike(search_term),
|
||||
Drink.grape_variety.ilike(search_term),
|
||||
Drink.country.ilike(search_term),
|
||||
Drink.notes.ilike(search_term),
|
||||
)
|
||||
)
|
||||
|
||||
if min_rating is not None:
|
||||
query = query.filter(Drink.rating >= min_rating)
|
||||
|
||||
return query.order_by(Drink.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
@router.get("/{drink_id}", response_model=DrinkResponse)
|
||||
def get_drink(
|
||||
drink_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return get_user_drink(drink_id, user, db)
|
||||
|
||||
|
||||
@router.put("/{drink_id}", response_model=DrinkResponse)
|
||||
def update_drink(
|
||||
drink_id: int,
|
||||
updates: DrinkUpdate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
drink = get_user_drink(drink_id, user, db)
|
||||
|
||||
update_data = updates.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(drink, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(drink)
|
||||
return drink
|
||||
|
||||
|
||||
@router.delete("/{drink_id}")
|
||||
def delete_drink(
|
||||
drink_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
drink = get_user_drink(drink_id, user, db)
|
||||
|
||||
if drink.image_path:
|
||||
filepath = os.path.join(os.path.dirname(__file__), "..", drink.image_path)
|
||||
resolved = os.path.realpath(filepath)
|
||||
if resolved.startswith(os.path.realpath(UPLOAD_DIR)) and os.path.exists(resolved):
|
||||
os.remove(resolved)
|
||||
|
||||
db.delete(drink)
|
||||
db.commit()
|
||||
return {"detail": "Drink deleted"}
|
||||
|
||||
|
||||
@router.post("/{drink_id}/upload-image", response_model=DrinkResponse)
|
||||
async def upload_image(
|
||||
drink_id: int,
|
||||
file: UploadFile = File(...),
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
drink = get_user_drink(drink_id, user, db)
|
||||
|
||||
if file.content_type not in ALLOWED_MIMES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Type de fichier non autorisé. Utilisez: {', '.join(ALLOWED_MIMES)}",
|
||||
)
|
||||
|
||||
ext = os.path.splitext(file.filename or "")[1].lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Extension non autorisée. Utilisez: {', '.join(ALLOWED_EXTENSIONS)}",
|
||||
)
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Fichier trop volumineux. Taille maximale: {MAX_UPLOAD_SIZE // (1024 * 1024)} MB",
|
||||
)
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(content))
|
||||
img.verify()
|
||||
real_format = img.format.lower() if img.format else ""
|
||||
if real_format not in {"jpeg", "png", "webp", "gif"}:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Format d'image non autorisé: {real_format}. Utilisez: JPEG, PNG, WebP ou GIF",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Le fichier n'est pas une image valide",
|
||||
)
|
||||
|
||||
filename = f"{uuid.uuid4().hex}{ext}"
|
||||
filepath = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
if drink.image_path:
|
||||
old_path = os.path.join(os.path.dirname(__file__), "..", drink.image_path)
|
||||
resolved = os.path.realpath(old_path)
|
||||
if resolved.startswith(os.path.realpath(UPLOAD_DIR)) and os.path.exists(resolved):
|
||||
os.remove(resolved)
|
||||
|
||||
drink.image_path = f"uploads/{filename}"
|
||||
db.commit()
|
||||
db.refresh(drink)
|
||||
return drink
|
||||
|
||||
|
||||
MIME_MAP = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
".gif": "image/gif",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{drink_id}/image")
|
||||
def get_drink_image(
|
||||
drink_id: int,
|
||||
token: str = Query(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
user = get_user_from_token(token, db)
|
||||
drink = db.query(Drink).filter(Drink.id == drink_id, Drink.owner_id == user.id).first()
|
||||
if not drink or not drink.image_path:
|
||||
raise HTTPException(status_code=404, detail="Image non trouvée")
|
||||
|
||||
filename = os.path.basename(drink.image_path)
|
||||
filepath = os.path.realpath(os.path.join(UPLOAD_DIR, filename))
|
||||
|
||||
if not filepath.startswith(os.path.realpath(UPLOAD_DIR)) or not os.path.isfile(filepath):
|
||||
raise HTTPException(status_code=404, detail="Image non trouvée")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
media_type = MIME_MAP.get(ext, "application/octet-stream")
|
||||
|
||||
return FileResponse(filepath, media_type=media_type)
|
||||
@@ -0,0 +1,162 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from models import BeerStyle, DrinkCategory, SpiritType, WineColor
|
||||
|
||||
|
||||
def validate_password_strength(v: str) -> str:
|
||||
if not any(c.isupper() for c in v):
|
||||
raise ValueError("Le mot de passe doit contenir au moins une majuscule")
|
||||
if not any(c.islower() for c in v):
|
||||
raise ValueError("Le mot de passe doit contenir au moins une minuscule")
|
||||
if not any(c.isdigit() for c in v):
|
||||
raise ValueError("Le mot de passe doit contenir au moins un chiffre")
|
||||
return v
|
||||
|
||||
|
||||
# ── Auth ──
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str = Field(..., min_length=3, max_length=50)
|
||||
email: str = Field(..., max_length=255)
|
||||
password: str = Field(..., min_length=8, max_length=128)
|
||||
invite_token: str = Field(..., min_length=1, max_length=128)
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, v: str) -> str:
|
||||
return validate_password_strength(v)
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: str) -> str:
|
||||
if "@" not in v or "." not in v.split("@")[-1]:
|
||||
raise ValueError("Adresse email invalide")
|
||||
return v.lower().strip()
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
username: str = Field(..., min_length=1, max_length=50)
|
||||
password: str = Field(..., min_length=1, max_length=128)
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
username: str
|
||||
email: str
|
||||
is_admin: bool
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
user: UserResponse
|
||||
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
current_password: str = Field(..., min_length=1, max_length=128)
|
||||
new_password: str = Field(..., min_length=8, max_length=128)
|
||||
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def validate_password(cls, v: str) -> str:
|
||||
return validate_password_strength(v)
|
||||
|
||||
|
||||
class InvitationCreate(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
# ── Admin ──
|
||||
|
||||
|
||||
class AdminResetPassword(BaseModel):
|
||||
new_password: str = Field(..., min_length=8, max_length=128)
|
||||
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def validate_password(cls, v: str) -> str:
|
||||
return validate_password_strength(v)
|
||||
|
||||
|
||||
class AdminToggleAdmin(BaseModel):
|
||||
is_admin: bool
|
||||
|
||||
|
||||
class InvitationResponse(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
token: str
|
||||
created_by: int
|
||||
used_by: Optional[int] = None
|
||||
created_at: datetime
|
||||
used_at: Optional[datetime] = None
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
# ── Drinks ──
|
||||
|
||||
class DrinkBase(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
category: DrinkCategory
|
||||
rating: Optional[float] = Field(None, ge=0, le=5)
|
||||
notes: Optional[str] = Field(None, max_length=5000)
|
||||
|
||||
grape_variety: Optional[str] = Field(None, max_length=255)
|
||||
vintage: Optional[int] = Field(None, ge=1900, le=2100)
|
||||
region: Optional[str] = Field(None, max_length=255)
|
||||
producer: Optional[str] = Field(None, max_length=255)
|
||||
wine_color: Optional[WineColor] = None
|
||||
|
||||
brewery: Optional[str] = Field(None, max_length=255)
|
||||
beer_style: Optional[BeerStyle] = None
|
||||
ibu: Optional[float] = Field(None, ge=0, le=200)
|
||||
abv: Optional[float] = Field(None, ge=0, le=30)
|
||||
|
||||
spirit_type: Optional[SpiritType] = None
|
||||
age_years: Optional[int] = Field(None, ge=0, le=200)
|
||||
distillery: Optional[str] = Field(None, max_length=255)
|
||||
country: Optional[str] = Field(None, max_length=255)
|
||||
|
||||
|
||||
class DrinkCreate(DrinkBase):
|
||||
pass
|
||||
|
||||
|
||||
class DrinkUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
category: Optional[DrinkCategory] = None
|
||||
rating: Optional[float] = Field(None, ge=0, le=5)
|
||||
notes: Optional[str] = Field(None, max_length=5000)
|
||||
|
||||
grape_variety: Optional[str] = Field(None, max_length=255)
|
||||
vintage: Optional[int] = Field(None, ge=1900, le=2100)
|
||||
region: Optional[str] = Field(None, max_length=255)
|
||||
producer: Optional[str] = Field(None, max_length=255)
|
||||
wine_color: Optional[WineColor] = None
|
||||
|
||||
brewery: Optional[str] = Field(None, max_length=255)
|
||||
beer_style: Optional[BeerStyle] = None
|
||||
ibu: Optional[float] = Field(None, ge=0, le=200)
|
||||
abv: Optional[float] = Field(None, ge=0, le=30)
|
||||
|
||||
spirit_type: Optional[SpiritType] = None
|
||||
age_years: Optional[int] = Field(None, ge=0, le=200)
|
||||
distillery: Optional[str] = Field(None, max_length=255)
|
||||
country: Optional[str] = Field(None, max_length=255)
|
||||
|
||||
|
||||
class DrinkResponse(DrinkBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
owner_id: int
|
||||
image_path: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
os.environ["SECRET_KEY"] = secrets.token_hex(32)
|
||||
os.environ["ADMIN_PASSWORD"] = "TestPass123"
|
||||
|
||||
from fastapi.testclient import TestClient # noqa: E402
|
||||
|
||||
from main import app # noqa: E402
|
||||
from database import SessionLocal, Base, engine # noqa: E402
|
||||
from models import User, Invitation # noqa: E402
|
||||
from auth import hash_password, create_access_token # noqa: E402
|
||||
|
||||
ADMIN_PASSWORD = "TestPass123"
|
||||
|
||||
|
||||
def setup():
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = User(
|
||||
username="admin",
|
||||
email="admin@cellar.local",
|
||||
hashed_password=hash_password(ADMIN_PASSWORD),
|
||||
is_admin=True,
|
||||
)
|
||||
db.add(admin)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
setup()
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def get_admin_token() -> str:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = db.query(User).filter(User.username == "admin").first()
|
||||
token = create_access_token({"sub": str(admin.id), "token_version": admin.token_version})
|
||||
return token
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_invite_token() -> str:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = db.query(User).filter(User.username == "admin").first()
|
||||
invite = Invitation(
|
||||
token=secrets.token_urlsafe(32),
|
||||
created_by=admin.id,
|
||||
expires_at=datetime.utcnow() + timedelta(days=7),
|
||||
)
|
||||
db.add(invite)
|
||||
db.commit()
|
||||
db.refresh(invite)
|
||||
return invite.token
|
||||
finally:
|
||||
db.close()
|
||||
@@ -0,0 +1,181 @@
|
||||
import os
|
||||
|
||||
from tests.conftest import client, get_admin_token, get_invite_token, ADMIN_PASSWORD
|
||||
|
||||
|
||||
# ── Auth tests ──
|
||||
|
||||
def test_login_success():
|
||||
resp = client.post("/api/auth/login", json={"username": "admin", "password": ADMIN_PASSWORD})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert data["user"]["username"] == "admin"
|
||||
|
||||
|
||||
def test_login_wrong_password():
|
||||
resp = client.post("/api/auth/login", json={"username": "admin", "password": "wrong"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_login_wrong_username():
|
||||
resp = client.post("/api/auth/login", json={"username": "nobody", "password": ADMIN_PASSWORD})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_me_unauthorized():
|
||||
resp = client.get("/api/auth/me")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_me_authorized():
|
||||
token = get_admin_token()
|
||||
resp = client.get("/api/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["username"] == "admin"
|
||||
|
||||
|
||||
# ── Invitation tests ──
|
||||
|
||||
def test_register_with_invite():
|
||||
invite_token = get_invite_token()
|
||||
resp = client.post("/api/auth/register", json={
|
||||
"username": "newuser",
|
||||
"email": "new@test.com",
|
||||
"password": "NewPass123",
|
||||
"invite_token": invite_token,
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["user"]["username"] == "newuser"
|
||||
|
||||
|
||||
def test_register_invalid_invite():
|
||||
resp = client.post("/api/auth/register", json={
|
||||
"username": "user2",
|
||||
"email": "u2@test.com",
|
||||
"password": "Pass1234",
|
||||
"invite_token": "invalid-token",
|
||||
})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_invite_reuse():
|
||||
invite_token = get_invite_token()
|
||||
client.post("/api/auth/register", json={
|
||||
"username": "reuse1",
|
||||
"email": "r1@test.com",
|
||||
"password": "Pass1234",
|
||||
"invite_token": invite_token,
|
||||
})
|
||||
resp = client.post("/api/auth/register", json={
|
||||
"username": "reuse2",
|
||||
"email": "r2@test.com",
|
||||
"password": "Pass1234",
|
||||
"invite_token": invite_token,
|
||||
})
|
||||
# 400 = invitation déjà utilisée, 429 = rate limit (les deux empêchent la réutilisation)
|
||||
assert resp.status_code in (400, 429)
|
||||
|
||||
|
||||
# ── Drink CRUD tests ──
|
||||
|
||||
def test_create_drink():
|
||||
token = get_admin_token()
|
||||
resp = client.post("/api/drinks", json={
|
||||
"name": "Château Test",
|
||||
"category": "wine",
|
||||
"rating": 4.5,
|
||||
}, headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Château Test"
|
||||
assert data["category"] == "wine"
|
||||
|
||||
|
||||
def test_list_drinks():
|
||||
token = get_admin_token()
|
||||
resp = client.get("/api/drinks", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) > 0
|
||||
|
||||
|
||||
def test_update_drink():
|
||||
token = get_admin_token()
|
||||
create_resp = client.post("/api/drinks", json={
|
||||
"name": "Château Test",
|
||||
"category": "wine",
|
||||
}, headers={"Authorization": f"Bearer {token}"})
|
||||
drink_id = create_resp.json()["id"]
|
||||
resp = client.put(f"/api/drinks/{drink_id}", json={
|
||||
"name": "Château Updated",
|
||||
"rating": 5.0,
|
||||
}, headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "Château Updated"
|
||||
assert resp.json()["rating"] == 5.0
|
||||
|
||||
|
||||
def test_delete_drink():
|
||||
token = get_admin_token()
|
||||
create_resp = client.post("/api/drinks", json={
|
||||
"name": "Château Test",
|
||||
"category": "wine",
|
||||
}, headers={"Authorization": f"Bearer {token}"})
|
||||
drink_id = create_resp.json()["id"]
|
||||
resp = client.delete(f"/api/drinks/{drink_id}", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
resp = client.get(f"/api/drinks/{drink_id}", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_drink_owner_isolation():
|
||||
token = get_admin_token()
|
||||
create_resp = client.post("/api/drinks", json={
|
||||
"name": "Château Test",
|
||||
"category": "wine",
|
||||
}, headers={"Authorization": f"Bearer {token}"})
|
||||
drink_id = create_resp.json()["id"]
|
||||
|
||||
from database import SessionLocal
|
||||
from models import User
|
||||
from auth import hash_password, create_access_token
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
other = User(username="other", email="other@test.com", hashed_password=hash_password("Other123"))
|
||||
db.add(other)
|
||||
db.commit()
|
||||
db.refresh(other)
|
||||
other_token = create_access_token({"sub": str(other.id), "token_version": other.token_version})
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
resp = client.get(f"/api/drinks/{drink_id}", headers={"Authorization": f"Bearer {other_token}"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── Rate limiter test ──
|
||||
|
||||
def test_rate_limit_blocks():
|
||||
for _ in range(6):
|
||||
client.post("/api/auth/login", json={"username": "admin", "password": "wrong"})
|
||||
resp = client.post("/api/auth/login", json={"username": "admin", "password": "wrong"})
|
||||
assert resp.status_code == 429
|
||||
|
||||
|
||||
# ── Logout test ──
|
||||
|
||||
def test_logout_invalidates_token():
|
||||
token = get_admin_token()
|
||||
resp = client.post("/api/auth/logout", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 200
|
||||
resp = client.get("/api/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ── Health check ──
|
||||
|
||||
def test_health():
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ok"
|
||||
Reference in New Issue
Block a user