"""Tiered rate limiting for API v2. Implements rate limiting with different tiers: - Free tier: 100 requests/minute - Premium tier: 1000 requests/minute - Enterprise tier: 10000 requests/minute Supports burst allowances and per-API-key limits. """ from typing import Optional from datetime import datetime from fastapi import Request, HTTPException, status from src.core.cache import cache_manager from src.core.logging_config import get_logger logger = get_logger(__name__) class RateLimitConfig: """Rate limit configuration per tier.""" TIERS = { "free": { "requests_per_minute": 100, "burst": 10, }, "premium": { "requests_per_minute": 1000, "burst": 50, }, "enterprise": { "requests_per_minute": 10000, "burst": 200, }, } class RateLimiter: """Simple in-memory rate limiter (use Redis in production).""" def __init__(self): self._storage = {} def _get_key(self, identifier: str, window: int = 60) -> str: """Generate rate limit key.""" timestamp = int(datetime.utcnow().timestamp()) // window return f"ratelimit:{identifier}:{timestamp}" async def is_allowed( self, identifier: str, limit: int, window: int = 60, ) -> tuple[bool, dict]: """Check if request is allowed. Returns: Tuple of (allowed, headers) """ key = self._get_key(identifier, window) try: # Try to use Redis await cache_manager.initialize() current = await cache_manager.redis.incr(key) if current == 1: # Set expiration on first request await cache_manager.redis.expire(key, window) remaining = max(0, limit - current) reset_time = (int(datetime.utcnow().timestamp()) // window + 1) * window headers = { "X-RateLimit-Limit": str(limit), "X-RateLimit-Remaining": str(remaining), "X-RateLimit-Reset": str(reset_time), } allowed = current <= limit return allowed, headers except Exception as e: # Fallback: allow request if Redis unavailable logger.warning(f"Rate limiting unavailable: {e}") return True, {} class TieredRateLimit: """Tiered rate limiting with burst support.""" def __init__(self): self.limiter = RateLimiter() def _get_client_identifier( self, request: Request, api_key: Optional[str] = None, ) -> str: """Get client identifier from request.""" if api_key: return f"apikey:{api_key}" # Use IP address as fallback forwarded = request.headers.get("X-Forwarded-For") if forwarded: return f"ip:{forwarded.split(',')[0].strip()}" client_host = request.client.host if request.client else "unknown" return f"ip:{client_host}" def _get_tier_for_key(self, api_key: Optional[str]) -> str: """Determine tier for API key. In production, this would lookup the tier from database. """ if not api_key: return "free" # For demo purposes, keys starting with 'mk_premium' are premium tier if api_key.startswith("mk_premium"): return "premium" elif api_key.startswith("mk_enterprise"): return "enterprise" return "free" async def check_rate_limit( self, request: Request, api_key: Optional[str] = None, tier: Optional[str] = None, burst: Optional[int] = None, ) -> dict: """Check rate limit and raise exception if exceeded. Args: request: FastAPI request object api_key: Optional API key tier: Override tier (free/premium/enterprise) burst: Override burst limit Returns: Rate limit headers Raises: HTTPException: If rate limit exceeded """ # Determine tier client_tier = tier or self._get_tier_for_key(api_key) config = RateLimitConfig.TIERS.get(client_tier, RateLimitConfig.TIERS["free"]) # Get client identifier identifier = self._get_client_identifier(request, api_key) # Calculate limit with burst limit = config["requests_per_minute"] if burst is not None: limit = burst # Check rate limit allowed, headers = await self.limiter.is_allowed(identifier, limit) if not allowed: logger.warning( "Rate limit exceeded", extra={ "identifier": identifier, "tier": client_tier, "limit": limit, }, ) raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded. Please try again later.", headers={ **headers, "Retry-After": "60", }, ) # Store headers in request state for middleware request.state.rate_limit_headers = headers return headers class RateLimitMiddleware: """Middleware to add rate limit headers to responses.""" def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): if scope["type"] != "http": await self.app(scope, receive, send) return from fastapi import Request request = Request(scope, receive) # Store original send original_send = send async def wrapped_send(message): if message["type"] == "http.response.start": # Add rate limit headers if available if hasattr(request.state, "rate_limit_headers"): headers = message.get("headers", []) for key, value in request.state.rate_limit_headers.items(): headers.append([key.encode(), value.encode()]) message["headers"] = headers await original_send(message) await self.app(scope, receive, wrapped_send)