feat: add support for local LLM providers (Ollama & LM Studio)
Implement local LLM inference support for Ollama and LM Studio: New Clients: - OllamaClient: Interface to Ollama API (default: localhost:11434) - LMStudioClient: Interface to LM Studio API (default: localhost:1234) Factory Updates: - Added OLLAMA and LMSTUDIO to LLMProvider enum - Updated create_client() to instantiate local clients - Updated list_available_providers() with is_local flag Configuration: - Added ollama_base_url and lmstudio_base_url settings - Local providers return configured for API key check Tests: - Comprehensive test suite (250+ lines) - Tests for client initialization and invocation - Factory integration tests Documentation: - Added LLM Providers section to SKILL.md - Documented setup for Ollama and LM Studio - Added usage examples and configuration guide Usage: provider: ollama, model: llama3.2 provider: lmstudio, model: local-model
This commit is contained in:
248
tests/unit/test_agentic_rag/test_core/test_local_providers.py
Normal file
248
tests/unit/test_agentic_rag/test_core/test_local_providers.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Tests for local LLM providers (Ollama and LMStudio)."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
class TestOllamaClient:
|
||||
"""Test suite for Ollama client."""
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_client(self):
|
||||
"""Create an Ollama client instance."""
|
||||
from agentic_rag.core.llm_factory import OllamaClient
|
||||
|
||||
return OllamaClient(api_key="", model="llama3.2", base_url="http://localhost:11434")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_client_initialization(self, ollama_client):
|
||||
"""Test Ollama client initialization."""
|
||||
assert ollama_client.model == "llama3.2"
|
||||
assert ollama_client.base_url == "http://localhost:11434"
|
||||
assert ollama_client.api_key == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_client_custom_base_url(self):
|
||||
"""Test Ollama client with custom base URL."""
|
||||
from agentic_rag.core.llm_factory import OllamaClient
|
||||
|
||||
client = OllamaClient(api_key="", model="mistral", base_url="http://192.168.1.100:11434")
|
||||
assert client.base_url == "http://192.168.1.100:11434"
|
||||
assert client.model == "mistral"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_invoke_success(self, ollama_client):
|
||||
"""Test successful Ollama API call."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Test response from Ollama"}}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
ollama_client.client = MagicMock()
|
||||
ollama_client.client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await ollama_client.invoke("Hello, how are you?")
|
||||
|
||||
assert result.text == "Test response from Ollama"
|
||||
assert result.model == "llama3.2"
|
||||
assert result.usage["prompt_tokens"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_invoke_with_kwargs(self, ollama_client):
|
||||
"""Test Ollama API call with additional kwargs."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Response"}}],
|
||||
"usage": {},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
ollama_client.client = MagicMock()
|
||||
ollama_client.client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
await ollama_client.invoke("Hello", temperature=0.7, max_tokens=100)
|
||||
|
||||
# Verify the call was made with correct parameters
|
||||
call_args = ollama_client.client.post.call_args
|
||||
assert call_args[0][0] == "/v1/chat/completions"
|
||||
json_data = call_args[1]["json"]
|
||||
assert json_data["temperature"] == 0.7
|
||||
assert json_data["max_tokens"] == 100
|
||||
|
||||
|
||||
class TestLMStudioClient:
|
||||
"""Test suite for LM Studio client."""
|
||||
|
||||
@pytest.fixture
|
||||
def lmstudio_client(self):
|
||||
"""Create an LM Studio client instance."""
|
||||
from agentic_rag.core.llm_factory import LMStudioClient
|
||||
|
||||
return LMStudioClient(api_key="", model="local-model", base_url="http://localhost:1234")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lmstudio_client_initialization(self, lmstudio_client):
|
||||
"""Test LM Studio client initialization."""
|
||||
assert lmstudio_client.model == "local-model"
|
||||
assert lmstudio_client.base_url == "http://localhost:1234"
|
||||
assert lmstudio_client.api_key == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lmstudio_client_custom_base_url(self):
|
||||
"""Test LM Studio client with custom base URL."""
|
||||
from agentic_rag.core.llm_factory import LMStudioClient
|
||||
|
||||
client = LMStudioClient(
|
||||
api_key="", model="custom-model", base_url="http://192.168.1.50:1234"
|
||||
)
|
||||
assert client.base_url == "http://192.168.1.50:1234"
|
||||
assert client.model == "custom-model"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lmstudio_invoke_success(self, lmstudio_client):
|
||||
"""Test successful LM Studio API call."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Test response from LM Studio"}}],
|
||||
"usage": {"prompt_tokens": 15, "completion_tokens": 25},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
lmstudio_client.client = MagicMock()
|
||||
lmstudio_client.client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await lmstudio_client.invoke("What is AI?")
|
||||
|
||||
assert result.text == "Test response from LM Studio"
|
||||
assert result.model == "local-model"
|
||||
assert result.usage["prompt_tokens"] == 15
|
||||
|
||||
|
||||
class TestLLMClientFactoryLocalProviders:
|
||||
"""Test factory integration for local providers."""
|
||||
|
||||
def test_ollama_provider_creation(self):
|
||||
"""Test creating Ollama client via factory."""
|
||||
from agentic_rag.core.llm_factory import LLMClientFactory, LLMProvider
|
||||
|
||||
client = LLMClientFactory.create_client(
|
||||
provider=LLMProvider.OLLAMA,
|
||||
api_key="",
|
||||
model="llama3.2",
|
||||
base_url="http://localhost:11434",
|
||||
)
|
||||
|
||||
from agentic_rag.core.llm_factory import OllamaClient
|
||||
|
||||
assert isinstance(client, OllamaClient)
|
||||
assert client.model == "llama3.2"
|
||||
|
||||
def test_lmstudio_provider_creation(self):
|
||||
"""Test creating LM Studio client via factory."""
|
||||
from agentic_rag.core.llm_factory import LLMClientFactory, LLMProvider
|
||||
|
||||
client = LLMClientFactory.create_client(
|
||||
provider=LLMProvider.LMSTUDIO,
|
||||
api_key="",
|
||||
model="qwen2.5",
|
||||
base_url="http://localhost:1234",
|
||||
)
|
||||
|
||||
from agentic_rag.core.llm_factory import LMStudioClient
|
||||
|
||||
assert isinstance(client, LMStudioClient)
|
||||
assert client.model == "qwen2.5"
|
||||
|
||||
def test_list_providers_includes_local(self):
|
||||
"""Test that local providers are listed."""
|
||||
from agentic_rag.core.llm_factory import LLMClientFactory
|
||||
|
||||
providers = LLMClientFactory.list_available_providers()
|
||||
provider_ids = [p["id"] for p in providers]
|
||||
|
||||
assert "ollama" in provider_ids
|
||||
assert "lmstudio" in provider_ids
|
||||
|
||||
# Check they are marked as local
|
||||
ollama_info = next(p for p in providers if p["id"] == "ollama")
|
||||
lmstudio_info = next(p for p in providers if p["id"] == "lmstudio")
|
||||
|
||||
assert ollama_info["is_local"] == True
|
||||
assert lmstudio_info["is_local"] == True
|
||||
assert "download" in ollama_info["install_command"]
|
||||
assert "download" in lmstudio_info["install_command"]
|
||||
|
||||
def test_default_models_include_local(self):
|
||||
"""Test default models for local providers."""
|
||||
from agentic_rag.core.llm_factory import LLMClientFactory
|
||||
|
||||
defaults = LLMClientFactory.get_default_models()
|
||||
|
||||
assert defaults["ollama"] == "llama3.2"
|
||||
assert defaults["lmstudio"] == "local-model"
|
||||
|
||||
|
||||
class TestConfigLocalProviders:
|
||||
"""Test configuration for local providers."""
|
||||
|
||||
def test_ollama_base_url_config(self):
|
||||
"""Test Ollama base URL configuration."""
|
||||
from agentic_rag.core.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
assert settings.ollama_base_url == "http://localhost:11434"
|
||||
|
||||
# Test custom URL
|
||||
settings_custom = Settings(ollama_base_url="http://192.168.1.100:11434")
|
||||
assert settings_custom.ollama_base_url == "http://192.168.1.100:11434"
|
||||
|
||||
def test_lmstudio_base_url_config(self):
|
||||
"""Test LM Studio base URL configuration."""
|
||||
from agentic_rag.core.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
assert settings.lmstudio_base_url == "http://localhost:1234"
|
||||
|
||||
def test_get_api_key_for_local_providers(self):
|
||||
"""Test API key retrieval for local providers."""
|
||||
from agentic_rag.core.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Local providers should return "configured" instead of empty string
|
||||
assert settings.get_api_key_for_provider("ollama") == "configured"
|
||||
assert settings.get_api_key_for_provider("lmstudio") == "configured"
|
||||
|
||||
# They should be considered configured
|
||||
assert settings.is_provider_configured("ollama") == True
|
||||
assert settings.is_provider_configured("lmstudio") == True
|
||||
|
||||
|
||||
class TestIntegrationLocalProviders:
|
||||
"""Integration tests for local providers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_ollama(self):
|
||||
"""Test getting Ollama client via get_llm_client."""
|
||||
from agentic_rag.core.llm_factory import get_llm_client
|
||||
|
||||
with patch("agentic_rag.core.llm_factory._client_cache", {}):
|
||||
client = await get_llm_client(provider="ollama")
|
||||
|
||||
from agentic_rag.core.llm_factory import OllamaClient
|
||||
|
||||
assert isinstance(client, OllamaClient)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_llm_client_lmstudio(self):
|
||||
"""Test getting LM Studio client via get_llm_client."""
|
||||
from agentic_rag.core.llm_factory import get_llm_client
|
||||
|
||||
with patch("agentic_rag.core.llm_factory._client_cache", {}):
|
||||
client = await get_llm_client(provider="lmstudio")
|
||||
|
||||
from agentic_rag.core.llm_factory import LMStudioClient
|
||||
|
||||
assert isinstance(client, LMStudioClient)
|
||||
Reference in New Issue
Block a user