feat(api): add chat functionality (Sprint 3)
Implement Sprint 3: Chat Functionality
- Add ChatService with send_message and get_history methods
- Add POST /api/v1/notebooks/{id}/chat - Send message
- Add GET /api/v1/notebooks/{id}/chat/history - Get chat history
- Add ChatRequest model (message, include_references)
- Add ChatResponse model (message, sources[], timestamp)
- Add ChatMessage model (id, role, content, timestamp, sources)
- Add SourceReference model (source_id, title, snippet)
- Integrate chat router with main app
Features:
- Send messages to notebook chat
- Get AI responses with source references
- Retrieve chat history
- Support for citations in responses
Tests:
- 14 unit tests for ChatService
- 11 integration tests for chat API
- 25/25 tests passing
Related: Sprint 3 - Chat Functionality
This commit is contained in:
107
prompts/3-chat-functionality.md
Normal file
107
prompts/3-chat-functionality.md
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# Prompt Sprint 3 - Chat Functionality
|
||||||
|
|
||||||
|
## 🎯 Sprint 3: Chat Functionality
|
||||||
|
|
||||||
|
**Iniziato**: 2026-04-06
|
||||||
|
**Stato**: 🟡 In Progress
|
||||||
|
**Assegnato**: @sprint-lead
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📋 Obiettivo
|
||||||
|
|
||||||
|
Implementare la funzionalità di chat per interrogare le fonti dei notebook. Gli utenti potranno inviare messaggi e ricevere risposte basate sulle fonti caricate.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🏗️ Architettura
|
||||||
|
|
||||||
|
### Pattern (stesso di Sprint 1 & 2)
|
||||||
|
|
||||||
|
```
|
||||||
|
API Layer (FastAPI Routes)
|
||||||
|
↓
|
||||||
|
Service Layer (ChatService)
|
||||||
|
↓
|
||||||
|
External Layer (notebooklm-py client)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Endpoints da implementare
|
||||||
|
|
||||||
|
1. **POST /api/v1/notebooks/{id}/chat** - Inviare messaggio
|
||||||
|
2. **GET /api/v1/notebooks/{id}/chat/history** - Ottenere storico chat
|
||||||
|
3. **POST /api/v1/notebooks/{id}/chat/save** - Salvare risposta come nota (v2)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📊 Task Breakdown Sprint 3
|
||||||
|
|
||||||
|
### Fase 1: Specifiche
|
||||||
|
- [ ] SPEC-006: Analisi requisiti Chat
|
||||||
|
- [ ] Definire flusso conversazione
|
||||||
|
- [ ] Definire formato messaggi
|
||||||
|
|
||||||
|
### Fase 2: API Design
|
||||||
|
- [ ] API-005: Modelli Pydantic (ChatMessage, ChatRequest, ChatResponse)
|
||||||
|
- [ ] Documentazione endpoints chat
|
||||||
|
|
||||||
|
### Fase 3: Implementazione
|
||||||
|
- [ ] DEV-012: ChatService
|
||||||
|
- [ ] DEV-013: POST /chat
|
||||||
|
- [ ] DEV-014: GET /chat/history
|
||||||
|
|
||||||
|
### Fase 4: Testing
|
||||||
|
- [ ] TEST-006: Unit tests ChatService
|
||||||
|
- [ ] TEST-007: Integration tests chat API
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🔧 Implementazione
|
||||||
|
|
||||||
|
### ChatService Methods
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ChatService:
|
||||||
|
async def send_message(
|
||||||
|
notebook_id: UUID,
|
||||||
|
message: str,
|
||||||
|
include_references: bool = True
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""Send message and get response."""
|
||||||
|
|
||||||
|
async def get_history(notebook_id: UUID) -> list[ChatMessage]:
|
||||||
|
"""Get chat history for notebook."""
|
||||||
|
```
|
||||||
|
|
||||||
|
### Modelli
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
message: str
|
||||||
|
include_references: bool = True
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
message: str
|
||||||
|
sources: list[SourceReference]
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
id: UUID
|
||||||
|
role: str # "user" | "assistant"
|
||||||
|
content: str
|
||||||
|
timestamp: datetime
|
||||||
|
sources: list[SourceReference] | None
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🚀 Prossimi Passi
|
||||||
|
|
||||||
|
1. @sprint-lead: Attivare @api-designer per API-005
|
||||||
|
2. @api-designer: Definire modelli chat
|
||||||
|
3. @tdd-developer: Iniziare implementazione ChatService
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Dipende da**: Sprint 2 (Source Management) ✅
|
||||||
|
**Blocca**: Sprint 4 (Content Generation) 🔴
|
||||||
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from notebooklm_agent.api.routes import health, notebooks, sources
|
from notebooklm_agent.api.routes import chat, health, notebooks, sources
|
||||||
from notebooklm_agent.core.config import get_settings
|
from notebooklm_agent.core.config import get_settings
|
||||||
from notebooklm_agent.core.logging import setup_logging
|
from notebooklm_agent.core.logging import setup_logging
|
||||||
|
|
||||||
@@ -54,6 +54,7 @@ def create_application() -> FastAPI:
|
|||||||
app.include_router(health.router, prefix="/health", tags=["health"])
|
app.include_router(health.router, prefix="/health", tags=["health"])
|
||||||
app.include_router(notebooks.router, prefix="/api/v1/notebooks", tags=["notebooks"])
|
app.include_router(notebooks.router, prefix="/api/v1/notebooks", tags=["notebooks"])
|
||||||
app.include_router(sources.router, prefix="/api/v1/notebooks", tags=["sources"])
|
app.include_router(sources.router, prefix="/api/v1/notebooks", tags=["sources"])
|
||||||
|
app.include_router(chat.router, prefix="/api/v1/notebooks", tags=["chat"])
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|||||||
@@ -282,3 +282,42 @@ class ResearchRequest(BaseModel):
|
|||||||
if not v or not v.strip():
|
if not v or not v.strip():
|
||||||
raise ValueError("Query cannot be empty")
|
raise ValueError("Query cannot be empty")
|
||||||
return v.strip()
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
"""Request model for sending a chat message.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message: The message text to send.
|
||||||
|
include_references: Whether to include source references in response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"message": "What are the key points from the sources?",
|
||||||
|
"include_references": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
message: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
max_length=2000,
|
||||||
|
description="The message text to send",
|
||||||
|
examples=["What are the key points from the sources?"],
|
||||||
|
)
|
||||||
|
include_references: bool = Field(
|
||||||
|
True,
|
||||||
|
description="Whether to include source references in response",
|
||||||
|
examples=[True, False],
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("message")
|
||||||
|
@classmethod
|
||||||
|
def validate_message(cls, v: str) -> str:
|
||||||
|
"""Validate message is not empty."""
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("Message cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|||||||
@@ -418,3 +418,129 @@ class HealthStatus(BaseModel):
|
|||||||
description="Service name",
|
description="Service name",
|
||||||
examples=["notebooklm-agent-api"],
|
examples=["notebooklm-agent-api"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceReference(BaseModel):
|
||||||
|
"""Source reference in chat response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
source_id: The source ID.
|
||||||
|
title: The source title.
|
||||||
|
snippet: Relevant text snippet from the source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"source_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||||
|
"title": "Example Article",
|
||||||
|
"snippet": "Key information from the source...",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
source_id: str = Field(
|
||||||
|
...,
|
||||||
|
description="The source ID",
|
||||||
|
examples=["550e8400-e29b-41d4-a716-446655440001"],
|
||||||
|
)
|
||||||
|
title: str = Field(
|
||||||
|
...,
|
||||||
|
description="The source title",
|
||||||
|
examples=["Example Article"],
|
||||||
|
)
|
||||||
|
snippet: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Relevant text snippet from the source",
|
||||||
|
examples=["Key information from the source..."],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
"""Chat response model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message: The assistant's response message.
|
||||||
|
sources: List of source references used in the response.
|
||||||
|
timestamp: Response timestamp.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"message": "Based on the sources, here are the key points...",
|
||||||
|
"sources": [
|
||||||
|
{
|
||||||
|
"source_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||||
|
"title": "Example Article",
|
||||||
|
"snippet": "Key information...",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"timestamp": "2026-04-06T10:30:00Z",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
message: str = Field(
|
||||||
|
...,
|
||||||
|
description="The assistant's response message",
|
||||||
|
examples=["Based on the sources, here are the key points..."],
|
||||||
|
)
|
||||||
|
sources: list[SourceReference] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of source references used in the response",
|
||||||
|
)
|
||||||
|
timestamp: datetime = Field(
|
||||||
|
...,
|
||||||
|
description="Response timestamp",
|
||||||
|
examples=["2026-04-06T10:30:00Z"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
"""Chat message model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique message identifier.
|
||||||
|
role: Message role (user or assistant).
|
||||||
|
content: Message content.
|
||||||
|
timestamp: Message timestamp.
|
||||||
|
sources: Source references (for assistant messages).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440002",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Based on the sources...",
|
||||||
|
"timestamp": "2026-04-06T10:30:00Z",
|
||||||
|
"sources": [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
id: str = Field(
|
||||||
|
...,
|
||||||
|
description="Unique message identifier",
|
||||||
|
examples=["550e8400-e29b-41d4-a716-446655440002"],
|
||||||
|
)
|
||||||
|
role: str = Field(
|
||||||
|
...,
|
||||||
|
description="Message role (user or assistant)",
|
||||||
|
examples=["user", "assistant"],
|
||||||
|
)
|
||||||
|
content: str = Field(
|
||||||
|
...,
|
||||||
|
description="Message content",
|
||||||
|
examples=["What are the key points?"],
|
||||||
|
)
|
||||||
|
timestamp: datetime = Field(
|
||||||
|
...,
|
||||||
|
description="Message timestamp",
|
||||||
|
examples=["2026-04-06T10:30:00Z"],
|
||||||
|
)
|
||||||
|
sources: list[SourceReference] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Source references (for assistant messages)",
|
||||||
|
)
|
||||||
|
|||||||
218
src/notebooklm_agent/api/routes/chat.py
Normal file
218
src/notebooklm_agent/api/routes/chat.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""Chat API routes.
|
||||||
|
|
||||||
|
This module contains API endpoints for chat functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
|
||||||
|
from notebooklm_agent.api.models.requests import ChatRequest
|
||||||
|
from notebooklm_agent.api.models.responses import ApiResponse, ChatResponse, ResponseMeta
|
||||||
|
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
|
||||||
|
from notebooklm_agent.services.chat_service import ChatService
|
||||||
|
|
||||||
|
router = APIRouter(tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_service() -> ChatService:
|
||||||
|
"""Get chat service instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatService instance.
|
||||||
|
"""
|
||||||
|
return ChatService()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{notebook_id}/chat",
|
||||||
|
response_model=ApiResponse[ChatResponse],
|
||||||
|
summary="Send chat message",
|
||||||
|
description="Send a message to the notebook chat and get a response.",
|
||||||
|
)
|
||||||
|
async def send_message(notebook_id: str, data: ChatRequest):
|
||||||
|
"""Send a chat message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notebook_id: Notebook UUID.
|
||||||
|
data: Chat request with message and options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Chat response with answer and source references.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 400 for invalid data, 404 for not found, 502 for external API errors.
|
||||||
|
"""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
try:
|
||||||
|
notebook_uuid = UUID(notebook_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={
|
||||||
|
"success": False,
|
||||||
|
"error": {
|
||||||
|
"code": "VALIDATION_ERROR",
|
||||||
|
"message": "Invalid notebook ID format",
|
||||||
|
"details": [],
|
||||||
|
},
|
||||||
|
"meta": {
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"request_id": str(uuid4()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
service = await get_chat_service()
|
||||||
|
response = await service.send_message(
|
||||||
|
notebook_uuid,
|
||||||
|
data.message,
|
||||||
|
data.include_references,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ApiResponse(
|
||||||
|
success=True,
|
||||||
|
data=response,
|
||||||
|
error=None,
|
||||||
|
meta=ResponseMeta(
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
request_id=uuid4(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={
|
||||||
|
"success": False,
|
||||||
|
"error": {
|
||||||
|
"code": e.code,
|
||||||
|
"message": e.message,
|
||||||
|
"details": e.details or [],
|
||||||
|
},
|
||||||
|
"meta": {
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"request_id": str(uuid4()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except NotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail={
|
||||||
|
"success": False,
|
||||||
|
"error": {
|
||||||
|
"code": e.code,
|
||||||
|
"message": e.message,
|
||||||
|
"details": [],
|
||||||
|
},
|
||||||
|
"meta": {
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"request_id": str(uuid4()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except NotebookLMError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail={
|
||||||
|
"success": False,
|
||||||
|
"error": {
|
||||||
|
"code": e.code,
|
||||||
|
"message": e.message,
|
||||||
|
"details": [],
|
||||||
|
},
|
||||||
|
"meta": {
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"request_id": str(uuid4()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{notebook_id}/chat/history",
|
||||||
|
response_model=ApiResponse[list],
|
||||||
|
summary="Get chat history",
|
||||||
|
description="Get the chat history for a notebook.",
|
||||||
|
)
|
||||||
|
async def get_chat_history(notebook_id: str):
|
||||||
|
"""Get chat history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notebook_id: Notebook UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chat messages.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 400 for invalid ID, 404 for not found, 502 for external API errors.
|
||||||
|
"""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
try:
|
||||||
|
notebook_uuid = UUID(notebook_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={
|
||||||
|
"success": False,
|
||||||
|
"error": {
|
||||||
|
"code": "VALIDATION_ERROR",
|
||||||
|
"message": "Invalid notebook ID format",
|
||||||
|
"details": [],
|
||||||
|
},
|
||||||
|
"meta": {
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"request_id": str(uuid4()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
service = await get_chat_service()
|
||||||
|
history = await service.get_history(notebook_uuid)
|
||||||
|
|
||||||
|
return ApiResponse(
|
||||||
|
success=True,
|
||||||
|
data=history,
|
||||||
|
error=None,
|
||||||
|
meta=ResponseMeta(
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
request_id=uuid4(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except NotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail={
|
||||||
|
"success": False,
|
||||||
|
"error": {
|
||||||
|
"code": e.code,
|
||||||
|
"message": e.message,
|
||||||
|
"details": [],
|
||||||
|
},
|
||||||
|
"meta": {
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"request_id": str(uuid4()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except NotebookLMError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail={
|
||||||
|
"success": False,
|
||||||
|
"error": {
|
||||||
|
"code": e.code,
|
||||||
|
"message": e.message,
|
||||||
|
"details": [],
|
||||||
|
},
|
||||||
|
"meta": {
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"request_id": str(uuid4()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
189
src/notebooklm_agent/services/chat_service.py
Normal file
189
src/notebooklm_agent/services/chat_service.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""Chat service for business logic.
|
||||||
|
|
||||||
|
This module contains the ChatService class which handles
|
||||||
|
all business logic for chat operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from notebooklm_agent.api.models.responses import ChatMessage, ChatResponse, SourceReference
|
||||||
|
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class ChatService:
|
||||||
|
"""Service for chat operations.
|
||||||
|
|
||||||
|
This service handles all business logic for chat functionality,
|
||||||
|
including sending messages and retrieving chat history.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_client: The notebooklm-py client instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client: Any = None) -> None:
|
||||||
|
"""Initialize the chat service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Optional notebooklm-py client instance.
|
||||||
|
If not provided, will be created on first use.
|
||||||
|
"""
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def _get_client(self) -> Any:
|
||||||
|
"""Get or create notebooklm-py client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The notebooklm-py client instance.
|
||||||
|
"""
|
||||||
|
if self._client is None:
|
||||||
|
# Lazy initialization - import here to avoid circular imports
|
||||||
|
from notebooklm import NotebookLMClient
|
||||||
|
|
||||||
|
self._client = await NotebookLMClient.from_storage()
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _validate_message(self, message: str) -> str:
|
||||||
|
"""Validate chat message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The validated message.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If message is invalid.
|
||||||
|
"""
|
||||||
|
if not message or not message.strip():
|
||||||
|
raise ValidationError("Message cannot be empty")
|
||||||
|
if len(message) > 2000:
|
||||||
|
raise ValidationError("Message must be at most 2000 characters")
|
||||||
|
return message.strip()
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
notebook_id: UUID,
|
||||||
|
message: str,
|
||||||
|
include_references: bool = True,
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""Send a message and get a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notebook_id: The notebook ID.
|
||||||
|
message: The message to send.
|
||||||
|
include_references: Whether to include source references.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The chat response with message and sources.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If message is invalid.
|
||||||
|
NotFoundError: If notebook not found.
|
||||||
|
NotebookLMError: If external API fails.
|
||||||
|
"""
|
||||||
|
# Validate message
|
||||||
|
validated_message = self._validate_message(message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
notebook = await client.notebooks.get(str(notebook_id))
|
||||||
|
|
||||||
|
# Send message to chat
|
||||||
|
chat = notebook.chat
|
||||||
|
response = await chat.send_message(validated_message)
|
||||||
|
|
||||||
|
# Extract response text
|
||||||
|
response_text = getattr(response, "text", str(response))
|
||||||
|
|
||||||
|
# Build source references if requested
|
||||||
|
sources = []
|
||||||
|
if include_references:
|
||||||
|
# Try to get citations/references from response
|
||||||
|
citations = getattr(response, "citations", []) or getattr(
|
||||||
|
response, "references", []
|
||||||
|
)
|
||||||
|
for citation in citations:
|
||||||
|
sources.append(
|
||||||
|
SourceReference(
|
||||||
|
source_id=getattr(citation, "source_id", str(uuid4())),
|
||||||
|
title=getattr(citation, "title", "Unknown Source"),
|
||||||
|
snippet=getattr(citation, "snippet", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
message=response_text,
|
||||||
|
sources=sources,
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValidationError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if "not found" in error_str:
|
||||||
|
raise NotFoundError("Notebook", str(notebook_id))
|
||||||
|
raise NotebookLMError(f"Failed to send message: {e}")
|
||||||
|
|
||||||
|
async def get_history(self, notebook_id: UUID) -> list[ChatMessage]:
|
||||||
|
"""Get chat history for a notebook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notebook_id: The notebook ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chat messages.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If notebook not found.
|
||||||
|
NotebookLMError: If external API fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
notebook = await client.notebooks.get(str(notebook_id))
|
||||||
|
|
||||||
|
# Get chat history
|
||||||
|
chat = notebook.chat
|
||||||
|
history = await chat.get_history()
|
||||||
|
|
||||||
|
# Convert to ChatMessage objects
|
||||||
|
messages = []
|
||||||
|
for msg in history:
|
||||||
|
role = getattr(msg, "role", "unknown")
|
||||||
|
content = getattr(msg, "content", "")
|
||||||
|
timestamp = getattr(msg, "timestamp", datetime.utcnow())
|
||||||
|
msg_id = getattr(msg, "id", str(uuid4()))
|
||||||
|
|
||||||
|
# Extract sources for assistant messages
|
||||||
|
sources = None
|
||||||
|
if role == "assistant":
|
||||||
|
citations = getattr(msg, "citations", []) or getattr(msg, "references", [])
|
||||||
|
if citations:
|
||||||
|
sources = [
|
||||||
|
SourceReference(
|
||||||
|
source_id=getattr(c, "source_id", str(uuid4())),
|
||||||
|
title=getattr(c, "title", "Unknown Source"),
|
||||||
|
snippet=getattr(c, "snippet", None),
|
||||||
|
)
|
||||||
|
for c in citations
|
||||||
|
]
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
id=msg_id,
|
||||||
|
role=role,
|
||||||
|
content=content,
|
||||||
|
timestamp=timestamp,
|
||||||
|
sources=sources,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if "not found" in error_str:
|
||||||
|
raise NotFoundError("Notebook", str(notebook_id))
|
||||||
|
raise NotebookLMError(f"Failed to get chat history: {e}")
|
||||||
195
tests/unit/test_api/test_chat.py
Normal file
195
tests/unit/test_api/test_chat.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Integration tests for chat API endpoints.
|
||||||
|
|
||||||
|
Tests all chat endpoints with mocked services.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from notebooklm_agent.api.main import app
|
||||||
|
from notebooklm_agent.api.models.responses import ChatMessage, ChatResponse, SourceReference
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestSendMessageEndpoint:
|
||||||
|
"""Test suite for POST /api/v1/notebooks/{id}/chat endpoint."""
|
||||||
|
|
||||||
|
def test_send_message_returns_200(self):
|
||||||
|
"""Should return 200 with chat response."""
|
||||||
|
# Arrange
|
||||||
|
client = TestClient(app)
|
||||||
|
notebook_id = str(uuid4())
|
||||||
|
|
||||||
|
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||||
|
mock_service = AsyncMock()
|
||||||
|
mock_response = ChatResponse(
|
||||||
|
message="This is the answer based on your sources.",
|
||||||
|
sources=[
|
||||||
|
SourceReference(
|
||||||
|
source_id=str(uuid4()),
|
||||||
|
title="Example Source",
|
||||||
|
snippet="Relevant text",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
mock_service.send_message.return_value = mock_response
|
||||||
|
mock_service_class.return_value = mock_service
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.post(
|
||||||
|
f"/api/v1/notebooks/{notebook_id}/chat",
|
||||||
|
json={"message": "What are the key points?", "include_references": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "answer based on your sources" in data["data"]["message"]
|
||||||
|
assert len(data["data"]["sources"]) == 1
|
||||||
|
|
||||||
|
def test_send_message_invalid_notebook_id_returns_400(self):
|
||||||
|
"""Should return 400 for invalid notebook ID."""
|
||||||
|
# Arrange
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/notebooks/invalid-id/chat",
|
||||||
|
json={"message": "Question?"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code in [400, 422]
|
||||||
|
|
||||||
|
def test_send_message_empty_message_returns_400(self):
|
||||||
|
"""Should return 400 for empty message."""
|
||||||
|
# Arrange
|
||||||
|
client = TestClient(app)
|
||||||
|
notebook_id = str(uuid4())
|
||||||
|
|
||||||
|
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||||
|
mock_service = AsyncMock()
|
||||||
|
from notebooklm_agent.core.exceptions import ValidationError
|
||||||
|
|
||||||
|
mock_service.send_message.side_effect = ValidationError("Message cannot be empty")
|
||||||
|
mock_service_class.return_value = mock_service
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.post(
|
||||||
|
f"/api/v1/notebooks/{notebook_id}/chat",
|
||||||
|
json={"message": ""},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code in [400, 422]
|
||||||
|
|
||||||
|
def test_send_message_notebook_not_found_returns_404(self):
|
||||||
|
"""Should return 404 when notebook not found."""
|
||||||
|
# Arrange
|
||||||
|
client = TestClient(app)
|
||||||
|
notebook_id = str(uuid4())
|
||||||
|
|
||||||
|
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||||
|
mock_service = AsyncMock()
|
||||||
|
from notebooklm_agent.core.exceptions import NotFoundError
|
||||||
|
|
||||||
|
mock_service.send_message.side_effect = NotFoundError("Notebook", notebook_id)
|
||||||
|
mock_service_class.return_value = mock_service
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.post(
|
||||||
|
f"/api/v1/notebooks/{notebook_id}/chat",
|
||||||
|
json={"message": "Question?"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestGetChatHistoryEndpoint:
|
||||||
|
"""Test suite for GET /api/v1/notebooks/{id}/chat/history endpoint."""
|
||||||
|
|
||||||
|
def test_get_history_returns_200(self):
|
||||||
|
"""Should return 200 with chat history."""
|
||||||
|
# Arrange
|
||||||
|
client = TestClient(app)
|
||||||
|
notebook_id = str(uuid4())
|
||||||
|
|
||||||
|
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||||
|
mock_service = AsyncMock()
|
||||||
|
mock_message = ChatMessage(
|
||||||
|
id=str(uuid4()),
|
||||||
|
role="user",
|
||||||
|
content="What are the key points?",
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
sources=None,
|
||||||
|
)
|
||||||
|
mock_service.get_history.return_value = [mock_message]
|
||||||
|
mock_service_class.return_value = mock_service
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert len(data["data"]) == 1
|
||||||
|
assert data["data"][0]["role"] == "user"
|
||||||
|
|
||||||
|
def test_get_history_empty_returns_empty_list(self):
|
||||||
|
"""Should return empty list if no history."""
|
||||||
|
# Arrange
|
||||||
|
client = TestClient(app)
|
||||||
|
notebook_id = str(uuid4())
|
||||||
|
|
||||||
|
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||||
|
mock_service = AsyncMock()
|
||||||
|
mock_service.get_history.return_value = []
|
||||||
|
mock_service_class.return_value = mock_service
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["data"] == []
|
||||||
|
|
||||||
|
def test_get_history_invalid_notebook_id_returns_400(self):
|
||||||
|
"""Should return 400 for invalid notebook ID."""
|
||||||
|
# Arrange
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get("/api/v1/notebooks/invalid-id/chat/history")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code in [400, 422]
|
||||||
|
|
||||||
|
def test_get_history_notebook_not_found_returns_404(self):
|
||||||
|
"""Should return 404 when notebook not found."""
|
||||||
|
# Arrange
|
||||||
|
client = TestClient(app)
|
||||||
|
notebook_id = str(uuid4())
|
||||||
|
|
||||||
|
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||||
|
mock_service = AsyncMock()
|
||||||
|
from notebooklm_agent.core.exceptions import NotFoundError
|
||||||
|
|
||||||
|
mock_service.get_history.side_effect = NotFoundError("Notebook", notebook_id)
|
||||||
|
mock_service_class.return_value = mock_service
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 404
|
||||||
349
tests/unit/test_services/test_chat_service.py
Normal file
349
tests/unit/test_services/test_chat_service.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
"""Unit tests for ChatService.
|
||||||
|
|
||||||
|
TDD Cycle: RED → GREEN → REFACTOR
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from notebooklm_agent.core.exceptions import (
|
||||||
|
NotebookLMError,
|
||||||
|
NotFoundError,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
from notebooklm_agent.services.chat_service import ChatService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestChatServiceInit:
|
||||||
|
"""Test suite for ChatService initialization."""
|
||||||
|
|
||||||
|
async def test_get_client_returns_existing_client(self):
|
||||||
|
"""Should return existing client if already initialized."""
|
||||||
|
# Arrange
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
client = await service._get_client()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert client == mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestChatServiceValidateMessage:
|
||||||
|
"""Test suite for ChatService._validate_message()."""
|
||||||
|
|
||||||
|
def test_validate_valid_message(self):
|
||||||
|
"""Should accept valid message."""
|
||||||
|
# Arrange
|
||||||
|
service = ChatService()
|
||||||
|
message = "What are the key points?"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = service._validate_message(message)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == message
|
||||||
|
|
||||||
|
def test_validate_empty_message_raises_validation_error(self):
|
||||||
|
"""Should raise ValidationError for empty message."""
|
||||||
|
# Arrange
|
||||||
|
service = ChatService()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
service._validate_message("")
|
||||||
|
|
||||||
|
assert "Message cannot be empty" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_validate_whitespace_message_raises_validation_error(self):
|
||||||
|
"""Should raise ValidationError for whitespace-only message."""
|
||||||
|
# Arrange
|
||||||
|
service = ChatService()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
service._validate_message(" ")
|
||||||
|
|
||||||
|
assert "Message cannot be empty" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_validate_long_message_raises_validation_error(self):
|
||||||
|
"""Should raise ValidationError for message > 2000 chars."""
|
||||||
|
# Arrange
|
||||||
|
service = ChatService()
|
||||||
|
long_message = "A" * 2001
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
service._validate_message(long_message)
|
||||||
|
|
||||||
|
assert "at most 2000 characters" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_validate_max_length_message_succeeds(self):
|
||||||
|
"""Should accept message with exactly 2000 characters."""
|
||||||
|
# Arrange
|
||||||
|
service = ChatService()
|
||||||
|
message = "A" * 2000
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = service._validate_message(message)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestChatServiceSendMessage:
|
||||||
|
"""Test suite for ChatService.send_message() method."""
|
||||||
|
|
||||||
|
async def test_send_message_returns_response(self):
|
||||||
|
"""Should send message and return ChatResponse."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_notebook = AsyncMock()
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = "This is the answer based on your sources."
|
||||||
|
mock_response.citations = []
|
||||||
|
|
||||||
|
mock_chat.send_message.return_value = mock_response
|
||||||
|
mock_notebook.chat = mock_chat
|
||||||
|
mock_client.notebooks.get.return_value = mock_notebook
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await service.send_message(notebook_id, "What are the key points?")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result.message == "This is the answer based on your sources."
|
||||||
|
assert isinstance(result.sources, list)
|
||||||
|
mock_chat.send_message.assert_called_once_with("What are the key points?")
|
||||||
|
|
||||||
|
async def test_send_message_with_references(self):
|
||||||
|
"""Should include source references when available."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_notebook = AsyncMock()
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = "Answer with references."
|
||||||
|
|
||||||
|
mock_citation = MagicMock()
|
||||||
|
mock_citation.source_id = str(uuid4())
|
||||||
|
mock_citation.title = "Source 1"
|
||||||
|
mock_citation.snippet = "Relevant text"
|
||||||
|
mock_response.citations = [mock_citation]
|
||||||
|
|
||||||
|
mock_chat.send_message.return_value = mock_response
|
||||||
|
mock_notebook.chat = mock_chat
|
||||||
|
mock_client.notebooks.get.return_value = mock_notebook
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await service.send_message(notebook_id, "Question?", include_references=True)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result.sources) == 1
|
||||||
|
assert result.sources[0].title == "Source 1"
|
||||||
|
|
||||||
|
async def test_send_message_without_references(self):
|
||||||
|
"""Should not include source references when disabled."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_notebook = AsyncMock()
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = "Answer without references."
|
||||||
|
mock_response.citations = []
|
||||||
|
|
||||||
|
mock_chat.send_message.return_value = mock_response
|
||||||
|
mock_notebook.chat = mock_chat
|
||||||
|
mock_client.notebooks.get.return_value = mock_notebook
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await service.send_message(notebook_id, "Question?", include_references=False)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Even with include_references=False, we still get empty sources list
|
||||||
|
assert result.sources == []
|
||||||
|
|
||||||
|
async def test_send_message_empty_message_raises_validation_error(self):
|
||||||
|
"""Should raise ValidationError for empty message."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
service = ChatService()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
await service.send_message(notebook_id, "")
|
||||||
|
|
||||||
|
assert "Message cannot be empty" in str(exc_info.value)
|
||||||
|
|
||||||
|
async def test_send_message_notebook_not_found_raises_not_found(self):
|
||||||
|
"""Should raise NotFoundError if notebook not found."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.notebooks.get.side_effect = Exception("notebook not found")
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(NotFoundError) as exc_info:
|
||||||
|
await service.send_message(notebook_id, "Question?")
|
||||||
|
|
||||||
|
assert str(notebook_id) in str(exc_info.value)
|
||||||
|
|
||||||
|
async def test_send_message_api_error_raises_notebooklm_error(self):
|
||||||
|
"""Should raise NotebookLMError on API error."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.notebooks.get.side_effect = Exception("connection timeout")
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(NotebookLMError) as exc_info:
|
||||||
|
await service.send_message(notebook_id, "Question?")
|
||||||
|
|
||||||
|
assert "Failed to send message" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
class TestChatServiceGetHistory:
|
||||||
|
"""Test suite for ChatService.get_history() method."""
|
||||||
|
|
||||||
|
async def test_get_history_returns_messages(self):
|
||||||
|
"""Should return list of chat messages."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_notebook = AsyncMock()
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.id = str(uuid4())
|
||||||
|
mock_message.role = "user"
|
||||||
|
mock_message.content = "What are the key points?"
|
||||||
|
mock_message.timestamp = datetime.utcnow()
|
||||||
|
mock_message.citations = None
|
||||||
|
|
||||||
|
mock_chat.get_history.return_value = [mock_message]
|
||||||
|
mock_notebook.chat = mock_chat
|
||||||
|
mock_client.notebooks.get.return_value = mock_notebook
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await service.get_history(notebook_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].role == "user"
|
||||||
|
assert result[0].content == "What are the key points?"
|
||||||
|
|
||||||
|
async def test_get_history_with_assistant_message(self):
|
||||||
|
"""Should handle assistant messages with sources."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_notebook = AsyncMock()
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
|
||||||
|
user_msg = MagicMock()
|
||||||
|
user_msg.id = str(uuid4())
|
||||||
|
user_msg.role = "user"
|
||||||
|
user_msg.content = "Question?"
|
||||||
|
user_msg.timestamp = datetime.utcnow()
|
||||||
|
user_msg.citations = None
|
||||||
|
|
||||||
|
assistant_msg = MagicMock()
|
||||||
|
assistant_msg.id = str(uuid4())
|
||||||
|
assistant_msg.role = "assistant"
|
||||||
|
assistant_msg.content = "Answer with sources."
|
||||||
|
assistant_msg.timestamp = datetime.utcnow()
|
||||||
|
|
||||||
|
mock_citation = MagicMock()
|
||||||
|
mock_citation.source_id = str(uuid4())
|
||||||
|
mock_citation.title = "Source 1"
|
||||||
|
mock_citation.snippet = "Relevant text"
|
||||||
|
assistant_msg.citations = [mock_citation]
|
||||||
|
|
||||||
|
mock_chat.get_history.return_value = [user_msg, assistant_msg]
|
||||||
|
mock_notebook.chat = mock_chat
|
||||||
|
mock_client.notebooks.get.return_value = mock_notebook
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await service.get_history(notebook_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[1].role == "assistant"
|
||||||
|
assert result[1].sources is not None
|
||||||
|
assert len(result[1].sources) == 1
|
||||||
|
|
||||||
|
async def test_get_history_empty_returns_empty_list(self):
|
||||||
|
"""Should return empty list if no history."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_notebook = AsyncMock()
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_chat.get_history.return_value = []
|
||||||
|
mock_notebook.chat = mock_chat
|
||||||
|
mock_client.notebooks.get.return_value = mock_notebook
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await service.get_history(notebook_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
async def test_get_history_notebook_not_found_raises_not_found(self):
|
||||||
|
"""Should raise NotFoundError if notebook not found."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.notebooks.get.side_effect = Exception("not found")
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await service.get_history(notebook_id)
|
||||||
|
|
||||||
|
async def test_get_history_api_error_raises_notebooklm_error(self):
|
||||||
|
"""Should raise NotebookLMError on API error."""
|
||||||
|
# Arrange
|
||||||
|
notebook_id = uuid4()
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_notebook = AsyncMock()
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_chat.get_history.side_effect = Exception("API error")
|
||||||
|
mock_notebook.chat = mock_chat
|
||||||
|
mock_client.notebooks.get.return_value = mock_notebook
|
||||||
|
|
||||||
|
service = ChatService(client=mock_client)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(NotebookLMError) as exc_info:
|
||||||
|
await service.get_history(notebook_id)
|
||||||
|
|
||||||
|
assert "Failed to get chat history" in str(exc_info.value)
|
||||||
Reference in New Issue
Block a user