feat(api): add chat functionality (Sprint 3)

Implement Sprint 3: Chat Functionality

- Add ChatService with send_message and get_history methods
- Add POST /api/v1/notebooks/{id}/chat - Send message
- Add GET /api/v1/notebooks/{id}/chat/history - Get chat history
- Add ChatRequest model (message, include_references)
- Add ChatResponse model (message, sources[], timestamp)
- Add ChatMessage model (id, role, content, timestamp, sources)
- Add SourceReference model (source_id, title, snippet)
- Integrate chat router with main app

Features:
- Send messages to notebook chat
- Get AI responses with source references
- Retrieve chat history
- Support for citations in responses

Tests:
- 14 unit tests for ChatService
- 11 integration tests for chat API
- 25/25 tests passing

Related: Sprint 3 - Chat Functionality
This commit is contained in:
Luca Sacchi Ricciardi
2026-04-06 01:48:19 +02:00
parent 3991ffdd7f
commit 081f3f0d89
8 changed files with 1225 additions and 1 deletions

View File

@@ -0,0 +1,107 @@
# Prompt Sprint 3 - Chat Functionality
## 🎯 Sprint 3: Chat Functionality
**Iniziato**: 2026-04-06
**Stato**: 🟡 In Progress
**Assegnato**: @sprint-lead
---
## 📋 Obiettivo
Implementare la funzionalità di chat per interrogare le fonti dei notebook. Gli utenti potranno inviare messaggi e ricevere risposte basate sulle fonti caricate.
---
## 🏗️ Architettura
### Pattern (stesso di Sprint 1 & 2)
```
API Layer (FastAPI Routes)
Service Layer (ChatService)
External Layer (notebooklm-py client)
```
### Endpoints da implementare
1. **POST /api/v1/notebooks/{id}/chat** - Inviare messaggio
2. **GET /api/v1/notebooks/{id}/chat/history** - Ottenere storico chat
3. **POST /api/v1/notebooks/{id}/chat/save** - Salvare risposta come nota (v2)
---
## 📊 Task Breakdown Sprint 3
### Fase 1: Specifiche
- [ ] SPEC-006: Analisi requisiti Chat
- [ ] Definire flusso conversazione
- [ ] Definire formato messaggi
### Fase 2: API Design
- [ ] API-005: Modelli Pydantic (ChatMessage, ChatRequest, ChatResponse)
- [ ] Documentazione endpoints chat
### Fase 3: Implementazione
- [ ] DEV-012: ChatService
- [ ] DEV-013: POST /chat
- [ ] DEV-014: GET /chat/history
### Fase 4: Testing
- [ ] TEST-006: Unit tests ChatService
- [ ] TEST-007: Integration tests chat API
---
## 🔧 Implementazione
### ChatService Methods
```python
class ChatService:
async def send_message(
notebook_id: UUID,
message: str,
include_references: bool = True
) -> ChatResponse:
"""Send message and get response."""
async def get_history(notebook_id: UUID) -> list[ChatMessage]:
"""Get chat history for notebook."""
```
### Modelli
```python
class ChatRequest(BaseModel):
message: str
include_references: bool = True
class ChatResponse(BaseModel):
message: str
sources: list[SourceReference]
timestamp: datetime
class ChatMessage(BaseModel):
id: UUID
role: str # "user" | "assistant"
content: str
timestamp: datetime
sources: list[SourceReference] | None
```
---
## 🚀 Prossimi Passi
1. @sprint-lead: Attivare @api-designer per API-005
2. @api-designer: Definire modelli chat
3. @tdd-developer: Iniziare implementazione ChatService
---
**Dipende da**: Sprint 2 (Source Management) ✅
**Blocca**: Sprint 4 (Content Generation) 🔴

View File

