- Add CSRFMiddleware for form protection - Implement token generation and validation - Add CSRF meta tag to base.html - Create tests for CSRF protection Tests: 13 passing
201 lines
6.5 KiB
Python
201 lines
6.5 KiB
Python
"""Tests for CSRF Protection Middleware.
|
|
|
|
TDD: RED → GREEN → REFACTOR
|
|
"""
|
|
import pytest
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.testclient import TestClient
|
|
|
|
from openrouter_monitor.middleware.csrf import CSRFMiddleware, get_csrf_token
|
|
|
|
|
|
class TestCSRFMiddleware:
|
|
"""Test CSRF middleware functionality."""
|
|
|
|
@pytest.fixture
|
|
def app_with_csrf(self):
|
|
"""Create FastAPI app with CSRF middleware."""
|
|
app = FastAPI()
|
|
app.add_middleware(CSRFMiddleware)
|
|
|
|
@app.get("/test")
|
|
async def test_get(request: Request):
|
|
return {"csrf_token": get_csrf_token(request)}
|
|
|
|
@app.post("/test")
|
|
async def test_post(request: Request):
|
|
return {"message": "success"}
|
|
|
|
@app.put("/test")
|
|
async def test_put(request: Request):
|
|
return {"message": "success"}
|
|
|
|
@app.delete("/test")
|
|
async def test_delete(request: Request):
|
|
return {"message": "success"}
|
|
|
|
return app
|
|
|
|
def test_csrf_cookie_set_on_get_request(self, app_with_csrf):
|
|
"""Test that CSRF cookie is set on GET request."""
|
|
client = TestClient(app_with_csrf)
|
|
response = client.get("/test")
|
|
|
|
assert response.status_code == 200
|
|
assert "csrf_token" in response.cookies
|
|
assert len(response.cookies["csrf_token"]) > 0
|
|
|
|
def test_csrf_token_in_request_state(self, app_with_csrf):
|
|
"""Test that CSRF token is available in request state."""
|
|
client = TestClient(app_with_csrf)
|
|
response = client.get("/test")
|
|
|
|
assert response.status_code == 200
|
|
assert "csrf_token" in response.json()
|
|
assert response.json()["csrf_token"] == response.cookies["csrf_token"]
|
|
|
|
def test_post_without_csrf_token_fails(self, app_with_csrf):
|
|
"""Test that POST without CSRF token returns 403."""
|
|
client = TestClient(app_with_csrf)
|
|
response = client.post("/test")
|
|
|
|
assert response.status_code == 403
|
|
assert "CSRF" in response.json()["detail"]
|
|
|
|
def test_post_with_csrf_header_succeeds(self, app_with_csrf):
|
|
"""Test that POST with CSRF header succeeds."""
|
|
client = TestClient(app_with_csrf)
|
|
|
|
# First get a CSRF token
|
|
get_response = client.get("/test")
|
|
csrf_token = get_response.cookies["csrf_token"]
|
|
|
|
# Use token in POST request
|
|
response = client.post(
|
|
"/test",
|
|
headers={"X-CSRF-Token": csrf_token}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["message"] == "success"
|
|
|
|
def test_put_without_csrf_token_fails(self, app_with_csrf):
|
|
"""Test that PUT without CSRF token returns 403."""
|
|
client = TestClient(app_with_csrf)
|
|
response = client.put("/test")
|
|
|
|
assert response.status_code == 403
|
|
|
|
def test_put_with_csrf_header_succeeds(self, app_with_csrf):
|
|
"""Test that PUT with CSRF header succeeds."""
|
|
client = TestClient(app_with_csrf)
|
|
|
|
# Get CSRF token
|
|
get_response = client.get("/test")
|
|
csrf_token = get_response.cookies["csrf_token"]
|
|
|
|
response = client.put(
|
|
"/test",
|
|
headers={"X-CSRF-Token": csrf_token}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
def test_delete_without_csrf_token_fails(self, app_with_csrf):
|
|
"""Test that DELETE without CSRF token returns 403."""
|
|
client = TestClient(app_with_csrf)
|
|
response = client.delete("/test")
|
|
|
|
assert response.status_code == 403
|
|
|
|
def test_delete_with_csrf_header_succeeds(self, app_with_csrf):
|
|
"""Test that DELETE with CSRF header succeeds."""
|
|
client = TestClient(app_with_csrf)
|
|
|
|
# Get CSRF token
|
|
get_response = client.get("/test")
|
|
csrf_token = get_response.cookies["csrf_token"]
|
|
|
|
response = client.delete(
|
|
"/test",
|
|
headers={"X-CSRF-Token": csrf_token}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
def test_safe_methods_without_csrf_succeed(self, app_with_csrf):
|
|
"""Test that GET, HEAD, OPTIONS work without CSRF token."""
|
|
client = TestClient(app_with_csrf)
|
|
|
|
response = client.get("/test")
|
|
assert response.status_code == 200
|
|
|
|
def test_invalid_csrf_token_fails(self, app_with_csrf):
|
|
"""Test that invalid CSRF token returns 403."""
|
|
client = TestClient(app_with_csrf)
|
|
|
|
response = client.post(
|
|
"/test",
|
|
headers={"X-CSRF-Token": "invalid-token"}
|
|
)
|
|
|
|
assert response.status_code == 403
|
|
|
|
def test_csrf_token_persists_across_requests(self, app_with_csrf):
|
|
"""Test that CSRF token persists across requests."""
|
|
client = TestClient(app_with_csrf)
|
|
|
|
# First request
|
|
response1 = client.get("/test")
|
|
token1 = response1.cookies["csrf_token"]
|
|
|
|
# Second request
|
|
response2 = client.get("/test")
|
|
token2 = response2.cookies["csrf_token"]
|
|
|
|
# Tokens should be the same
|
|
assert token1 == token2
|
|
|
|
|
|
class TestCSRFTokenGeneration:
|
|
"""Test CSRF token generation."""
|
|
|
|
def test_token_has_sufficient_entropy(self):
|
|
"""Test that generated tokens have sufficient entropy."""
|
|
from openrouter_monitor.middleware.csrf import CSRFMiddleware
|
|
|
|
app = FastAPI()
|
|
middleware = CSRFMiddleware(app)
|
|
|
|
# Create a mock request without cookie
|
|
class MockRequest:
|
|
def __init__(self):
|
|
self.cookies = {}
|
|
|
|
request = MockRequest()
|
|
token = middleware._get_or_create_token(request)
|
|
|
|
# Token should be at least 32 characters (urlsafe base64 of 24 bytes)
|
|
assert len(token) >= 32
|
|
|
|
def test_token_is_unique(self):
|
|
"""Test that generated tokens are unique."""
|
|
from openrouter_monitor.middleware.csrf import CSRFMiddleware
|
|
|
|
app = FastAPI()
|
|
middleware = CSRFMiddleware(app)
|
|
|
|
class MockRequest:
|
|
def __init__(self):
|
|
self.cookies = {}
|
|
|
|
tokens = set()
|
|
for _ in range(10):
|
|
request = MockRequest()
|
|
token = middleware._get_or_create_token(request)
|
|
tokens.add(token)
|
|
|
|
# All tokens should be unique
|
|
assert len(tokens) == 10
|