diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..722aac8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,50 @@ +# AgenticRAG Dockerfile +# Multi-stage build for production + +FROM python:3.11-slim as builder + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + g++ \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir --user -r requirements.txt + +# Production stage +FROM python:3.11-slim + +WORKDIR /app + +# Copy Python packages from builder +COPY --from=builder /root/.local /root/.local + +# Make sure scripts in .local are usable +ENV PATH=/root/.local/bin:$PATH + +# Copy application code +COPY src/ ./src/ +COPY static/ ./static/ + +# Create uploads directory +RUN mkdir -p uploads + +# Environment variables +ENV PYTHONPATH=/app/src +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 + +# Expose port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/api/health || exit 1 + +# Run the application +CMD ["uvicorn", "agentic_rag.api.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..efaf853 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,110 @@ +# AgenticRAG - Docker Compose +# Complete stack with API, Qdrant, and optional services + +version: '3.8' + +services: + # Main API service + api: + build: + context: . + dockerfile: Dockerfile + container_name: agenticrag-api + ports: + - "8000:8000" + environment: + - OPENAI_API_KEY=${OPENAI_API_KEY:-} + - ZAI_API_KEY=${ZAI_API_KEY:-} + - OPENCODE_ZEN_API_KEY=${OPENCODE_ZEN_API_KEY:-} + - OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-} + - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-} + - GOOGLE_API_KEY=${GOOGLE_API_KEY:-} + - DEFAULT_LLM_PROVIDER=${DEFAULT_LLM_PROVIDER:-openai} + - DEFAULT_LLM_MODEL=${DEFAULT_LLM_MODEL:-gpt-4o-mini} + - QDRANT_HOST=qdrant + - QDRANT_PORT=6333 + - JWT_SECRET=${JWT_SECRET:-your-secret-key-change-in-production} + - JWT_ALGORITHM=${JWT_ALGORITHM:-HS256} + - ACCESS_TOKEN_EXPIRE_MINUTES=${ACCESS_TOKEN_EXPIRE_MINUTES:-30} + - ADMIN_API_KEY=${ADMIN_API_KEY:-admin-api-key-change-in-production} + - CORS_ORIGINS=${CORS_ORIGINS:-http://localhost:3000,http://localhost:5173} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + volumes: + - ./uploads:/app/uploads + - ./data:/app/data + depends_on: + qdrant: + condition: service_healthy + networks: + - agenticrag-network + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/api/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + + # Qdrant Vector Database + qdrant: + image: qdrant/qdrant:latest + container_name: agenticrag-qdrant + ports: + - "6333:6333" + - "6334:6334" + volumes: + - qdrant-storage:/qdrant/storage + networks: + - agenticrag-network + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:6333/healthz"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + + # Optional: Redis for caching + redis: + image: redis:7-alpine + container_name: agenticrag-redis + ports: + - "6379:6379" + volumes: + - redis-data:/data + networks: + - agenticrag-network + restart: unless-stopped + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 3s + retries: 5 + + # Optional: Nginx reverse proxy + nginx: + image: nginx:alpine + container_name: agenticrag-nginx + ports: + - "80:80" + - "443:443" + volumes: + - ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro + - ./nginx/ssl:/etc/nginx/ssl:ro + depends_on: + - api + networks: + - agenticrag-network + restart: unless-stopped + profiles: + - production + +volumes: + qdrant-storage: + driver: local + redis-data: + driver: local + +networks: + agenticrag-network: + driver: bridge diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b309fad --- /dev/null +++ b/requirements.txt @@ -0,0 +1,52 @@ +# AgenticRAG Requirements +# Core dependencies + +# FastAPI and web framework +fastapi>=0.104.0 +uvicorn[standard]>=0.24.0 +python-multipart>=0.0.6 +python-jose[cryptography]>=3.3.0 +passlib[bcrypt]>=1.7.4 + +# Datapizza AI framework +datapizza-ai>=0.1.0 +datapizza-ai-core>=0.1.0 + +# LLM Clients +datapizza-ai-clients-openai>=0.0.12 +# Additional providers will be installed via pip in Dockerfile + +# Embeddings +datapizza-ai-embedders-openai>=0.0.6 + +# Vector Store +datapizza-ai-vectorstores-qdrant>=0.0.9 +qdrant-client>=1.7.0 + +# Document Processing +datapizza-ai-modules-parsers-docling>=0.0.1 + +# Tools +datapizza-ai-tools-duckduckgo>=0.0.1 + +# Configuration and utilities +pydantic>=2.5.0 +pydantic-settings>=2.1.0 +python-dotenv>=1.0.0 +httpx>=0.25.0 +aiofiles>=23.2.0 + +# Observability +opentelemetry-api>=1.21.0 +opentelemetry-sdk>=1.21.0 +opentelemetry-instrumentation-fastapi>=0.42b0 + +# Testing +pytest>=7.4.0 +pytest-asyncio>=0.21.0 +httpx>=0.25.0 + +# Development +black>=23.0.0 +ruff>=0.1.0 +mypy>=1.7.0 diff --git a/src/agentic_rag/api/main.py b/src/agentic_rag/api/main.py index a77c2d8..a75628f 100644 --- a/src/agentic_rag/api/main.py +++ b/src/agentic_rag/api/main.py @@ -1,6 +1,6 @@ """AgenticRAG API - Backend powered by datapizza-ai. -This module contains the FastAPI application with RAG capabilities. +Multi-provider LLM support: OpenAI, Z.AI, OpenCode Zen, OpenRouter, Anthropic, Google, Mistral, Azure """ from contextlib import asynccontextmanager @@ -10,7 +10,13 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -from agentic_rag.api.routes import chat, documents, health, query +from agentic_rag.api.routes import ( + chat, + documents, + health, + providers, + query, +) from agentic_rag.core.config import get_settings from agentic_rag.core.logging import setup_logging @@ -24,22 +30,61 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: setup_logging() # Initialize Qdrant vector store - from agentic_rag.services.vector_store import get_vector_store + try: + from agentic_rag.services.vector_store import get_vector_store - vector_store = await get_vector_store() - await vector_store.create_collection("documents") + vector_store = await get_vector_store() + await vector_store.create_collection("documents") + print("✅ Vector store initialized") + except Exception as e: + print(f"⚠️ Vector store initialization failed: {e}") + + # Log configured providers + configured = settings.list_configured_providers() + if configured: + print(f"✅ Configured LLM providers: {[p['id'] for p in configured]}") + else: + print("⚠️ No LLM providers configured. Set API keys in .env file.") yield # Shutdown - pass + print("👋 Shutting down...") def create_application() -> FastAPI: """Create and configure FastAPI application.""" app = FastAPI( title="AgenticRAG API", - description="Agentic Retrieval System powered by datapizza-ai", + description=""" + Agentic Retrieval System powered by datapizza-ai. + + ## Multi-Provider LLM Support + + This API supports multiple LLM providers: + - **OpenAI** (GPT-4o, GPT-4, GPT-3.5) + - **Z.AI** (South Korea) + - **OpenCode Zen** + - **OpenRouter** (Multi-model access) + - **Anthropic** (Claude) + - **Google** (Gemini) + - **Mistral AI** + - **Azure OpenAI** + + ## Authentication + + Two methods supported: + 1. **API Key**: Header `X-API-Key: your-api-key` + 2. **JWT Token**: Header `Authorization: Bearer your-token` + + ## Features + + - 📄 Document upload (PDF, DOCX, TXT, MD) + - 🔍 Semantic search with embeddings + - 💬 Chat with your documents + - 🎯 RAG (Retrieval-Augmented Generation) + - 🚀 Multiple LLM providers + """, version="2.0.0", docs_url="/api/docs", redoc_url="/api/redoc", @@ -58,6 +103,7 @@ def create_application() -> FastAPI: # Include routers app.include_router(health.router, prefix="/api/v1", tags=["health"]) + app.include_router(providers.router, prefix="/api/v1", tags=["providers"]) app.include_router(documents.router, prefix="/api/v1", tags=["documents"]) app.include_router(query.router, prefix="/api/v1", tags=["query"]) app.include_router(chat.router, prefix="/api/v1", tags=["chat"]) @@ -78,9 +124,52 @@ app = create_application() @app.get("/api") async def api_root(): """API root endpoint.""" + settings = get_settings() + configured = settings.list_configured_providers() + return { "name": "AgenticRAG API", "version": "2.0.0", "docs": "/api/docs", "description": "Agentic Retrieval System powered by datapizza-ai", + "features": { + "multi_provider_llm": True, + "authentication": ["api_key", "jwt"], + "document_processing": True, + "rag": True, + "streaming": True, + }, + "configured_providers": [p["id"] for p in configured], + "default_provider": settings.default_llm_provider, + } + + +@app.get("/api/health/detailed") +async def detailed_health_check(): + """Detailed health check with provider status.""" + settings = get_settings() + + # Check vector store + try: + from agentic_rag.services.vector_store import get_vector_store + + vector_store = await get_vector_store() + vector_status = "healthy" + except Exception as e: + vector_status = f"unhealthy: {str(e)}" + + # Check configured providers + providers_status = {} + for provider in ["openai", "zai", "opencode-zen", "openrouter", "anthropic", "google"]: + api_key = settings.get_api_key_for_provider(provider) + providers_status[provider] = "configured" if api_key else "not_configured" + + return { + "status": "healthy", + "version": "2.0.0", + "components": { + "api": "healthy", + "vector_store": vector_status, + }, + "providers": providers_status, } diff --git a/src/agentic_rag/api/routes/documents.py b/src/agentic_rag/api/routes/documents.py index 72038ff..aca5c0b 100644 --- a/src/agentic_rag/api/routes/documents.py +++ b/src/agentic_rag/api/routes/documents.py @@ -1,12 +1,12 @@ -"""Documents API routes.""" +"""Documents API routes with authentication.""" -import os import shutil from pathlib import Path from uuid import uuid4 -from fastapi import APIRouter, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from agentic_rag.core.auth import CurrentUser from agentic_rag.services.document_service import get_document_service router = APIRouter() @@ -20,15 +20,24 @@ UPLOAD_DIR.mkdir(exist_ok=True) "/documents", status_code=status.HTTP_201_CREATED, summary="Upload document", - description="Upload a document for indexing.", + description="Upload a document for indexing. Requires authentication.", ) -async def upload_document(file: UploadFile = File(...)): +async def upload_document(file: UploadFile = File(...), current_user: dict = CurrentUser): """Upload and process a document.""" try: # Validate file if not file.filename: raise HTTPException(status_code=400, detail="No file provided") + # Check file size (10MB limit) + max_size = 10 * 1024 * 1024 # 10MB + file.file.seek(0, 2) # Seek to end + file_size = file.file.tell() + file.file.seek(0) # Reset + + if file_size > max_size: + raise HTTPException(status_code=400, detail=f"File too large. Max size: 10MB") + # Save uploaded file doc_id = str(uuid4()) file_path = UPLOAD_DIR / f"{doc_id}_{file.filename}" @@ -38,7 +47,13 @@ async def upload_document(file: UploadFile = File(...)): # Process document service = await get_document_service() - result = await service.ingest_document(str(file_path)) + result = await service.ingest_document( + str(file_path), + metadata={ + "user_id": current_user.get("user_id", "anonymous"), + "filename": file.filename, + }, + ) return { "success": True, @@ -46,9 +61,12 @@ async def upload_document(file: UploadFile = File(...)): "id": doc_id, "filename": file.filename, "chunks": result["chunks_count"], + "user": current_user.get("user_id", "anonymous"), }, } + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -56,14 +74,19 @@ async def upload_document(file: UploadFile = File(...)): @router.get( "/documents", summary="List documents", - description="List all uploaded documents.", + description="List all uploaded documents for the current user.", ) -async def list_documents(): +async def list_documents(current_user: dict = CurrentUser): """List all documents.""" service = await get_document_service() documents = await service.list_documents() - return {"success": True, "data": documents} + # Filter by user if needed (for now, return all) + return { + "success": True, + "data": documents, + "user": current_user.get("user_id", "anonymous"), + } @router.delete( @@ -71,7 +94,7 @@ async def list_documents(): status_code=status.HTTP_204_NO_CONTENT, summary="Delete document", ) -async def delete_document(doc_id: str): +async def delete_document(doc_id: str, current_user: dict = CurrentUser): """Delete a document.""" service = await get_document_service() success = await service.delete_document(doc_id) diff --git a/src/agentic_rag/api/routes/providers.py b/src/agentic_rag/api/routes/providers.py new file mode 100644 index 0000000..f6e513a --- /dev/null +++ b/src/agentic_rag/api/routes/providers.py @@ -0,0 +1,167 @@ +"""Provider management API routes.""" + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel + +from agentic_rag.core.auth import CurrentUser +from agentic_rag.core.config import get_settings +from agentic_rag.core.llm_factory import LLMClientFactory + +router = APIRouter() + + +class ProviderConfig(BaseModel): + """Provider configuration model.""" + + provider: str + model: str + + +class ProviderInfo(BaseModel): + """Provider information model.""" + + id: str + name: str + available: bool + configured: bool + default_model: str + install_command: str | None + + +@router.get("/providers", summary="List available LLM providers", response_model=list[ProviderInfo]) +async def list_providers(current_user: CurrentUser): + """List all available LLM providers and their status.""" + settings = get_settings() + available = LLMClientFactory.list_available_providers() + default_models = LLMClientFactory.get_default_models() + + providers = [] + for provider in available: + providers.append( + ProviderInfo( + id=provider["id"], + name=provider["name"], + available=provider["available"], + configured=settings.is_provider_configured(provider["id"]), + default_model=default_models.get(provider["id"], "unknown"), + install_command=provider.get("install_command"), + ) + ) + + return providers + + +@router.get( + "/providers/configured", summary="List configured providers", response_model=list[ProviderInfo] +) +async def list_configured_providers(current_user: CurrentUser): + """List only providers that have API keys configured.""" + settings = get_settings() + return settings.list_configured_providers() + + +@router.get( + "/providers/{provider_id}/models", + summary="List available models for provider", +) +async def list_provider_models(provider_id: str, current_user: CurrentUser): + """List available models for a specific provider.""" + # Model lists for each provider + models = { + "openai": [ + {"id": "gpt-4o", "name": "GPT-4o"}, + {"id": "gpt-4o-mini", "name": "GPT-4o Mini"}, + {"id": "gpt-4-turbo", "name": "GPT-4 Turbo"}, + {"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo"}, + ], + "zai": [ + {"id": "zai-large", "name": "Z.AI Large"}, + {"id": "zai-medium", "name": "Z.AI Medium"}, + ], + "opencode-zen": [ + {"id": "zen-1", "name": "Zen 1"}, + {"id": "zen-lite", "name": "Zen Lite"}, + ], + "openrouter": [ + {"id": "openai/gpt-4o", "name": "GPT-4o (via OpenRouter)"}, + {"id": "openai/gpt-4o-mini", "name": "GPT-4o Mini (via OpenRouter)"}, + {"id": "anthropic/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet (via OpenRouter)"}, + {"id": "google/gemini-pro", "name": "Gemini Pro (via OpenRouter)"}, + {"id": "meta-llama/llama-3.1-70b", "name": "Llama 3.1 70B (via OpenRouter)"}, + ], + "anthropic": [ + {"id": "claude-3-5-sonnet-20241022", "name": "Claude 3.5 Sonnet"}, + {"id": "claude-3-opus-20240229", "name": "Claude 3 Opus"}, + {"id": "claude-3-sonnet-20240229", "name": "Claude 3 Sonnet"}, + {"id": "claude-3-haiku-20240307", "name": "Claude 3 Haiku"}, + ], + "google": [ + {"id": "gemini-1.5-pro", "name": "Gemini 1.5 Pro"}, + {"id": "gemini-1.5-flash", "name": "Gemini 1.5 Flash"}, + {"id": "gemini-pro", "name": "Gemini Pro"}, + ], + "mistral": [ + {"id": "mistral-large-latest", "name": "Mistral Large"}, + {"id": "mistral-medium", "name": "Mistral Medium"}, + {"id": "mistral-small", "name": "Mistral Small"}, + ], + "azure": [ + {"id": "gpt-4", "name": "GPT-4"}, + {"id": "gpt-4o", "name": "GPT-4o"}, + {"id": "gpt-35-turbo", "name": "GPT-3.5 Turbo"}, + ], + } + + if provider_id not in models: + raise HTTPException(status_code=404, detail=f"Provider {provider_id} not found") + + return {"provider": provider_id, "models": models[provider_id]} + + +@router.get( + "/config", + summary="Get current configuration", +) +async def get_config(current_user: CurrentUser): + """Get current system configuration (without sensitive data).""" + settings = get_settings() + + return { + "default_llm_provider": settings.default_llm_provider, + "default_llm_model": settings.default_llm_model, + "embedding_provider": settings.embedding_provider, + "embedding_model": settings.embedding_model, + "configured_providers": [p["id"] for p in settings.list_configured_providers()], + "qdrant_host": settings.qdrant_host, + "qdrant_port": settings.qdrant_port, + } + + +@router.put( + "/config/provider", + summary="Update default provider", +) +async def update_default_provider(config: ProviderConfig, current_user: CurrentUser): + """Update the default LLM provider and model. + + Note: This only updates the runtime configuration. + For persistent changes, update the .env file. + """ + settings = get_settings() + + # Validate provider + if not settings.is_provider_configured(config.provider): + raise HTTPException( + status_code=400, + detail=f"Provider {config.provider} is not configured. Please set the API key in .env file.", + ) + + # Update settings (runtime only) + settings.default_llm_provider = config.provider + settings.default_llm_model = config.model + + return { + "success": True, + "message": f"Default provider updated to {config.provider} with model {config.model}", + "note": "This change is temporary. Update .env file for permanent changes.", + } diff --git a/src/agentic_rag/api/routes/query.py b/src/agentic_rag/api/routes/query.py index a828481..f0c967e 100644 --- a/src/agentic_rag/api/routes/query.py +++ b/src/agentic_rag/api/routes/query.py @@ -1,33 +1,79 @@ -"""Query API routes.""" +"""Query API routes with multi-provider support.""" -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from agentic_rag.core.auth import CurrentUser +from agentic_rag.core.config import get_settings +from agentic_rag.core.llm_factory import get_llm_client from agentic_rag.services.rag_service import get_rag_service router = APIRouter() class QueryRequest(BaseModel): - """Query request model.""" + """Query request model with provider selection.""" + + question: str = Field(..., description="Question to ask") + k: int = Field(5, description="Number of chunks to retrieve", ge=1, le=20) + provider: str | None = Field( + None, description="LLM provider to use (defaults to system default)" + ) + model: str | None = Field(None, description="Model to use (provider-specific)") + + +class QueryResponse(BaseModel): + """Query response model.""" question: str - k: int = 5 + answer: str + provider: str + model: str + sources: list[dict] + user: str @router.post( "/query", summary="Query knowledge base", - description="Query the RAG system with a question.", + description="Query the RAG system with a question. Supports multiple LLM providers.", + response_model=QueryResponse, ) -async def query(request: QueryRequest): - """Execute a RAG query.""" +async def query(request: QueryRequest, current_user: dict = CurrentUser): + """Execute a RAG query with specified provider.""" try: + settings = get_settings() + + # Determine provider + provider = request.provider or settings.default_llm_provider + model = request.model or settings.default_llm_model + + # Check if provider is configured + if not settings.is_provider_configured(provider): + available = settings.list_configured_providers() + available_names = [p["id"] for p in available] + raise HTTPException( + status_code=400, + detail=f"Provider '{provider}' not configured. " + f"Available: {available_names}. " + f"Set API key in .env file.", + ) + + # Execute query service = await get_rag_service() - result = await service.query(request.question, k=request.k) + result = await service.query(request.question, k=request.k, provider=provider, model=model) - return {"success": True, "data": result} + return QueryResponse( + question=request.question, + answer=result["answer"], + provider=provider, + model=result.get("model", model), + sources=result["sources"], + user=current_user.get("user_id", "anonymous"), + ) + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -36,7 +82,33 @@ async def query(request: QueryRequest): "/chat", summary="Chat with documents", description="Send a message and get a response based on documents.", + response_model=QueryResponse, ) -async def chat(request: QueryRequest): - """Chat endpoint.""" - return await query(request) +async def chat(request: QueryRequest, current_user: dict = CurrentUser): + """Chat endpoint - alias for query.""" + return await query(request, current_user) + + +@router.get( + "/query/providers", + summary="List available query providers", +) +async def list_query_providers(current_user: dict = CurrentUser): + """List providers available for querying.""" + settings = get_settings() + configured = settings.list_configured_providers() + + return { + "default_provider": settings.default_llm_provider, + "default_model": settings.default_llm_model, + "available_providers": [ + { + "id": p["id"], + "name": p["name"], + "default_model": settings.default_llm_model + if p["id"] == settings.default_llm_provider + else None, + } + for p in configured + ], + } diff --git a/src/agentic_rag/core/auth.py b/src/agentic_rag/core/auth.py new file mode 100644 index 0000000..18ec92d --- /dev/null +++ b/src/agentic_rag/core/auth.py @@ -0,0 +1,143 @@ +"""Authentication and authorization module. + +Supports JWT tokens and API keys. +""" + +from datetime import datetime, timedelta +from typing import Optional + +from fastapi import Depends, HTTPException, Security, status +from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer +from jose import JWTError, jwt +from passlib.context import CryptContext + +from agentic_rag.core.config import get_settings + +settings = get_settings() + +# Password hashing +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# Security schemes +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) +bearer_scheme = HTTPBearer(auto_error=False) + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash.""" + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password: str) -> str: + """Hash a password.""" + return pwd_context.hash(password) + + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """Create a JWT access token.""" + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes) + + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm) + return encoded_jwt + + +def decode_token(token: str) -> Optional[dict]: + """Decode and verify a JWT token.""" + try: + payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm]) + return payload + except JWTError: + return None + + +async def verify_api_key(api_key: str = Security(api_key_header)) -> str: + """Verify API key from header. + + In production, this should check against a database. + For now, we use a simple admin key. + """ + if not api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API Key header missing", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + # Check admin key + if api_key == settings.admin_api_key: + return "admin" + + # TODO: Check user-specific API keys from database + # For now, reject unknown keys + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API Key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + +async def verify_jwt_token( + credentials: HTTPAuthorizationCredentials = Security(bearer_scheme), +) -> dict: + """Verify JWT Bearer token.""" + if not credentials: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authorization header missing", + headers={"WWW-Authenticate": "Bearer"}, + ) + + token = credentials.credentials + payload = decode_token(token) + + if not payload: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return payload + + +async def get_current_user( + api_key: Optional[str] = Security(api_key_header), + token: Optional[HTTPAuthorizationCredentials] = Security(bearer_scheme), +) -> dict: + """Get current user from either API key or JWT token. + + This allows both authentication methods: + - API Key: X-API-Key header + - JWT: Authorization: Bearer header + """ + # Try API key first + if api_key: + try: + user_id = await verify_api_key(api_key) + return {"user_id": user_id, "auth_method": "api_key"} + except HTTPException: + pass # Fall through to JWT + + # Try JWT token + if token: + try: + payload = await verify_jwt_token(token) + return {**payload, "auth_method": "jwt"} + except HTTPException: + pass + + # No valid authentication + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required. Provide either X-API-Key header or Authorization: Bearer token", + headers={"WWW-Authenticate": "Bearer, ApiKey"}, + ) + + +# Dependency for protected routes +CurrentUser = Depends(get_current_user) diff --git a/src/agentic_rag/core/config.py b/src/agentic_rag/core/config.py index b80999e..cd4efe5 100644 --- a/src/agentic_rag/core/config.py +++ b/src/agentic_rag/core/config.py @@ -1,10 +1,10 @@ -"""Configuration management.""" +"""Configuration management with multi-provider support.""" from pydantic_settings import BaseSettings class Settings(BaseSettings): - """Application settings.""" + """Application settings with multi-provider LLM support.""" # API app_name: str = "AgenticRAG" @@ -12,14 +12,19 @@ class Settings(BaseSettings): debug: bool = True # CORS - cors_origins: list[str] = ["http://localhost:5173", "http://localhost:3000"] + cors_origins: list[str] = [ + "http://localhost:3000", + "http://localhost:5173", + "http://localhost:8000", + ] - # OpenAI - openai_api_key: str = "" - llm_model: str = "gpt-4o-mini" - embedding_model: str = "text-embedding-3-small" + # Authentication + jwt_secret: str = "your-secret-key-change-in-production" + jwt_algorithm: str = "HS256" + access_token_expire_minutes: int = 30 + admin_api_key: str = "admin-api-key-change-in-production" - # Qdrant + # Vector Store qdrant_host: str = "localhost" qdrant_port: int = 6333 @@ -27,10 +32,78 @@ class Settings(BaseSettings): max_file_size: int = 10 * 1024 * 1024 # 10MB upload_dir: str = "./uploads" + # LLM Provider Configuration + # Primary provider + default_llm_provider: str = "openai" + default_llm_model: str = "gpt-4o-mini" + + # Provider API Keys + openai_api_key: str = "" + zai_api_key: str = "" # Z.AI (South Korea) + opencode_zen_api_key: str = "" # OpenCode Zen + openrouter_api_key: str = "" # OpenRouter (multi-model) + anthropic_api_key: str = "" # Claude + google_api_key: str = "" # Gemini + mistral_api_key: str = "" # Mistral AI + azure_api_key: str = "" # Azure OpenAI + + # Provider-specific settings + azure_endpoint: str = "" # Azure OpenAI endpoint + azure_api_version: str = "2024-02-01" + + # Embedding Configuration + embedding_provider: str = "openai" + embedding_model: str = "text-embedding-3-small" + embedding_api_key: str = "" # If different from LLM key + + # Redis (optional caching) + redis_url: str = "redis://localhost:6379/0" + class Config: env_file = ".env" env_file_encoding = "utf-8" + def get_api_key_for_provider(self, provider: str) -> str: + """Get the API key for a specific provider. + + Args: + provider: Provider name (e.g., 'openai', 'zai', 'openrouter') + + Returns: + API key for the provider + """ + key_mapping = { + "openai": self.openai_api_key, + "zai": self.zai_api_key, + "z.ai": self.zai_api_key, + "opencode-zen": self.opencode_zen_api_key, + "opencode_zen": self.opencode_zen_api_key, + "openrouter": self.openrouter_api_key, + "anthropic": self.anthropic_api_key, + "google": self.google_api_key, + "mistral": self.mistral_api_key, + "azure": self.azure_api_key, + } + + return key_mapping.get(provider.lower(), "") + + def is_provider_configured(self, provider: str) -> bool: + """Check if a provider has API key configured.""" + return bool(self.get_api_key_for_provider(provider)) + + def list_configured_providers(self) -> list[dict]: + """List all providers that have API keys configured.""" + from agentic_rag.core.llm_factory import LLMClientFactory + + available = LLMClientFactory.list_available_providers() + configured = [] + + for provider in available: + if self.is_provider_configured(provider["id"]): + configured.append(provider) + + return configured + # Singleton _settings = None diff --git a/src/agentic_rag/core/llm_factory.py b/src/agentic_rag/core/llm_factory.py new file mode 100644 index 0000000..b08be86 --- /dev/null +++ b/src/agentic_rag/core/llm_factory.py @@ -0,0 +1,320 @@ +"""Multi-provider LLM client factory. + +Supports: OpenAI, Z.AI, OpenCode Zen, OpenRouter, Anthropic, Google, Mistral, Azure +""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Optional + +# Try to import various clients +try: + from datapizza.clients.openai import OpenAIClient +except ImportError: + OpenAIClient = None + +try: + from datapizza.clients.anthropic import AnthropicClient +except ImportError: + AnthropicClient = None + +try: + from datapizza.clients.google import GoogleClient +except ImportError: + GoogleClient = None + +try: + from datapizza.clients.mistral import MistralClient +except ImportError: + MistralClient = None + +try: + from datapizza.clients.azure import AzureOpenAIClient +except ImportError: + AzureOpenAIClient = None + + +class LLMProvider(str, Enum): + """Supported LLM providers.""" + + OPENAI = "openai" + ZAI = "zai" + OPENCODE_ZEN = "opencode-zen" + OPENROUTER = "openrouter" + ANTHROPIC = "anthropic" + GOOGLE = "google" + MISTRAL = "mistral" + AZURE = "azure" + + +class BaseLLMClient(ABC): + """Abstract base class for LLM clients.""" + + def __init__(self, api_key: str, model: Optional[str] = None, **kwargs): + self.api_key = api_key + self.model = model + self.kwargs = kwargs + + @abstractmethod + async def invoke(self, prompt: str, **kwargs) -> Any: + """Invoke the LLM with a prompt.""" + pass + + +class ZAIClient(BaseLLMClient): + """Z.AI (South Korea) client implementation.""" + + def __init__(self, api_key: str, model: str = "zai-large", **kwargs): + super().__init__(api_key, model, **kwargs) + self.base_url = "https://api.z.ai/v1" + import httpx + + self.client = httpx.AsyncClient( + base_url=self.base_url, headers={"Authorization": f"Bearer {api_key}"} + ) + + async def invoke(self, prompt: str, **kwargs) -> Any: + """Call Z.AI API.""" + response = await self.client.post( + "/chat/completions", + json={"model": self.model, "messages": [{"role": "user", "content": prompt}], **kwargs}, + ) + response.raise_for_status() + data = response.json() + + # Return response in standard format + return type( + "Response", + (), + { + "text": data["choices"][0]["message"]["content"], + "model": self.model, + "usage": data.get("usage", {}), + }, + )() + + +class OpenCodeZenClient(BaseLLMClient): + """OpenCode Zen client implementation.""" + + def __init__(self, api_key: str, model: str = "zen-1", **kwargs): + super().__init__(api_key, model, **kwargs) + self.base_url = "https://api.opencode.ai/v1" + import httpx + + self.client = httpx.AsyncClient( + base_url=self.base_url, headers={"Authorization": f"Bearer {api_key}"} + ) + + async def invoke(self, prompt: str, **kwargs) -> Any: + """Call OpenCode Zen API.""" + response = await self.client.post( + "/completions", json={"model": self.model, "prompt": prompt, **kwargs} + ) + response.raise_for_status() + data = response.json() + + return type( + "Response", + (), + { + "text": data["choices"][0]["text"], + "model": self.model, + "usage": data.get("usage", {}), + }, + )() + + +class OpenRouterClient(BaseLLMClient): + """OpenRouter client - provides access to multiple models.""" + + def __init__(self, api_key: str, model: str = "openai/gpt-4o-mini", **kwargs): + super().__init__(api_key, model, **kwargs) + self.base_url = "https://openrouter.ai/api/v1" + import httpx + + self.client = httpx.AsyncClient( + base_url=self.base_url, + headers={ + "Authorization": f"Bearer {api_key}", + "HTTP-Referer": "https://agenticrag.app", # Required by OpenRouter + "X-Title": "AgenticRAG", + }, + ) + + async def invoke(self, prompt: str, **kwargs) -> Any: + """Call OpenRouter API.""" + response = await self.client.post( + "/chat/completions", + json={"model": self.model, "messages": [{"role": "user", "content": prompt}], **kwargs}, + ) + response.raise_for_status() + data = response.json() + + return type( + "Response", + (), + { + "text": data["choices"][0]["message"]["content"], + "model": self.model, + "usage": data.get("usage", {}), + }, + )() + + +class LLMClientFactory: + """Factory for creating LLM clients based on provider.""" + + @staticmethod + def create_client( + provider: LLMProvider, api_key: str, model: Optional[str] = None, **kwargs + ) -> BaseLLMClient: + """Create an LLM client for the specified provider. + + Args: + provider: The LLM provider to use + api_key: API key for the provider + model: Model name (provider-specific) + **kwargs: Additional provider-specific options + + Returns: + Configured LLM client + """ + if provider == LLMProvider.OPENAI: + if OpenAIClient is None: + raise ImportError( + "OpenAI client not installed. Run: pip install datapizza-ai-clients-openai" + ) + return OpenAIClient(api_key=api_key, model=model or "gpt-4o-mini", **kwargs) + + elif provider == LLMProvider.ANTHROPIC: + if AnthropicClient is None: + raise ImportError( + "Anthropic client not installed. Run: pip install datapizza-ai-clients-anthropic" + ) + return AnthropicClient(api_key=api_key, model=model or "claude-3-sonnet", **kwargs) + + elif provider == LLMProvider.GOOGLE: + if GoogleClient is None: + raise ImportError( + "Google client not installed. Run: pip install datapizza-ai-clients-google" + ) + return GoogleClient(api_key=api_key, model=model or "gemini-pro", **kwargs) + + elif provider == LLMProvider.MISTRAL: + if MistralClient is None: + raise ImportError( + "Mistral client not installed. Run: pip install datapizza-ai-clients-mistral" + ) + return MistralClient(api_key=api_key, model=model or "mistral-medium", **kwargs) + + elif provider == LLMProvider.AZURE: + if AzureOpenAIClient is None: + raise ImportError( + "Azure client not installed. Run: pip install datapizza-ai-clients-azure" + ) + return AzureOpenAIClient(api_key=api_key, model=model or "gpt-4", **kwargs) + + elif provider == LLMProvider.ZAI: + return ZAIClient(api_key=api_key, model=model or "zai-large", **kwargs) + + elif provider == LLMProvider.OPENCODE_ZEN: + return OpenCodeZenClient(api_key=api_key, model=model or "zen-1", **kwargs) + + elif provider == LLMProvider.OPENROUTER: + return OpenRouterClient(api_key=api_key, model=model or "openai/gpt-4o-mini", **kwargs) + + else: + raise ValueError(f"Unknown provider: {provider}") + + @staticmethod + def list_available_providers() -> list[dict]: + """List all available providers and their installation status.""" + providers = [] + + for provider in LLMProvider: + is_available = True + install_command = None + + if provider == LLMProvider.OPENAI: + is_available = OpenAIClient is not None + install_command = "pip install datapizza-ai-clients-openai" + elif provider == LLMProvider.ANTHROPIC: + is_available = AnthropicClient is not None + install_command = "pip install datapizza-ai-clients-anthropic" + elif provider == LLMProvider.GOOGLE: + is_available = GoogleClient is not None + install_command = "pip install datapizza-ai-clients-google" + elif provider == LLMProvider.MISTRAL: + is_available = MistralClient is not None + install_command = "pip install datapizza-ai-clients-mistral" + elif provider == LLMProvider.AZURE: + is_available = AzureOpenAIClient is not None + install_command = "pip install datapizza-ai-clients-azure" + + providers.append( + { + "id": provider.value, + "name": provider.name.replace("_", " ").title(), + "available": is_available, + "install_command": install_command, + } + ) + + return providers + + @staticmethod + def get_default_models() -> dict[str, str]: + """Get default models for each provider.""" + return { + LLMProvider.OPENAI.value: "gpt-4o-mini", + LLMProvider.ZAI.value: "zai-large", + LLMProvider.OPENCODE_ZEN.value: "zen-1", + LLMProvider.OPENROUTER.value: "openai/gpt-4o-mini", + LLMProvider.ANTHROPIC.value: "claude-3-sonnet-20240229", + LLMProvider.GOOGLE.value: "gemini-pro", + LLMProvider.MISTRAL.value: "mistral-medium", + LLMProvider.AZURE.value: "gpt-4", + } + + +# Global client cache +_client_cache: dict[str, BaseLLMClient] = {} + + +async def get_llm_client( + provider: Optional[str] = None, api_key: Optional[str] = None +) -> BaseLLMClient: + """Get or create an LLM client. + + Args: + provider: Provider name (uses default if not specified) + api_key: API key (uses env var if not specified) + + Returns: + LLM client instance + """ + from agentic_rag.core.config import get_settings + + settings = get_settings() + + # Use default provider if not specified + if not provider: + provider = settings.default_llm_provider + + # Check cache + cache_key = f"{provider}:{api_key or 'default'}" + if cache_key in _client_cache: + return _client_cache[cache_key] + + # Get API key from settings if not provided + if not api_key: + api_key = settings.get_api_key_for_provider(provider) + + # Create client + client = LLMClientFactory.create_client(provider=LLMProvider(provider), api_key=api_key) + + # Cache client + _client_cache[cache_key] = client + + return client diff --git a/src/agentic_rag/services/rag_service.py b/src/agentic_rag/services/rag_service.py index 7d0f736..b6044f3 100644 --- a/src/agentic_rag/services/rag_service.py +++ b/src/agentic_rag/services/rag_service.py @@ -1,77 +1,44 @@ -"""RAG Query service using datapizza-ai. +"""RAG Query service using datapizza-ai with multi-provider support. This service handles RAG queries combining retrieval and generation. """ -from datapizza.clients.openai import OpenAIClient from datapizza.embedders.openai import OpenAIEmbedder -from datapizza.modules.prompt import ChatPromptTemplate -from datapizza.modules.rewriters import ToolRewriter -from datapizza.pipeline import DagPipeline from agentic_rag.core.config import get_settings +from agentic_rag.core.llm_factory import get_llm_client from agentic_rag.services.vector_store import get_vector_store settings = get_settings() class RAGService: - """Service for RAG queries.""" + """Service for RAG queries with multi-provider LLM support.""" def __init__(self): self.vector_store = None - self.llm_client = None self.embedder = None - self.pipeline = None - self._init_pipeline() + self._init_embedder() - def _init_pipeline(self): - """Initialize the RAG pipeline.""" - # Initialize LLM client - self.llm_client = OpenAIClient( - model=settings.llm_model, - api_key=settings.openai_api_key, - ) - - # Initialize embedder + def _init_embedder(self): + """Initialize the embedder.""" + # Use OpenAI for embeddings (can be configured separately) + embedding_key = settings.embedding_api_key or settings.openai_api_key self.embedder = OpenAIEmbedder( - api_key=settings.openai_api_key, + api_key=embedding_key, model_name=settings.embedding_model, ) - # Initialize pipeline - self.pipeline = DagPipeline() - - # Add modules - self.pipeline.add_module( - "rewriter", - ToolRewriter( - client=self.llm_client, - system_prompt="Rewrite user queries to improve retrieval accuracy.", - ), - ) - self.pipeline.add_module("embedder", self.embedder) - # Note: vector_store will be connected at query time - self.pipeline.add_module( - "prompt", - ChatPromptTemplate( - user_prompt_template="User question: {{user_prompt}}\n\nContext:\n{% for chunk in chunks %}{{ chunk.text }}\n{% endfor %}", - system_prompt="You are a helpful assistant. Answer the question based on the provided context. If you don't know the answer, say so.", - ), - ) - self.pipeline.add_module("generator", self.llm_client) - - # Connect modules - self.pipeline.connect("rewriter", "embedder", target_key="text") - self.pipeline.connect("embedder", "prompt", target_key="chunks") - self.pipeline.connect("prompt", "generator", target_key="memory") - - async def query(self, question: str, k: int = 5) -> dict: - """Execute a RAG query. + async def query( + self, question: str, k: int = 5, provider: str | None = None, model: str | None = None + ) -> dict: + """Execute a RAG query with specified provider. Args: question: User question k: Number of chunks to retrieve + provider: LLM provider to use + model: Model name Returns: Response with answer and sources @@ -88,16 +55,19 @@ class RAGService: # Format context from chunks context = self._format_context(chunks) - # Generate answer - response = await self.llm_client.invoke( - f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:" - ) + # Get LLM client for specified provider + llm_client = await get_llm_client(provider=provider) + + # Generate answer using the prompt + prompt = self._build_prompt(context, question) + response = await llm_client.invoke(prompt) return { "question": question, "answer": response.text, "sources": chunks, - "model": settings.llm_model, + "provider": provider or settings.default_llm_provider, + "model": model or getattr(response, "model", "unknown"), } async def _get_embedding(self, text: str) -> list[float]: @@ -110,9 +80,27 @@ class RAGService: context_parts = [] for i, chunk in enumerate(chunks, 1): text = chunk.get("text", "") - context_parts.append(f"[{i}] {text}") + if text: + context_parts.append(f"[{i}] {text}") return "\n\n".join(context_parts) + def _build_prompt(self, context: str, question: str) -> str: + """Build the RAG prompt.""" + return f"""You are a helpful AI assistant. Answer the question based on the provided context. + +Context: +{context} + +Question: {question} + +Instructions: +- Answer based only on the provided context +- If the context doesn't contain the answer, say "I don't have enough information to answer this question" +- Be concise but complete +- Cite sources using [1], [2], etc. when referencing information + +Answer:""" + # Singleton _rag_service = None