211 lines
7.4 KiB
Python
211 lines
7.4 KiB
Python
"""
|
|
Test API endpoints
|
|
"""
|
|
|
|
import pytest
|
|
import requests
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
def test_health_check(client):
|
|
"""Test health endpoint"""
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_get.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "status" in data
|
|
assert data["status"] == "healthy"
|
|
|
|
def test_ready_endpoint(client):
|
|
"""Test readiness probe"""
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_get.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/ready")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "ready"}
|
|
|
|
def test_get_models(client, mock_models_response):
|
|
"""Test getting models list"""
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_models_response
|
|
mock_get.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/models")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "models" in data
|
|
assert "total" in data
|
|
assert data["total"] == 2
|
|
assert len(data["models"]) == 2
|
|
assert data["models"][0]["name"] == "llama2"
|
|
|
|
|
|
def test_get_models_with_host_override(client, mock_models_response):
|
|
"""Test host override is propagated to upstream models API call."""
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_models_response
|
|
mock_get.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/models", params={"host": "http://example-host:11434"})
|
|
assert response.status_code == 200
|
|
assert mock_get.call_args.args[0] == "http://example-host:11434/api/tags"
|
|
|
|
|
|
def test_health_with_invalid_host_returns_422(client):
|
|
"""Invalid host query parameter must be rejected."""
|
|
response = client.get("/api/v1/health", params={"host": "not-a-url"})
|
|
assert response.status_code == 422
|
|
|
|
|
|
def test_model_show_with_invalid_host_returns_422(client):
|
|
"""Invalid host query parameter must be rejected on show endpoint."""
|
|
response = client.get("/api/v1/models/llama2/show", params={"host": "localhost:11434"})
|
|
assert response.status_code == 422
|
|
|
|
|
|
def test_get_running_models(client):
|
|
"""Test getting running models (ollama ps)."""
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"models": [
|
|
{
|
|
"name": "llama3.2:3b",
|
|
"size_vram": 2147483648,
|
|
"expires_at": "2026-04-24T10:30:00Z"
|
|
}
|
|
]
|
|
}
|
|
mock_get.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/models/running")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "models" in data
|
|
assert data["total"] == 1
|
|
assert data["models"][0]["name"] == "llama3.2:3b"
|
|
|
|
|
|
def test_get_running_models_ollama_offline(client):
|
|
"""Test running models when Ollama is offline."""
|
|
with patch("requests.get") as mock_get:
|
|
mock_get.side_effect = Exception("Connection refused")
|
|
|
|
response = client.get("/api/v1/models/running")
|
|
assert response.status_code == 500
|
|
|
|
def test_get_models_ollama_offline(client):
|
|
"""Test getting models when Ollama is offline"""
|
|
with patch("requests.get") as mock_get:
|
|
mock_get.side_effect = requests.exceptions.ConnectionError("Connection refused")
|
|
|
|
response = client.get("/api/v1/models")
|
|
assert response.status_code == 502
|
|
|
|
|
|
def test_get_models_returns_502_when_upstream_is_unavailable(client):
|
|
"""Non-200 upstream response should remain a 502, not be converted to 500."""
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 503
|
|
mock_get.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/models")
|
|
assert response.status_code == 502
|
|
|
|
def test_get_specific_model(client, mock_models_response):
|
|
"""Test getting specific model"""
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_models_response
|
|
mock_get.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/models/llama2")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "llama2"
|
|
|
|
def test_get_nonexistent_model(client, mock_models_response):
|
|
"""Test getting nonexistent model"""
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_models_response
|
|
mock_get.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/models/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
|
|
def test_get_model_show(client):
|
|
"""Test show endpoint for model details."""
|
|
with patch("requests.post") as mock_post:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"details": {
|
|
"family": "llama",
|
|
"parameter_size": "8B"
|
|
},
|
|
"model_info": {
|
|
"general.architecture": "llama"
|
|
}
|
|
}
|
|
mock_post.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/models/llama2/show")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "details" in data
|
|
assert data["details"]["family"] == "llama"
|
|
|
|
|
|
def test_get_model_show_not_found(client):
|
|
"""Test show endpoint when model is not found."""
|
|
with patch("requests.post") as mock_post:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 404
|
|
mock_post.return_value = mock_response
|
|
|
|
response = client.get("/api/v1/models/nonexistent/show")
|
|
assert response.status_code == 404
|
|
|
|
def test_root_endpoint(client):
|
|
"""Test root endpoint redirects to dashboard"""
|
|
response = client.get("/", follow_redirects=False)
|
|
assert response.status_code in [200, 307]
|
|
|
|
def test_openapi_schema(client):
|
|
"""Test OpenAPI schema is available"""
|
|
response = client.get("/openapi.json")
|
|
assert response.status_code == 200
|
|
schema = response.json()
|
|
assert "info" in schema
|
|
assert "paths" in schema
|
|
assert "/api/v1/health" in schema["paths"]
|
|
assert "/api/v1/models" in schema["paths"]
|
|
assert "/api/v1/models/running" in schema["paths"]
|
|
assert "/api/v1/models/{model_name}/show" in schema["paths"]
|
|
assert "/api/v1/models/{model_name}/pull" not in schema["paths"]
|
|
|
|
|
|
def test_write_endpoints_disabled_by_default(client):
|
|
"""POST/DELETE sui modelli devono essere non disponibili di default."""
|
|
response_pull = client.post("/api/v1/models/llama2/pull")
|
|
assert response_pull.status_code == 404
|
|
|
|
response_delete = client.delete("/api/v1/models/llama2")
|
|
assert response_delete.status_code == 404
|