"""Security headers and CORS middleware. Implements security hardening: - HSTS (HTTP Strict Transport Security) - CSP (Content Security Policy) - X-Frame-Options - CORS strict configuration - Additional security headers """ from typing import Optional from fastapi import Request, Response from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware from src.core.config import settings # Security headers configuration SECURITY_HEADERS = { # HTTP Strict Transport Security "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", # Content Security Policy "Content-Security-Policy": ( "default-src 'self'; " "script-src 'self' 'unsafe-inline' 'unsafe-eval'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: https:; " "font-src 'self' data:; " "connect-src 'self' https:; " "frame-ancestors 'none'; " "base-uri 'self'; " "form-action 'self';" ), # X-Frame-Options "X-Frame-Options": "DENY", # X-Content-Type-Options "X-Content-Type-Options": "nosniff", # Referrer Policy "Referrer-Policy": "strict-origin-when-cross-origin", # Permissions Policy "Permissions-Policy": ( "accelerometer=(), " "camera=(), " "geolocation=(), " "gyroscope=(), " "magnetometer=(), " "microphone=(), " "payment=(), " "usb=()" ), # X-XSS-Protection (legacy browsers) "X-XSS-Protection": "1; mode=block", # Cache control for sensitive data "Cache-Control": "no-store, max-age=0", } class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Middleware to add security headers to all responses.""" async def dispatch(self, request: Request, call_next): response = await call_next(request) # Add security headers for header, value in SECURITY_HEADERS.items(): response.headers[header] = value return response class CORSSecurityMiddleware: """CORS middleware with strict security configuration.""" @staticmethod def get_middleware(): """Get CORS middleware with strict configuration.""" # Get allowed origins from settings allowed_origins = getattr( settings, "cors_allowed_origins", ["http://localhost:3000", "http://localhost:5173"], ) # In production, enforce strict origin checking if not getattr(settings, "debug", False): allowed_origins = getattr( settings, "cors_allowed_origins_production", allowed_origins, ) return CORSMiddleware( allow_origins=allowed_origins, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], allow_headers=[ "Authorization", "Content-Type", "X-Request-ID", "X-Correlation-ID", "X-API-Key", "X-Scenario-ID", ], expose_headers=[ "X-Request-ID", "X-Correlation-ID", "X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset", ], max_age=600, # 10 minutes ) # Content Security Policy for different contexts CSP_POLICIES = { "default": SECURITY_HEADERS["Content-Security-Policy"], "api": ("default-src 'none'; frame-ancestors 'none'; base-uri 'none';"), "reports": ( "default-src 'self'; " "script-src 'self'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data:; " "frame-ancestors 'none';" ), } def get_csp_header(context: str = "default") -> str: """Get Content Security Policy for specific context. Args: context: Context type (default, api, reports) Returns: CSP header value """ return CSP_POLICIES.get(context, CSP_POLICIES["default"]) class SecurityContextMiddleware(BaseHTTPMiddleware): """Middleware to add context-aware security headers.""" async def dispatch(self, request: Request, call_next): response = await call_next(request) # Determine context based on path path = request.url.path if path.startswith("/api/"): context = "api" elif path.startswith("/reports/"): context = "reports" else: context = "default" # Set context-specific CSP response.headers["Content-Security-Policy"] = get_csp_header(context) return response # Input validation security class InputValidator: """Input validation helpers for security.""" # Maximum allowed sizes MAX_STRING_LENGTH = 10000 MAX_JSON_SIZE = 1024 * 1024 # 1MB MAX_QUERY_PARAMS = 50 MAX_HEADER_SIZE = 8192 # 8KB @classmethod def validate_string( cls, value: str, field_name: str, max_length: Optional[int] = None ) -> str: """Validate string input. Args: value: String value to validate field_name: Name of the field for error messages max_length: Maximum allowed length Returns: Validated string Raises: ValueError: If validation fails """ max_len = max_length or cls.MAX_STRING_LENGTH if not isinstance(value, str): raise ValueError(f"{field_name} must be a string") if len(value) > max_len: raise ValueError(f"{field_name} exceeds maximum length of {max_len}") # Check for potential XSS if cls._contains_xss_patterns(value): raise ValueError(f"{field_name} contains invalid characters") return value @classmethod def _contains_xss_patterns(cls, value: str) -> bool: """Check if string contains potential XSS patterns.""" xss_patterns = [ " str: """Sanitize HTML content to prevent XSS. Args: value: HTML string to sanitize Returns: Sanitized string """ import html # Escape HTML entities sanitized = html.escape(value) return sanitized def setup_security_middleware(app): """Setup all security middleware for FastAPI app. Args: app: FastAPI application instance """ # Note: CORS middleware is configured in main.py # Add security headers middleware app.add_middleware(SecurityHeadersMiddleware) # Add context-aware security middleware app.add_middleware(SecurityContextMiddleware)