- Add CSRFMiddleware for form protection - Implement token generation and validation - Add CSRF meta tag to base.html - Create tests for CSRF protection Tests: 13 passing
133 lines
4.0 KiB
Python
133 lines
4.0 KiB
Python
"""CSRF Protection Middleware.
|
|
|
|
Provides CSRF token generation and validation for form submissions.
|
|
"""
|
|
import secrets
|
|
from typing import Optional
|
|
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware for CSRF protection.
|
|
|
|
Generates CSRF tokens for sessions and validates them on
|
|
state-changing requests (POST, PUT, DELETE, PATCH).
|
|
"""
|
|
|
|
CSRF_TOKEN_NAME = "csrf_token"
|
|
CSRF_HEADER_NAME = "X-CSRF-Token"
|
|
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
|
|
|
def __init__(self, app, cookie_name: str = "csrf_token", cookie_secure: bool = False):
|
|
super().__init__(app)
|
|
self.cookie_name = cookie_name
|
|
self.cookie_secure = cookie_secure
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
"""Process request and validate CSRF token if needed.
|
|
|
|
Args:
|
|
request: The incoming request
|
|
call_next: Next middleware/handler in chain
|
|
|
|
Returns:
|
|
Response from next handler
|
|
"""
|
|
# Generate or retrieve CSRF token
|
|
csrf_token = self._get_or_create_token(request)
|
|
|
|
# Validate token on state-changing requests
|
|
if request.method not in self.SAFE_METHODS:
|
|
is_valid = await self._validate_token(request, csrf_token)
|
|
if not is_valid:
|
|
from fastapi.responses import JSONResponse
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={"detail": "CSRF token missing or invalid"}
|
|
)
|
|
|
|
# Store token in request state for templates
|
|
request.state.csrf_token = csrf_token
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Set CSRF cookie
|
|
response.set_cookie(
|
|
key=self.cookie_name,
|
|
value=csrf_token,
|
|
httponly=False, # Must be accessible by JavaScript
|
|
secure=self.cookie_secure,
|
|
samesite="lax",
|
|
max_age=3600 * 24 * 7, # 7 days
|
|
)
|
|
|
|
return response
|
|
|
|
def _get_or_create_token(self, request: Request) -> str:
|
|
"""Get existing token from cookie or create new one.
|
|
|
|
Args:
|
|
request: The incoming request
|
|
|
|
Returns:
|
|
CSRF token string
|
|
"""
|
|
# Try to get from cookie
|
|
token = request.cookies.get(self.cookie_name)
|
|
if token:
|
|
return token
|
|
|
|
# Generate new token
|
|
return secrets.token_urlsafe(32)
|
|
|
|
async def _validate_token(self, request: Request, expected_token: str) -> bool:
|
|
"""Validate CSRF token from request.
|
|
|
|
Checks header first, then form data.
|
|
|
|
Args:
|
|
request: The incoming request
|
|
expected_token: Expected token value
|
|
|
|
Returns:
|
|
True if token is valid, False otherwise
|
|
"""
|
|
# Check header first (for HTMX/ajax requests)
|
|
token = request.headers.get(self.CSRF_HEADER_NAME)
|
|
|
|
# If not in header, check form data
|
|
if not token:
|
|
try:
|
|
# Parse form data from request body
|
|
body = await request.body()
|
|
if body:
|
|
from urllib.parse import parse_qs
|
|
form_data = parse_qs(body.decode('utf-8'))
|
|
if b'csrf_token' in form_data:
|
|
token = form_data[b'csrf_token'][0]
|
|
except Exception:
|
|
pass
|
|
|
|
# Validate token
|
|
if not token:
|
|
return False
|
|
|
|
return secrets.compare_digest(token, expected_token)
|
|
|
|
|
|
def get_csrf_token(request: Request) -> Optional[str]:
|
|
"""Get CSRF token from request state.
|
|
|
|
Use this in route handlers to pass token to templates.
|
|
|
|
Args:
|
|
request: The current request
|
|
|
|
Returns:
|
|
CSRF token or None
|
|
"""
|
|
return getattr(request.state, "csrf_token", None)
|