feat(frontend): T46 configure HTMX and CSRF protection
- Add CSRFMiddleware for form protection - Implement token generation and validation - Add CSRF meta tag to base.html - Create tests for CSRF protection Tests: 13 passing
This commit is contained in:
@@ -11,6 +11,7 @@ from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from openrouter_monitor.config import get_settings
|
||||
from openrouter_monitor.middleware.csrf import CSRFMiddleware
|
||||
from openrouter_monitor.routers import api_keys
|
||||
from openrouter_monitor.routers import auth
|
||||
from openrouter_monitor.routers import public_api
|
||||
@@ -50,9 +51,12 @@ app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Mount static files
|
||||
# Mount static files (before CSRF middleware to allow access without token)
|
||||
app.mount("/static", StaticFiles(directory=str(PROJECT_ROOT / "static")), name="static")
|
||||
|
||||
# CSRF protection middleware
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
||||
132
src/openrouter_monitor/middleware/csrf.py
Normal file
132
src/openrouter_monitor/middleware/csrf.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""CSRF Protection Middleware.
|
||||
|
||||
Provides CSRF token generation and validation for form submissions.
|
||||
"""
|
||||
import secrets
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for CSRF protection.
|
||||
|
||||
Generates CSRF tokens for sessions and validates them on
|
||||
state-changing requests (POST, PUT, DELETE, PATCH).
|
||||
"""
|
||||
|
||||
CSRF_TOKEN_NAME = "csrf_token"
|
||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
||||
|
||||
def __init__(self, app, cookie_name: str = "csrf_token", cookie_secure: bool = False):
|
||||
super().__init__(app)
|
||||
self.cookie_name = cookie_name
|
||||
self.cookie_secure = cookie_secure
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request and validate CSRF token if needed.
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
call_next: Next middleware/handler in chain
|
||||
|
||||
Returns:
|
||||
Response from next handler
|
||||
"""
|
||||
# Generate or retrieve CSRF token
|
||||
csrf_token = self._get_or_create_token(request)
|
||||
|
||||
# Validate token on state-changing requests
|
||||
if request.method not in self.SAFE_METHODS:
|
||||
is_valid = await self._validate_token(request, csrf_token)
|
||||
if not is_valid:
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token missing or invalid"}
|
||||
)
|
||||
|
||||
# Store token in request state for templates
|
||||
request.state.csrf_token = csrf_token
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Set CSRF cookie
|
||||
response.set_cookie(
|
||||
key=self.cookie_name,
|
||||
value=csrf_token,
|
||||
httponly=False, # Must be accessible by JavaScript
|
||||
secure=self.cookie_secure,
|
||||
samesite="lax",
|
||||
max_age=3600 * 24 * 7, # 7 days
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _get_or_create_token(self, request: Request) -> str:
|
||||
"""Get existing token from cookie or create new one.
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
|
||||
Returns:
|
||||
CSRF token string
|
||||
"""
|
||||
# Try to get from cookie
|
||||
token = request.cookies.get(self.cookie_name)
|
||||
if token:
|
||||
return token
|
||||
|
||||
# Generate new token
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
async def _validate_token(self, request: Request, expected_token: str) -> bool:
|
||||
"""Validate CSRF token from request.
|
||||
|
||||
Checks header first, then form data.
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
expected_token: Expected token value
|
||||
|
||||
Returns:
|
||||
True if token is valid, False otherwise
|
||||
"""
|
||||
# Check header first (for HTMX/ajax requests)
|
||||
token = request.headers.get(self.CSRF_HEADER_NAME)
|
||||
|
||||
# If not in header, check form data
|
||||
if not token:
|
||||
try:
|
||||
# Parse form data from request body
|
||||
body = await request.body()
|
||||
if body:
|
||||
from urllib.parse import parse_qs
|
||||
form_data = parse_qs(body.decode('utf-8'))
|
||||
if b'csrf_token' in form_data:
|
||||
token = form_data[b'csrf_token'][0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Validate token
|
||||
if not token:
|
||||
return False
|
||||
|
||||
return secrets.compare_digest(token, expected_token)
|
||||
|
||||
|
||||
def get_csrf_token(request: Request) -> Optional[str]:
|
||||
"""Get CSRF token from request state.
|
||||
|
||||
Use this in route handlers to pass token to templates.
|
||||
|
||||
Args:
|
||||
request: The current request
|
||||
|
||||
Returns:
|
||||
CSRF token or None
|
||||
"""
|
||||
return getattr(request.state, "csrf_token", None)
|
||||
Reference in New Issue
Block a user