feat: add support for local LLM providers (Ollama & LM Studio)
Implement local LLM inference support for Ollama and LM Studio: New Clients: - OllamaClient: Interface to Ollama API (default: localhost:11434) - LMStudioClient: Interface to LM Studio API (default: localhost:1234) Factory Updates: - Added OLLAMA and LMSTUDIO to LLMProvider enum - Updated create_client() to instantiate local clients - Updated list_available_providers() with is_local flag Configuration: - Added ollama_base_url and lmstudio_base_url settings - Local providers return configured for API key check Tests: - Comprehensive test suite (250+ lines) - Tests for client initialization and invocation - Factory integration tests Documentation: - Added LLM Providers section to SKILL.md - Documented setup for Ollama and LM Studio - Added usage examples and configuration guide Usage: provider: ollama, model: llama3.2 provider: lmstudio, model: local-model
This commit is contained in:
154
SKILL.md
154
SKILL.md
@@ -602,7 +602,113 @@ curl http://localhost:8000/api/v1/notebooks -H "X-API-Key: your-key"
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**Skill Version:** 1.2.0
|
## LLM Providers
|
||||||
|
|
||||||
|
DocuMente supporta molteplici provider LLM, inclusi quelli locali tramite **Ollama** e **LM Studio**.
|
||||||
|
|
||||||
|
### Provider Cloud
|
||||||
|
|
||||||
|
| Provider | API Key Richiesta | Default Model |
|
||||||
|
|----------|------------------|---------------|
|
||||||
|
| **OpenAI** | ✅ `OPENAI_API_KEY` | gpt-4o-mini |
|
||||||
|
| **Anthropic** | ✅ `ANTHROPIC_API_KEY` | claude-3-sonnet |
|
||||||
|
| **Google** | ✅ `GOOGLE_API_KEY` | gemini-pro |
|
||||||
|
| **Mistral** | ✅ `MISTRAL_API_KEY` | mistral-medium |
|
||||||
|
| **Azure** | ✅ `AZURE_API_KEY` | gpt-4 |
|
||||||
|
| **OpenRouter** | ✅ `OPENROUTER_API_KEY` | openai/gpt-4o-mini |
|
||||||
|
| **Z.AI** | ✅ `ZAI_API_KEY` | zai-large |
|
||||||
|
| **OpenCode Zen** | ✅ `OPENCODE_ZEN_API_KEY` | zen-1 |
|
||||||
|
|
||||||
|
### Provider Locali
|
||||||
|
|
||||||
|
| Provider | URL Default | Configurazione |
|
||||||
|
|----------|-------------|----------------|
|
||||||
|
| **Ollama** | http://localhost:11434 | `OLLAMA_BASE_URL` |
|
||||||
|
| **LM Studio** | http://localhost:1234 | `LMSTUDIO_BASE_URL` |
|
||||||
|
|
||||||
|
#### Setup Ollama
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Installa Ollama
|
||||||
|
# macOS/Linux
|
||||||
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
|
||||||
|
# 2. Scarica un modello
|
||||||
|
ollama pull llama3.2
|
||||||
|
ollama pull mistral
|
||||||
|
ollama pull qwen2.5
|
||||||
|
|
||||||
|
# 3. Avvia Ollama (in un terminale separato)
|
||||||
|
ollama serve
|
||||||
|
|
||||||
|
# 4. Verifica che sia in esecuzione
|
||||||
|
curl http://localhost:11434/api/tags
|
||||||
|
```
|
||||||
|
|
||||||
|
**Uso con DocuMente:**
|
||||||
|
```bash
|
||||||
|
# Query con Ollama
|
||||||
|
curl -X POST http://localhost:8000/api/v1/query \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"question": "Spiega l\'intelligenza artificiale",
|
||||||
|
"provider": "ollama",
|
||||||
|
"model": "llama3.2"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Setup LM Studio
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Scarica LM Studio da https://lmstudio.ai/
|
||||||
|
|
||||||
|
# 2. Avvia LM Studio e carica un modello
|
||||||
|
|
||||||
|
# 3. Attiva il server locale (Settings > Local Server)
|
||||||
|
# Default URL: http://localhost:1234
|
||||||
|
|
||||||
|
# 4. Verifica che sia in esecuzione
|
||||||
|
curl http://localhost:1234/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
**Uso con DocuMente:**
|
||||||
|
```bash
|
||||||
|
# Query con LM Studio
|
||||||
|
curl -X POST http://localhost:8000/api/v1/query \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"question": "Cosa sono i notebook?",
|
||||||
|
"provider": "lmstudio",
|
||||||
|
"model": "local-model"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configurazione URL Personalizzato
|
||||||
|
|
||||||
|
Per usare Ollama/LM Studio su un'altra macchina nella rete:
|
||||||
|
|
||||||
|
```env
|
||||||
|
# .env
|
||||||
|
OLLAMA_BASE_URL=http://192.168.1.100:11434
|
||||||
|
LMSTUDIO_BASE_URL=http://192.168.1.50:1234
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Vantaggi dei Provider Locali
|
||||||
|
|
||||||
|
- 🔒 **Privacy**: I dati non lasciano il tuo computer/rete
|
||||||
|
- 💰 **Gratuito**: Nessun costo per API call
|
||||||
|
- ⚡ **Offline**: Funziona senza connessione internet
|
||||||
|
- 🔧 **Controllo**: Scegli tu quali modelli usare
|
||||||
|
|
||||||
|
#### Limitazioni
|
||||||
|
|
||||||
|
- Richiedono hardware adeguato (RAM, GPU consigliata)
|
||||||
|
- I modelli locali sono generalmente meno potenti di GPT-4/Claude
|
||||||
|
- Tempo di risposta più lungo su hardware consumer
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Skill Version:** 1.3.0
|
||||||
**API Version:** v1
|
**API Version:** v1
|
||||||
**Last Updated:** 2026-04-06
|
**Last Updated:** 2026-04-06
|
||||||
|
|
||||||
@@ -667,3 +773,49 @@ curl http://localhost:8000/api/v1/notebooks -H "X-API-Key: your-key"
|
|||||||
- ✅ Created docs/integration.md with full guide
|
- ✅ Created docs/integration.md with full guide
|
||||||
- ✅ Updated SKILL.md with new capabilities
|
- ✅ Updated SKILL.md with new capabilities
|
||||||
- ✅ API examples and best practices
|
- ✅ API examples and best practices
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Changelog Sprint 3
|
||||||
|
|
||||||
|
### 2026-04-06 - Local LLM Providers (Ollama & LM Studio)
|
||||||
|
|
||||||
|
**Implemented:**
|
||||||
|
- ✅ `OllamaClient` - Support for Ollama local inference
|
||||||
|
- ✅ `LMStudioClient` - Support for LM Studio local inference
|
||||||
|
- ✅ Added `ollama` and `lmstudio` to `LLMProvider` enum
|
||||||
|
- ✅ Updated `LLMClientFactory` to create local provider clients
|
||||||
|
- ✅ Added configuration options `OLLAMA_BASE_URL` and `LMSTUDIO_BASE_URL`
|
||||||
|
- ✅ Local providers marked with `is_local: true` in provider list
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- OpenAI-compatible API endpoints (/v1/chat/completions)
|
||||||
|
- Configurable base URLs for network deployments
|
||||||
|
- Longer timeouts (120s) for local inference
|
||||||
|
- No API key required for local providers
|
||||||
|
- Support for all Ollama models (llama3.2, mistral, qwen, etc.)
|
||||||
|
- Support for any model loaded in LM Studio
|
||||||
|
|
||||||
|
**Configuration:**
|
||||||
|
```env
|
||||||
|
# Optional: Custom URLs
|
||||||
|
OLLAMA_BASE_URL=http://localhost:11434
|
||||||
|
LMSTUDIO_BASE_URL=http://localhost:1234
|
||||||
|
```
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/query \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"question": "Explain AI",
|
||||||
|
"provider": "ollama",
|
||||||
|
"model": "llama3.2"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Tests:**
|
||||||
|
- ✅ 250+ lines of tests for local providers
|
||||||
|
- ✅ Unit tests for OllamaClient and LMStudioClient
|
||||||
|
- ✅ Integration tests for factory creation
|
||||||
|
- ✅ Configuration tests
|
||||||
|
|||||||
@@ -51,6 +51,10 @@ class Settings(BaseSettings):
|
|||||||
azure_endpoint: str = "" # Azure OpenAI endpoint
|
azure_endpoint: str = "" # Azure OpenAI endpoint
|
||||||
azure_api_version: str = "2024-02-01"
|
azure_api_version: str = "2024-02-01"
|
||||||
|
|
||||||
|
# Local LLM providers (Ollama, LMStudio)
|
||||||
|
ollama_base_url: str = "http://localhost:11434"
|
||||||
|
lmstudio_base_url: str = "http://localhost:1234"
|
||||||
|
|
||||||
# Embedding Configuration
|
# Embedding Configuration
|
||||||
embedding_provider: str = "openai"
|
embedding_provider: str = "openai"
|
||||||
embedding_model: str = "text-embedding-3-small"
|
embedding_model: str = "text-embedding-3-small"
|
||||||
@@ -67,10 +71,10 @@ class Settings(BaseSettings):
|
|||||||
"""Get the API key for a specific provider.
|
"""Get the API key for a specific provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: Provider name (e.g., 'openai', 'zai', 'openrouter')
|
provider: Provider name (e.g., 'openai', 'zai', 'openrouter', 'ollama', 'lmstudio')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
API key for the provider
|
API key for the provider (or base_url for local providers)
|
||||||
"""
|
"""
|
||||||
key_mapping = {
|
key_mapping = {
|
||||||
"openai": self.openai_api_key,
|
"openai": self.openai_api_key,
|
||||||
@@ -83,6 +87,8 @@ class Settings(BaseSettings):
|
|||||||
"google": self.google_api_key,
|
"google": self.google_api_key,
|
||||||
"mistral": self.mistral_api_key,
|
"mistral": self.mistral_api_key,
|
||||||
"azure": self.azure_api_key,
|
"azure": self.azure_api_key,
|
||||||
|
"ollama": "configured", # Local provider, no API key needed
|
||||||
|
"lmstudio": "configured", # Local provider, no API key needed
|
||||||
}
|
}
|
||||||
|
|
||||||
return key_mapping.get(provider.lower(), "")
|
return key_mapping.get(provider.lower(), "")
|
||||||
|
|||||||
@@ -45,6 +45,8 @@ class LLMProvider(str, Enum):
|
|||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
AZURE = "azure"
|
AZURE = "azure"
|
||||||
|
OLLAMA = "ollama"
|
||||||
|
LMSTUDIO = "lmstudio"
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMClient(ABC):
|
class BaseLLMClient(ABC):
|
||||||
@@ -162,6 +164,106 @@ class OpenRouterClient(BaseLLMClient):
|
|||||||
)()
|
)()
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaClient(BaseLLMClient):
|
||||||
|
"""Ollama client for local LLM inference.
|
||||||
|
|
||||||
|
Ollama runs locally and provides OpenAI-compatible API.
|
||||||
|
Default URL: http://localhost:11434
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = "",
|
||||||
|
model: str = "llama3.2",
|
||||||
|
base_url: str = "http://localhost:11434",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(api_key, model, **kwargs)
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
self.client = httpx.AsyncClient(
|
||||||
|
base_url=self.base_url,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=120.0, # Longer timeout for local inference
|
||||||
|
)
|
||||||
|
|
||||||
|
async def invoke(self, prompt: str, **kwargs) -> Any:
|
||||||
|
"""Call Ollama API."""
|
||||||
|
# Ollama uses OpenAI-compatible endpoints
|
||||||
|
response = await self.client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"stream": False,
|
||||||
|
**kwargs,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
return type(
|
||||||
|
"Response",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"text": data["choices"][0]["message"]["content"],
|
||||||
|
"model": self.model,
|
||||||
|
"usage": data.get("usage", {}),
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
|
|
||||||
|
class LMStudioClient(BaseLLMClient):
|
||||||
|
"""LM Studio client for local LLM inference.
|
||||||
|
|
||||||
|
LM Studio runs locally and provides OpenAI-compatible API.
|
||||||
|
Default URL: http://localhost:1234
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = "",
|
||||||
|
model: str = "local-model",
|
||||||
|
base_url: str = "http://localhost:1234",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(api_key, model, **kwargs)
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
self.client = httpx.AsyncClient(
|
||||||
|
base_url=self.base_url,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=120.0, # Longer timeout for local inference
|
||||||
|
)
|
||||||
|
|
||||||
|
async def invoke(self, prompt: str, **kwargs) -> Any:
|
||||||
|
"""Call LM Studio API."""
|
||||||
|
# LM Studio uses OpenAI-compatible endpoints
|
||||||
|
response = await self.client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"stream": False,
|
||||||
|
**kwargs,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
return type(
|
||||||
|
"Response",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"text": data["choices"][0]["message"]["content"],
|
||||||
|
"model": self.model,
|
||||||
|
"usage": data.get("usage", {}),
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
|
|
||||||
class LLMClientFactory:
|
class LLMClientFactory:
|
||||||
"""Factory for creating LLM clients based on provider."""
|
"""Factory for creating LLM clients based on provider."""
|
||||||
|
|
||||||
@@ -224,6 +326,22 @@ class LLMClientFactory:
|
|||||||
elif provider == LLMProvider.OPENROUTER:
|
elif provider == LLMProvider.OPENROUTER:
|
||||||
return OpenRouterClient(api_key=api_key, model=model or "openai/gpt-4o-mini", **kwargs)
|
return OpenRouterClient(api_key=api_key, model=model or "openai/gpt-4o-mini", **kwargs)
|
||||||
|
|
||||||
|
elif provider == LLMProvider.OLLAMA:
|
||||||
|
return OllamaClient(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model or "llama3.2",
|
||||||
|
base_url=kwargs.get("base_url", "http://localhost:11434"),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif provider == LLMProvider.LMSTUDIO:
|
||||||
|
return LMStudioClient(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model or "local-model",
|
||||||
|
base_url=kwargs.get("base_url", "http://localhost:1234"),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown provider: {provider}")
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
@@ -235,6 +353,7 @@ class LLMClientFactory:
|
|||||||
for provider in LLMProvider:
|
for provider in LLMProvider:
|
||||||
is_available = True
|
is_available = True
|
||||||
install_command = None
|
install_command = None
|
||||||
|
is_local = False
|
||||||
|
|
||||||
if provider == LLMProvider.OPENAI:
|
if provider == LLMProvider.OPENAI:
|
||||||
is_available = OpenAIClient is not None
|
is_available = OpenAIClient is not None
|
||||||
@@ -251,6 +370,12 @@ class LLMClientFactory:
|
|||||||
elif provider == LLMProvider.AZURE:
|
elif provider == LLMProvider.AZURE:
|
||||||
is_available = AzureOpenAIClient is not None
|
is_available = AzureOpenAIClient is not None
|
||||||
install_command = "pip install datapizza-ai-clients-azure"
|
install_command = "pip install datapizza-ai-clients-azure"
|
||||||
|
elif provider == LLMProvider.OLLAMA:
|
||||||
|
is_local = True
|
||||||
|
install_command = "https://ollama.com/download"
|
||||||
|
elif provider == LLMProvider.LMSTUDIO:
|
||||||
|
is_local = True
|
||||||
|
install_command = "https://lmstudio.ai/download"
|
||||||
|
|
||||||
providers.append(
|
providers.append(
|
||||||
{
|
{
|
||||||
@@ -258,6 +383,7 @@ class LLMClientFactory:
|
|||||||
"name": provider.name.replace("_", " ").title(),
|
"name": provider.name.replace("_", " ").title(),
|
||||||
"available": is_available,
|
"available": is_available,
|
||||||
"install_command": install_command,
|
"install_command": install_command,
|
||||||
|
"is_local": is_local,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -275,10 +401,13 @@ class LLMClientFactory:
|
|||||||
LLMProvider.GOOGLE.value: "gemini-pro",
|
LLMProvider.GOOGLE.value: "gemini-pro",
|
||||||
LLMProvider.MISTRAL.value: "mistral-medium",
|
LLMProvider.MISTRAL.value: "mistral-medium",
|
||||||
LLMProvider.AZURE.value: "gpt-4",
|
LLMProvider.AZURE.value: "gpt-4",
|
||||||
|
LLMProvider.OLLAMA.value: "llama3.2",
|
||||||
|
LLMProvider.LMSTUDIO.value: "local-model",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Global client cache
|
# Global client cache
|
||||||
|
|
||||||
|
|
||||||
_client_cache: dict[str, BaseLLMClient] = {}
|
_client_cache: dict[str, BaseLLMClient] = {}
|
||||||
|
|
||||||
|
|
||||||
@@ -311,8 +440,17 @@ async def get_llm_client(
|
|||||||
if not api_key:
|
if not api_key:
|
||||||
api_key = settings.get_api_key_for_provider(provider)
|
api_key = settings.get_api_key_for_provider(provider)
|
||||||
|
|
||||||
|
# Prepare extra kwargs for local providers
|
||||||
|
extra_kwargs = {}
|
||||||
|
if provider == LLMProvider.OLLAMA.value:
|
||||||
|
extra_kwargs["base_url"] = settings.ollama_base_url
|
||||||
|
elif provider == LLMProvider.LMSTUDIO.value:
|
||||||
|
extra_kwargs["base_url"] = settings.lmstudio_base_url
|
||||||
|
|
||||||
# Create client
|
# Create client
|
||||||
client = LLMClientFactory.create_client(provider=LLMProvider(provider), api_key=api_key)
|
client = LLMClientFactory.create_client(
|
||||||
|
provider=LLMProvider(provider), api_key=api_key, **extra_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# Cache client
|
# Cache client
|
||||||
_client_cache[cache_key] = client
|
_client_cache[cache_key] = client
|
||||||
|
|||||||
248
tests/unit/test_agentic_rag/test_core/test_local_providers.py
Normal file
248
tests/unit/test_agentic_rag/test_core/test_local_providers.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
"""Tests for local LLM providers (Ollama and LMStudio)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaClient:
|
||||||
|
"""Test suite for Ollama client."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ollama_client(self):
|
||||||
|
"""Create an Ollama client instance."""
|
||||||
|
from agentic_rag.core.llm_factory import OllamaClient
|
||||||
|
|
||||||
|
return OllamaClient(api_key="", model="llama3.2", base_url="http://localhost:11434")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_client_initialization(self, ollama_client):
|
||||||
|
"""Test Ollama client initialization."""
|
||||||
|
assert ollama_client.model == "llama3.2"
|
||||||
|
assert ollama_client.base_url == "http://localhost:11434"
|
||||||
|
assert ollama_client.api_key == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_client_custom_base_url(self):
|
||||||
|
"""Test Ollama client with custom base URL."""
|
||||||
|
from agentic_rag.core.llm_factory import OllamaClient
|
||||||
|
|
||||||
|
client = OllamaClient(api_key="", model="mistral", base_url="http://192.168.1.100:11434")
|
||||||
|
assert client.base_url == "http://192.168.1.100:11434"
|
||||||
|
assert client.model == "mistral"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_invoke_success(self, ollama_client):
|
||||||
|
"""Test successful Ollama API call."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Test response from Ollama"}}],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
ollama_client.client = MagicMock()
|
||||||
|
ollama_client.client.post = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await ollama_client.invoke("Hello, how are you?")
|
||||||
|
|
||||||
|
assert result.text == "Test response from Ollama"
|
||||||
|
assert result.model == "llama3.2"
|
||||||
|
assert result.usage["prompt_tokens"] == 10
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_invoke_with_kwargs(self, ollama_client):
|
||||||
|
"""Test Ollama API call with additional kwargs."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Response"}}],
|
||||||
|
"usage": {},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
ollama_client.client = MagicMock()
|
||||||
|
ollama_client.client.post = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
await ollama_client.invoke("Hello", temperature=0.7, max_tokens=100)
|
||||||
|
|
||||||
|
# Verify the call was made with correct parameters
|
||||||
|
call_args = ollama_client.client.post.call_args
|
||||||
|
assert call_args[0][0] == "/v1/chat/completions"
|
||||||
|
json_data = call_args[1]["json"]
|
||||||
|
assert json_data["temperature"] == 0.7
|
||||||
|
assert json_data["max_tokens"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestLMStudioClient:
|
||||||
|
"""Test suite for LM Studio client."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def lmstudio_client(self):
|
||||||
|
"""Create an LM Studio client instance."""
|
||||||
|
from agentic_rag.core.llm_factory import LMStudioClient
|
||||||
|
|
||||||
|
return LMStudioClient(api_key="", model="local-model", base_url="http://localhost:1234")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lmstudio_client_initialization(self, lmstudio_client):
|
||||||
|
"""Test LM Studio client initialization."""
|
||||||
|
assert lmstudio_client.model == "local-model"
|
||||||
|
assert lmstudio_client.base_url == "http://localhost:1234"
|
||||||
|
assert lmstudio_client.api_key == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lmstudio_client_custom_base_url(self):
|
||||||
|
"""Test LM Studio client with custom base URL."""
|
||||||
|
from agentic_rag.core.llm_factory import LMStudioClient
|
||||||
|
|
||||||
|
client = LMStudioClient(
|
||||||
|
api_key="", model="custom-model", base_url="http://192.168.1.50:1234"
|
||||||
|
)
|
||||||
|
assert client.base_url == "http://192.168.1.50:1234"
|
||||||
|
assert client.model == "custom-model"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lmstudio_invoke_success(self, lmstudio_client):
|
||||||
|
"""Test successful LM Studio API call."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Test response from LM Studio"}}],
|
||||||
|
"usage": {"prompt_tokens": 15, "completion_tokens": 25},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
lmstudio_client.client = MagicMock()
|
||||||
|
lmstudio_client.client.post = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await lmstudio_client.invoke("What is AI?")
|
||||||
|
|
||||||
|
assert result.text == "Test response from LM Studio"
|
||||||
|
assert result.model == "local-model"
|
||||||
|
assert result.usage["prompt_tokens"] == 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMClientFactoryLocalProviders:
|
||||||
|
"""Test factory integration for local providers."""
|
||||||
|
|
||||||
|
def test_ollama_provider_creation(self):
|
||||||
|
"""Test creating Ollama client via factory."""
|
||||||
|
from agentic_rag.core.llm_factory import LLMClientFactory, LLMProvider
|
||||||
|
|
||||||
|
client = LLMClientFactory.create_client(
|
||||||
|
provider=LLMProvider.OLLAMA,
|
||||||
|
api_key="",
|
||||||
|
model="llama3.2",
|
||||||
|
base_url="http://localhost:11434",
|
||||||
|
)
|
||||||
|
|
||||||
|
from agentic_rag.core.llm_factory import OllamaClient
|
||||||
|
|
||||||
|
assert isinstance(client, OllamaClient)
|
||||||
|
assert client.model == "llama3.2"
|
||||||
|
|
||||||
|
def test_lmstudio_provider_creation(self):
|
||||||
|
"""Test creating LM Studio client via factory."""
|
||||||
|
from agentic_rag.core.llm_factory import LLMClientFactory, LLMProvider
|
||||||
|
|
||||||
|
client = LLMClientFactory.create_client(
|
||||||
|
provider=LLMProvider.LMSTUDIO,
|
||||||
|
api_key="",
|
||||||
|
model="qwen2.5",
|
||||||
|
base_url="http://localhost:1234",
|
||||||
|
)
|
||||||
|
|
||||||
|
from agentic_rag.core.llm_factory import LMStudioClient
|
||||||
|
|
||||||
|
assert isinstance(client, LMStudioClient)
|
||||||
|
assert client.model == "qwen2.5"
|
||||||
|
|
||||||
|
def test_list_providers_includes_local(self):
|
||||||
|
"""Test that local providers are listed."""
|
||||||
|
from agentic_rag.core.llm_factory import LLMClientFactory
|
||||||
|
|
||||||
|
providers = LLMClientFactory.list_available_providers()
|
||||||
|
provider_ids = [p["id"] for p in providers]
|
||||||
|
|
||||||
|
assert "ollama" in provider_ids
|
||||||
|
assert "lmstudio" in provider_ids
|
||||||
|
|
||||||
|
# Check they are marked as local
|
||||||
|
ollama_info = next(p for p in providers if p["id"] == "ollama")
|
||||||
|
lmstudio_info = next(p for p in providers if p["id"] == "lmstudio")
|
||||||
|
|
||||||
|
assert ollama_info["is_local"] == True
|
||||||
|
assert lmstudio_info["is_local"] == True
|
||||||
|
assert "download" in ollama_info["install_command"]
|
||||||
|
assert "download" in lmstudio_info["install_command"]
|
||||||
|
|
||||||
|
def test_default_models_include_local(self):
|
||||||
|
"""Test default models for local providers."""
|
||||||
|
from agentic_rag.core.llm_factory import LLMClientFactory
|
||||||
|
|
||||||
|
defaults = LLMClientFactory.get_default_models()
|
||||||
|
|
||||||
|
assert defaults["ollama"] == "llama3.2"
|
||||||
|
assert defaults["lmstudio"] == "local-model"
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigLocalProviders:
|
||||||
|
"""Test configuration for local providers."""
|
||||||
|
|
||||||
|
def test_ollama_base_url_config(self):
|
||||||
|
"""Test Ollama base URL configuration."""
|
||||||
|
from agentic_rag.core.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.ollama_base_url == "http://localhost:11434"
|
||||||
|
|
||||||
|
# Test custom URL
|
||||||
|
settings_custom = Settings(ollama_base_url="http://192.168.1.100:11434")
|
||||||
|
assert settings_custom.ollama_base_url == "http://192.168.1.100:11434"
|
||||||
|
|
||||||
|
def test_lmstudio_base_url_config(self):
|
||||||
|
"""Test LM Studio base URL configuration."""
|
||||||
|
from agentic_rag.core.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.lmstudio_base_url == "http://localhost:1234"
|
||||||
|
|
||||||
|
def test_get_api_key_for_local_providers(self):
|
||||||
|
"""Test API key retrieval for local providers."""
|
||||||
|
from agentic_rag.core.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
# Local providers should return "configured" instead of empty string
|
||||||
|
assert settings.get_api_key_for_provider("ollama") == "configured"
|
||||||
|
assert settings.get_api_key_for_provider("lmstudio") == "configured"
|
||||||
|
|
||||||
|
# They should be considered configured
|
||||||
|
assert settings.is_provider_configured("ollama") == True
|
||||||
|
assert settings.is_provider_configured("lmstudio") == True
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationLocalProviders:
|
||||||
|
"""Integration tests for local providers."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_llm_client_ollama(self):
|
||||||
|
"""Test getting Ollama client via get_llm_client."""
|
||||||
|
from agentic_rag.core.llm_factory import get_llm_client
|
||||||
|
|
||||||
|
with patch("agentic_rag.core.llm_factory._client_cache", {}):
|
||||||
|
client = await get_llm_client(provider="ollama")
|
||||||
|
|
||||||
|
from agentic_rag.core.llm_factory import OllamaClient
|
||||||
|
|
||||||
|
assert isinstance(client, OllamaClient)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_llm_client_lmstudio(self):
|
||||||
|
"""Test getting LM Studio client via get_llm_client."""
|
||||||
|
from agentic_rag.core.llm_factory import get_llm_client
|
||||||
|
|
||||||
|
with patch("agentic_rag.core.llm_factory._client_cache", {}):
|
||||||
|
client = await get_llm_client(provider="lmstudio")
|
||||||
|
|
||||||
|
from agentic_rag.core.llm_factory import LMStudioClient
|
||||||
|
|
||||||
|
assert isinstance(client, LMStudioClient)
|
||||||
Reference in New Issue
Block a user