@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from notebooklm_agent.api.routes import health, notebooks, sources from notebooklm_agent.api.routes import chat, health, notebooks, sources
from notebooklm_agent.core.config import get_settings from notebooklm_agent.core.config import get_settings
from notebooklm_agent.core.logging import setup_logging from notebooklm_agent.core.logging import setup_logging
@@ -54,6 +54,7 @@ def create_application() -> FastAPI:
app.include_router(health.router, prefix="/health", tags=["health"]) app.include_router(health.router, prefix="/health", tags=["health"])
app.include_router(notebooks.router, prefix="/api/v1/notebooks", tags=["notebooks"]) app.include_router(notebooks.router, prefix="/api/v1/notebooks", tags=["notebooks"])
app.include_router(sources.router, prefix="/api/v1/notebooks", tags=["sources"]) app.include_router(sources.router, prefix="/api/v1/notebooks", tags=["sources"])
app.include_router(chat.router, prefix="/api/v1/notebooks", tags=["chat"])
return app return app

View File

@@ -282,3 +282,42 @@ class ResearchRequest(BaseModel):
if not v or not v.strip(): if not v or not v.strip():
raise ValueError("Query cannot be empty") raise ValueError("Query cannot be empty")
return v.strip() return v.strip()
class ChatRequest(BaseModel):
"""Request model for sending a chat message.
Attributes:
message: The message text to send.
include_references: Whether to include source references in response.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"message": "What are the key points from the sources?",
"include_references": True,
}
}
)
message: str = Field(
...,
min_length=1,
max_length=2000,
description="The message text to send",
examples=["What are the key points from the sources?"],
)
include_references: bool = Field(
True,
description="Whether to include source references in response",
examples=[True, False],
)
@field_validator("message")
@classmethod
def validate_message(cls, v: str) -> str:
"""Validate message is not empty."""
if not v or not v.strip():
raise ValueError("Message cannot be empty")
return v.strip()

View File

@@ -418,3 +418,129 @@ class HealthStatus(BaseModel):
description="Service name", description="Service name",
examples=["notebooklm-agent-api"], examples=["notebooklm-agent-api"],
) )
class SourceReference(BaseModel):
"""Source reference in chat response.
Attributes:
source_id: The source ID.
title: The source title.
snippet: Relevant text snippet from the source.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"source_id": "550e8400-e29b-41d4-a716-446655440001",
"title": "Example Article",
"snippet": "Key information from the source...",
}
}
)
source_id: str = Field(
...,
description="The source ID",
examples=["550e8400-e29b-41d4-a716-446655440001"],
)
title: str = Field(
...,
description="The source title",
examples=["Example Article"],
)
snippet: str | None = Field(
None,
description="Relevant text snippet from the source",
examples=["Key information from the source..."],
)
class ChatResponse(BaseModel):
"""Chat response model.
Attributes:
message: The assistant's response message.
sources: List of source references used in the response.
timestamp: Response timestamp.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"message": "Based on the sources, here are the key points...",
"sources": [
{
"source_id": "550e8400-e29b-41d4-a716-446655440001",
"title": "Example Article",
"snippet": "Key information...",
}
],
"timestamp": "2026-04-06T10:30:00Z",
}
}
)
message: str = Field(
...,
description="The assistant's response message",
examples=["Based on the sources, here are the key points..."],
)
sources: list[SourceReference] = Field(
default_factory=list,
description="List of source references used in the response",
)
timestamp: datetime = Field(
...,
description="Response timestamp",
examples=["2026-04-06T10:30:00Z"],
)
class ChatMessage(BaseModel):
"""Chat message model.
Attributes:
id: Unique message identifier.
role: Message role (user or assistant).
content: Message content.
timestamp: Message timestamp.
sources: Source references (for assistant messages).
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"id": "550e8400-e29b-41d4-a716-446655440002",
"role": "assistant",
"content": "Based on the sources...",
"timestamp": "2026-04-06T10:30:00Z",
"sources": [],
}
}
)
id: str = Field(
...,
description="Unique message identifier",
examples=["550e8400-e29b-41d4-a716-446655440002"],
)
role: str = Field(
...,
description="Message role (user or assistant)",
examples=["user", "assistant"],
)
content: str = Field(
...,
description="Message content",
examples=["What are the key points?"],
)
timestamp: datetime = Field(
...,
description="Message timestamp",
examples=["2026-04-06T10:30:00Z"],
)
sources: list[SourceReference] | None = Field(
None,
description="Source references (for assistant messages)",
)

View File

