Files
openrouter-watcher/src/openrouter_monitor/middleware/csrf.py
Luca Sacchi Ricciardi ccd96acaac feat(frontend): T46 configure HTMX and CSRF protection
- Add CSRFMiddleware for form protection
- Implement token generation and validation
- Add CSRF meta tag to base.html
- Create tests for CSRF protection

Tests: 13 passing
2026-04-07 18:02:20 +02:00

133 lines
4.0 KiB
Python

"""CSRF Protection Middleware.
Provides CSRF token generation and validation for form submissions.
"""
import secrets
from typing import Optional
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware for CSRF protection.
Generates CSRF tokens for sessions and validates them on
state-changing requests (POST, PUT, DELETE, PATCH).
"""
CSRF_TOKEN_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token"
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
def __init__(self, app, cookie_name: str = "csrf_token", cookie_secure: bool = False):
super().__init__(app)
self.cookie_name = cookie_name
self.cookie_secure = cookie_secure
async def dispatch(self, request: Request, call_next):
"""Process request and validate CSRF token if needed.
Args:
request: The incoming request
call_next: Next middleware/handler in chain
Returns:
Response from next handler
"""
# Generate or retrieve CSRF token
csrf_token = self._get_or_create_token(request)
# Validate token on state-changing requests
if request.method not in self.SAFE_METHODS:
is_valid = await self._validate_token(request, csrf_token)
if not is_valid:
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=403,
content={"detail": "CSRF token missing or invalid"}
)
# Store token in request state for templates
request.state.csrf_token = csrf_token
# Process request
response = await call_next(request)
# Set CSRF cookie
response.set_cookie(
key=self.cookie_name,
value=csrf_token,
httponly=False, # Must be accessible by JavaScript
secure=self.cookie_secure,
samesite="lax",
max_age=3600 * 24 * 7, # 7 days
)
return response
def _get_or_create_token(self, request: Request) -> str:
"""Get existing token from cookie or create new one.
Args:
request: The incoming request
Returns:
CSRF token string
"""
# Try to get from cookie
token = request.cookies.get(self.cookie_name)
if token:
return token
# Generate new token
return secrets.token_urlsafe(32)
async def _validate_token(self, request: Request, expected_token: str) -> bool:
"""Validate CSRF token from request.
Checks header first, then form data.
Args:
request: The incoming request
expected_token: Expected token value
Returns:
True if token is valid, False otherwise
"""
# Check header first (for HTMX/ajax requests)
token = request.headers.get(self.CSRF_HEADER_NAME)
# If not in header, check form data
if not token:
try:
# Parse form data from request body
body = await request.body()
if body:
from urllib.parse import parse_qs
form_data = parse_qs(body.decode('utf-8'))
if b'csrf_token' in form_data:
token = form_data[b'csrf_token'][0]
except Exception:
pass
# Validate token
if not token:
return False
return secrets.compare_digest(token, expected_token)
def get_csrf_token(request: Request) -> Optional[str]:
"""Get CSRF token from request state.
Use this in route handlers to pass token to templates.
Args:
request: The current request
Returns:
CSRF token or None
"""
return getattr(request.state, "csrf_token", None)