"""CSRF Protection Middleware. Provides CSRF token generation and validation for form submissions. """ import secrets from typing import Optional from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware class CSRFMiddleware(BaseHTTPMiddleware): """Middleware for CSRF protection. Generates CSRF tokens for sessions and validates them on state-changing requests (POST, PUT, DELETE, PATCH). """ CSRF_TOKEN_NAME = "csrf_token" CSRF_HEADER_NAME = "X-CSRF-Token" SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"} def __init__(self, app, cookie_name: str = "csrf_token", cookie_secure: bool = False): super().__init__(app) self.cookie_name = cookie_name self.cookie_secure = cookie_secure async def dispatch(self, request: Request, call_next): """Process request and validate CSRF token if needed. Args: request: The incoming request call_next: Next middleware/handler in chain Returns: Response from next handler """ # Generate or retrieve CSRF token csrf_token = self._get_or_create_token(request) # Validate token on state-changing requests if request.method not in self.SAFE_METHODS: is_valid = await self._validate_token(request, csrf_token) if not is_valid: from fastapi.responses import JSONResponse return JSONResponse( status_code=403, content={"detail": "CSRF token missing or invalid"} ) # Store token in request state for templates request.state.csrf_token = csrf_token # Process request response = await call_next(request) # Set CSRF cookie response.set_cookie( key=self.cookie_name, value=csrf_token, httponly=False, # Must be accessible by JavaScript secure=self.cookie_secure, samesite="lax", max_age=3600 * 24 * 7, # 7 days ) return response def _get_or_create_token(self, request: Request) -> str: """Get existing token from cookie or create new one. Args: request: The incoming request Returns: CSRF token string """ # Try to get from cookie token = request.cookies.get(self.cookie_name) if token: return token # Generate new token return secrets.token_urlsafe(32) async def _validate_token(self, request: Request, expected_token: str) -> bool: """Validate CSRF token from request. Checks header first, then form data. Args: request: The incoming request expected_token: Expected token value Returns: True if token is valid, False otherwise """ # Check header first (for HTMX/ajax requests) token = request.headers.get(self.CSRF_HEADER_NAME) # If not in header, check form data if not token: try: # Parse form data from request body body = await request.body() if body: from urllib.parse import parse_qs form_data = parse_qs(body.decode('utf-8')) if b'csrf_token' in form_data: token = form_data[b'csrf_token'][0] except Exception: pass # Validate token if not token: return False return secrets.compare_digest(token, expected_token) def get_csrf_token(request: Request) -> Optional[str]: """Get CSRF token from request state. Use this in route handlers to pass token to templates. Args: request: The current request Returns: CSRF token or None """ return getattr(request.state, "csrf_token", None)