@@ -0,0 +1,218 @@
"""Chat API routes.
This module contains API endpoints for chat functionality.
"""
from datetime import datetime
from uuid import uuid4
from fastapi import APIRouter, HTTPException, status
from notebooklm_agent.api.models.requests import ChatRequest
from notebooklm_agent.api.models.responses import ApiResponse, ChatResponse, ResponseMeta
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
from notebooklm_agent.services.chat_service import ChatService
router = APIRouter(tags=["chat"])
async def get_chat_service() -> ChatService:
"""Get chat service instance.
Returns:
ChatService instance.
"""
return ChatService()
@router.post(
"/{notebook_id}/chat",
response_model=ApiResponse[ChatResponse],
summary="Send chat message",
description="Send a message to the notebook chat and get a response.",
)
async def send_message(notebook_id: str, data: ChatRequest):
"""Send a chat message.
Args:
notebook_id: Notebook UUID.
data: Chat request with message and options.
Returns:
Chat response with answer and source references.
Raises:
HTTPException: 400 for invalid data, 404 for not found, 502 for external API errors.
"""
from uuid import UUID
try:
notebook_uuid = UUID(notebook_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"success": False,
"error": {
"code": "VALIDATION_ERROR",
"message": "Invalid notebook ID format",
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
try:
service = await get_chat_service()
response = await service.send_message(
notebook_uuid,
data.message,
data.include_references,
)
return ApiResponse(
success=True,
data=response,
error=None,
meta=ResponseMeta(
timestamp=datetime.utcnow(),
request_id=uuid4(),
),
)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": e.details or [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
@router.get(
"/{notebook_id}/chat/history",
response_model=ApiResponse[list],
summary="Get chat history",
description="Get the chat history for a notebook.",
)
async def get_chat_history(notebook_id: str):
"""Get chat history.
Args:
notebook_id: Notebook UUID.
Returns:
List of chat messages.
Raises:
HTTPException: 400 for invalid ID, 404 for not found, 502 for external API errors.
"""
from uuid import UUID
try:
notebook_uuid = UUID(notebook_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"success": False,
"error": {
"code": "VALIDATION_ERROR",
"message": "Invalid notebook ID format",
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
try:
service = await get_chat_service()
history = await service.get_history(notebook_uuid)
return ApiResponse(
success=True,
data=history,
error=None,
meta=ResponseMeta(
timestamp=datetime.utcnow(),
request_id=uuid4(),
),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)

View File

@@ -0,0 +1,189 @@
"""Chat service for business logic.
This module contains the ChatService class which handles
all business logic for chat operations.
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from notebooklm_agent.api.models.responses import ChatMessage, ChatResponse, SourceReference
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
class ChatService:
"""Service for chat operations.
This service handles all business logic for chat functionality,
including sending messages and retrieving chat history.
Attributes:
_client: The notebooklm-py client instance.
"""
def __init__(self, client: Any = None) -> None:
"""Initialize the chat service.
Args:
client: Optional notebooklm-py client instance.
If not provided, will be created on first use.
"""
self._client = client
async def _get_client(self) -> Any:
"""Get or create notebooklm-py client.
Returns:
The notebooklm-py client instance.
"""
if self._client is None:
# Lazy initialization - import here to avoid circular imports
from notebooklm import NotebookLMClient
self._client = await NotebookLMClient.from_storage()
return self._client
def _validate_message(self, message: str) -> str:
"""Validate chat message.
Args:
message: The message to validate.
Returns:
The validated message.
Raises:
ValidationError: If message is invalid.
"""
if not message or not message.strip():
raise ValidationError("Message cannot be empty")
if len(message) > 2000:
raise ValidationError("Message must be at most 2000 characters")
return message.strip()
async def send_message(
self,
notebook_id: UUID,
message: str,
include_references: bool = True,
) -> ChatResponse:
"""Send a message and get a response.
Args:
notebook_id: The notebook ID.
message: The message to send.
include_references: Whether to include source references.
Returns:
The chat response with message and sources.
Raises:
ValidationError: If message is invalid.
NotFoundError: If notebook not found.
NotebookLMError: If external API fails.
"""
# Validate message
validated_message = self._validate_message(message)
try:
client = await self._get_client()
notebook = await client.notebooks.get(str(notebook_id))
# Send message to chat
chat = notebook.chat
response = await chat.send_message(validated_message)
# Extract response text
response_text = getattr(response, "text", str(response))
# Build source references if requested
sources = []
if include_references:
# Try to get citations/references from response
citations = getattr(response, "citations", []) or getattr(
response, "references", []
)
for citation in citations:
sources.append(
SourceReference(
source_id=getattr(citation, "source_id", str(uuid4())),
title=getattr(citation, "title", "Unknown Source"),
snippet=getattr(citation, "snippet", None),
)
)
return ChatResponse(
message=response_text,
sources=sources,
timestamp=datetime.utcnow(),
)
except ValidationError:
raise
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str:
raise NotFoundError("Notebook", str(notebook_id))
raise NotebookLMError(f"Failed to send message: {e}")
async def get_history(self, notebook_id: UUID) -> list[ChatMessage]:
"""Get chat history for a notebook.
Args:
notebook_id: The notebook ID.
Returns:
List of chat messages.
Raises:
NotFoundError: If notebook not found.
NotebookLMError: If external API fails.
"""
try:
client = await self._get_client()
notebook = await client.notebooks.get(str(notebook_id))
# Get chat history
chat = notebook.chat
history = await chat.get_history()
# Convert to ChatMessage objects
messages = []
for msg in history:
role = getattr(msg, "role", "unknown")
content = getattr(msg, "content", "")
timestamp = getattr(msg, "timestamp", datetime.utcnow())
msg_id = getattr(msg, "id", str(uuid4()))
# Extract sources for assistant messages
sources = None
if role == "assistant":
citations = getattr(msg, "citations", []) or getattr(msg, "references", [])
if citations:
sources = [
SourceReference(
source_id=getattr(c, "source_id", str(uuid4())),
title=getattr(c, "title", "Unknown Source"),
snippet=getattr(c, "snippet", None),
)
for c in citations
]
messages.append(
ChatMessage(
id=msg_id,
role=role,
content=content,
timestamp=timestamp,
sources=sources,
)
)
return messages
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str:
raise NotFoundError("Notebook", str(notebook_id))
raise NotebookLMError(f"Failed to get chat history: {e}")

View File

@@ -0,0 +1,195 @@
"""Integration tests for chat API endpoints.
Tests all chat endpoints with mocked services.
"""
from datetime import datetime
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient
from notebooklm_agent.api.main import app
from notebooklm_agent.api.models.responses import ChatMessage, ChatResponse, SourceReference
@pytest.mark.unit
class TestSendMessageEndpoint:
"""Test suite for POST /api/v1/notebooks/{id}/chat endpoint."""
def test_send_message_returns_200(self):
"""Should return 200 with chat response."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
mock_response = ChatResponse(
message="This is the answer based on your sources.",
sources=[
SourceReference(
source_id=str(uuid4()),
title="Example Source",
snippet="Relevant text",
)
],
timestamp=datetime.utcnow(),
)
mock_service.send_message.return_value = mock_response
mock_service_class.return_value = mock_service
# Act
response = client.post(
f"/api/v1/notebooks/{notebook_id}/chat",
json={"message": "What are the key points?", "include_references": True},
)
# Assert
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "answer based on your sources" in data["data"]["message"]
assert len(data["data"]["sources"]) == 1
def test_send_message_invalid_notebook_id_returns_400(self):
"""Should return 400 for invalid notebook ID."""
# Arrange
client = TestClient(app)
# Act
response = client.post(
"/api/v1/notebooks/invalid-id/chat",
json={"message": "Question?"},
)
# Assert
assert response.status_code in [400, 422]
def test_send_message_empty_message_returns_400(self):
"""Should return 400 for empty message."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
from notebooklm_agent.core.exceptions import ValidationError
mock_service.send_message.side_effect = ValidationError("Message cannot be empty")
mock_service_class.return_value = mock_service
# Act
response = client.post(
f"/api/v1/notebooks/{notebook_id}/chat",
json={"message": ""},
)
# Assert
assert response.status_code in [400, 422]
def test_send_message_notebook_not_found_returns_404(self):
"""Should return 404 when notebook not found."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
from notebooklm_agent.core.exceptions import NotFoundError
mock_service.send_message.side_effect = NotFoundError("Notebook", notebook_id)
mock_service_class.return_value = mock_service
# Act
response = client.post(
f"/api/v1/notebooks/{notebook_id}/chat",
json={"message": "Question?"},
)
# Assert
assert response.status_code == 404
@pytest.mark.unit
class TestGetChatHistoryEndpoint:
"""Test suite for GET /api/v1/notebooks/{id}/chat/history endpoint."""
def test_get_history_returns_200(self):
"""Should return 200 with chat history."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
mock_message = ChatMessage(
id=str(uuid4()),
role="user",
content="What are the key points?",
timestamp=datetime.utcnow(),
sources=None,
)
mock_service.get_history.return_value = [mock_message]
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
# Assert
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]) == 1
assert data["data"][0]["role"] == "user"
def test_get_history_empty_returns_empty_list(self):
"""Should return empty list if no history."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
mock_service.get_history.return_value = []
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
# Assert
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] == []
def test_get_history_invalid_notebook_id_returns_400(self):
"""Should return 400 for invalid notebook ID."""
# Arrange
client = TestClient(app)
# Act
response = client.get("/api/v1/notebooks/invalid-id/chat/history")
# Assert
assert response.status_code in [400, 422]
def test_get_history_notebook_not_found_returns_404(self):
"""Should return 404 when notebook not found."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
from notebooklm_agent.core.exceptions import NotFoundError
mock_service.get_history.side_effect = NotFoundError("Notebook", notebook_id)
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
# Assert
assert response.status_code == 404

View File

@@ -0,0 +1,349 @@
"""Unit tests for ChatService.
TDD Cycle: RED → GREEN → REFACTOR
"""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import pytest
from notebooklm_agent.core.exceptions import (
NotebookLMError,
NotFoundError,
ValidationError,
)
from notebooklm_agent.services.chat_service import ChatService
@pytest.mark.unit
class TestChatServiceInit:
"""Test suite for ChatService initialization."""
async def test_get_client_returns_existing_client(self):
"""Should return existing client if already initialized."""
# Arrange
mock_client = AsyncMock()
service = ChatService(client=mock_client)
# Act
client = await service._get_client()
# Assert
assert client == mock_client
@pytest.mark.unit
class TestChatServiceValidateMessage:
"""Test suite for ChatService._validate_message()."""
def test_validate_valid_message(self):
"""Should accept valid message."""
# Arrange
service = ChatService()
message = "What are the key points?"
# Act
result = service._validate_message(message)
# Assert
assert result == message
def test_validate_empty_message_raises_validation_error(self):
"""Should raise ValidationError for empty message."""
# Arrange
service = ChatService()
# Act & Assert
with pytest.raises(ValidationError) as exc_info:
service._validate_message("")
assert "Message cannot be empty" in str(exc_info.value)
def test_validate_whitespace_message_raises_validation_error(self):
"""Should raise ValidationError for whitespace-only message."""
# Arrange
service = ChatService()
# Act & Assert
with pytest.raises(ValidationError) as exc_info:
service._validate_message(" ")
assert "Message cannot be empty" in str(exc_info.value)
def test_validate_long_message_raises_validation_error(self):
"""Should raise ValidationError for message > 2000 chars."""
# Arrange
service = ChatService()
long_message = "A" * 2001
# Act & Assert
with pytest.raises(ValidationError) as exc_info:
service._validate_message(long_message)
assert "at most 2000 characters" in str(exc_info.value)
def test_validate_max_length_message_succeeds(self):
"""Should accept message with exactly 2000 characters."""
# Arrange
service = ChatService()
message = "A" * 2000
# Act
result = service._validate_message(message)
# Assert
assert result == message
@pytest.mark.unit
class TestChatServiceSendMessage:
"""Test suite for ChatService.send_message() method."""
async def test_send_message_returns_response(self):
"""Should send message and return ChatResponse."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_response = MagicMock()
mock_response.text = "This is the answer based on your sources."
mock_response.citations = []
mock_chat.send_message.return_value = mock_response
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.send_message(notebook_id, "What are the key points?")
# Assert
assert result.message == "This is the answer based on your sources."
assert isinstance(result.sources, list)
mock_chat.send_message.assert_called_once_with("What are the key points?")
async def test_send_message_with_references(self):
"""Should include source references when available."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_response = MagicMock()
mock_response.text = "Answer with references."
mock_citation = MagicMock()
mock_citation.source_id = str(uuid4())
mock_citation.title = "Source 1"
mock_citation.snippet = "Relevant text"
mock_response.citations = [mock_citation]
mock_chat.send_message.return_value = mock_response
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.send_message(notebook_id, "Question?", include_references=True)
# Assert
assert len(result.sources) == 1
assert result.sources[0].title == "Source 1"
async def test_send_message_without_references(self):
"""Should not include source references when disabled."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_response = MagicMock()
mock_response.text = "Answer without references."
mock_response.citations = []
mock_chat.send_message.return_value = mock_response
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.send_message(notebook_id, "Question?", include_references=False)
# Assert
# Even with include_references=False, we still get empty sources list
assert result.sources == []
async def test_send_message_empty_message_raises_validation_error(self):
"""Should raise ValidationError for empty message."""
# Arrange
notebook_id = uuid4()
service = ChatService()
# Act & Assert
with pytest.raises(ValidationError) as exc_info:
await service.send_message(notebook_id, "")
assert "Message cannot be empty" in str(exc_info.value)
async def test_send_message_notebook_not_found_raises_not_found(self):
"""Should raise NotFoundError if notebook not found."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_client.notebooks.get.side_effect = Exception("notebook not found")
service = ChatService(client=mock_client)
# Act & Assert
with pytest.raises(NotFoundError) as exc_info:
await service.send_message(notebook_id, "Question?")
assert str(notebook_id) in str(exc_info.value)
async def test_send_message_api_error_raises_notebooklm_error(self):
"""Should raise NotebookLMError on API error."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_client.notebooks.get.side_effect = Exception("connection timeout")
service = ChatService(client=mock_client)
# Act & Assert
with pytest.raises(NotebookLMError) as exc_info:
await service.send_message(notebook_id, "Question?")
assert "Failed to send message" in str(exc_info.value)
@pytest.mark.unit
class TestChatServiceGetHistory:
"""Test suite for ChatService.get_history() method."""
async def test_get_history_returns_messages(self):
"""Should return list of chat messages."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_message = MagicMock()
mock_message.id = str(uuid4())
mock_message.role = "user"
mock_message.content = "What are the key points?"
mock_message.timestamp = datetime.utcnow()
mock_message.citations = None
mock_chat.get_history.return_value = [mock_message]
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.get_history(notebook_id)
# Assert
assert len(result) == 1
assert result[0].role == "user"
assert result[0].content == "What are the key points?"
async def test_get_history_with_assistant_message(self):
"""Should handle assistant messages with sources."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
user_msg = MagicMock()
user_msg.id = str(uuid4())
user_msg.role = "user"
user_msg.content = "Question?"
user_msg.timestamp = datetime.utcnow()
user_msg.citations = None
assistant_msg = MagicMock()
assistant_msg.id = str(uuid4())
assistant_msg.role = "assistant"
assistant_msg.content = "Answer with sources."
assistant_msg.timestamp = datetime.utcnow()
mock_citation = MagicMock()
mock_citation.source_id = str(uuid4())
mock_citation.title = "Source 1"
mock_citation.snippet = "Relevant text"
assistant_msg.citations = [mock_citation]
mock_chat.get_history.return_value = [user_msg, assistant_msg]
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.get_history(notebook_id)
# Assert
assert len(result) == 2
assert result[1].role == "assistant"
assert result[1].sources is not None
assert len(result[1].sources) == 1
async def test_get_history_empty_returns_empty_list(self):
"""Should return empty list if no history."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_chat.get_history.return_value = []
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.get_history(notebook_id)
# Assert
assert result == []
async def test_get_history_notebook_not_found_raises_not_found(self):
"""Should raise NotFoundError if notebook not found."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_client.notebooks.get.side_effect = Exception("not found")
service = ChatService(client=mock_client)
# Act & Assert
with pytest.raises(NotFoundError):
await service.get_history(notebook_id)
async def test_get_history_api_error_raises_notebooklm_error(self):
"""Should raise NotebookLMError on API error."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_chat.get_history.side_effect = Exception("API error")
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act & Assert
with pytest.raises(NotebookLMError) as exc_info:
await service.get_history(notebook_id)
assert "Failed to get chat history" in str(exc_info.value)