feat(api): add chat functionality (Sprint 3)

Implement Sprint 3: Chat Functionality

- Add ChatService with send_message and get_history methods
- Add POST /api/v1/notebooks/{id}/chat - Send message
- Add GET /api/v1/notebooks/{id}/chat/history - Get chat history
- Add ChatRequest model (message, include_references)
- Add ChatResponse model (message, sources[], timestamp)
- Add ChatMessage model (id, role, content, timestamp, sources)
- Add SourceReference model (source_id, title, snippet)
- Integrate chat router with main app

Features:
- Send messages to notebook chat
- Get AI responses with source references
- Retrieve chat history
- Support for citations in responses

Tests:
- 14 unit tests for ChatService
- 11 integration tests for chat API
- 25/25 tests passing

Related: Sprint 3 - Chat Functionality
This commit is contained in:
Luca Sacchi Ricciardi
2026-04-06 01:48:19 +02:00
parent 3991ffdd7f
commit 081f3f0d89
8 changed files with 1225 additions and 1 deletions

View File

@@ -0,0 +1,195 @@
"""Integration tests for chat API endpoints.
Tests all chat endpoints with mocked services.
"""
from datetime import datetime
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient
from notebooklm_agent.api.main import app
from notebooklm_agent.api.models.responses import ChatMessage, ChatResponse, SourceReference
@pytest.mark.unit
class TestSendMessageEndpoint:
"""Test suite for POST /api/v1/notebooks/{id}/chat endpoint."""
def test_send_message_returns_200(self):
"""Should return 200 with chat response."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
mock_response = ChatResponse(
message="This is the answer based on your sources.",
sources=[
SourceReference(
source_id=str(uuid4()),
title="Example Source",
snippet="Relevant text",
)
],
timestamp=datetime.utcnow(),
)
mock_service.send_message.return_value = mock_response
mock_service_class.return_value = mock_service
# Act
response = client.post(
f"/api/v1/notebooks/{notebook_id}/chat",
json={"message": "What are the key points?", "include_references": True},
)
# Assert
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "answer based on your sources" in data["data"]["message"]
assert len(data["data"]["sources"]) == 1
def test_send_message_invalid_notebook_id_returns_400(self):
"""Should return 400 for invalid notebook ID."""
# Arrange
client = TestClient(app)
# Act
response = client.post(
"/api/v1/notebooks/invalid-id/chat",
json={"message": "Question?"},
)
# Assert
assert response.status_code in [400, 422]
def test_send_message_empty_message_returns_400(self):
"""Should return 400 for empty message."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
from notebooklm_agent.core.exceptions import ValidationError
mock_service.send_message.side_effect = ValidationError("Message cannot be empty")
mock_service_class.return_value = mock_service
# Act
response = client.post(
f"/api/v1/notebooks/{notebook_id}/chat",
json={"message": ""},
)
# Assert
assert response.status_code in [400, 422]
def test_send_message_notebook_not_found_returns_404(self):
"""Should return 404 when notebook not found."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
from notebooklm_agent.core.exceptions import NotFoundError
mock_service.send_message.side_effect = NotFoundError("Notebook", notebook_id)
mock_service_class.return_value = mock_service
# Act
response = client.post(
f"/api/v1/notebooks/{notebook_id}/chat",
json={"message": "Question?"},
)
# Assert
assert response.status_code == 404
@pytest.mark.unit
class TestGetChatHistoryEndpoint:
"""Test suite for GET /api/v1/notebooks/{id}/chat/history endpoint."""
def test_get_history_returns_200(self):
"""Should return 200 with chat history."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
mock_message = ChatMessage(
id=str(uuid4()),
role="user",
content="What are the key points?",
timestamp=datetime.utcnow(),
sources=None,
)
mock_service.get_history.return_value = [mock_message]
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
# Assert
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]) == 1
assert data["data"][0]["role"] == "user"
def test_get_history_empty_returns_empty_list(self):
"""Should return empty list if no history."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
mock_service.get_history.return_value = []
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
# Assert
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"] == []
def test_get_history_invalid_notebook_id_returns_400(self):
"""Should return 400 for invalid notebook ID."""
# Arrange
client = TestClient(app)
# Act
response = client.get("/api/v1/notebooks/invalid-id/chat/history")
# Assert
assert response.status_code in [400, 422]
def test_get_history_notebook_not_found_returns_404(self):
"""Should return 404 when notebook not found."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
mock_service = AsyncMock()
from notebooklm_agent.core.exceptions import NotFoundError
mock_service.get_history.side_effect = NotFoundError("Notebook", notebook_id)
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
# Assert
assert response.status_code == 404

View File

@@ -0,0 +1,349 @@
"""Unit tests for ChatService.
TDD Cycle: RED → GREEN → REFACTOR
"""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import pytest
from notebooklm_agent.core.exceptions import (
NotebookLMError,
NotFoundError,
ValidationError,
)
from notebooklm_agent.services.chat_service import ChatService
@pytest.mark.unit
class TestChatServiceInit:
"""Test suite for ChatService initialization."""
async def test_get_client_returns_existing_client(self):
"""Should return existing client if already initialized."""
# Arrange
mock_client = AsyncMock()
service = ChatService(client=mock_client)
# Act
client = await service._get_client()
# Assert
assert client == mock_client
@pytest.mark.unit
class TestChatServiceValidateMessage:
"""Test suite for ChatService._validate_message()."""
def test_validate_valid_message(self):
"""Should accept valid message."""
# Arrange
service = ChatService()
message = "What are the key points?"
# Act
result = service._validate_message(message)
# Assert
assert result == message
def test_validate_empty_message_raises_validation_error(self):
"""Should raise ValidationError for empty message."""
# Arrange
service = ChatService()
# Act & Assert
with pytest.raises(ValidationError) as exc_info:
service._validate_message("")
assert "Message cannot be empty" in str(exc_info.value)
def test_validate_whitespace_message_raises_validation_error(self):
"""Should raise ValidationError for whitespace-only message."""
# Arrange
service = ChatService()
# Act & Assert
with pytest.raises(ValidationError) as exc_info:
service._validate_message(" ")
assert "Message cannot be empty" in str(exc_info.value)
def test_validate_long_message_raises_validation_error(self):
"""Should raise ValidationError for message > 2000 chars."""
# Arrange
service = ChatService()
long_message = "A" * 2001
# Act & Assert
with pytest.raises(ValidationError) as exc_info:
service._validate_message(long_message)
assert "at most 2000 characters" in str(exc_info.value)
def test_validate_max_length_message_succeeds(self):
"""Should accept message with exactly 2000 characters."""
# Arrange
service = ChatService()
message = "A" * 2000
# Act
result = service._validate_message(message)
# Assert
assert result == message
@pytest.mark.unit
class TestChatServiceSendMessage:
"""Test suite for ChatService.send_message() method."""
async def test_send_message_returns_response(self):
"""Should send message and return ChatResponse."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_response = MagicMock()
mock_response.text = "This is the answer based on your sources."
mock_response.citations = []
mock_chat.send_message.return_value = mock_response
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.send_message(notebook_id, "What are the key points?")
# Assert
assert result.message == "This is the answer based on your sources."
assert isinstance(result.sources, list)
mock_chat.send_message.assert_called_once_with("What are the key points?")
async def test_send_message_with_references(self):
"""Should include source references when available."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_response = MagicMock()
mock_response.text = "Answer with references."
mock_citation = MagicMock()
mock_citation.source_id = str(uuid4())
mock_citation.title = "Source 1"
mock_citation.snippet = "Relevant text"
mock_response.citations = [mock_citation]
mock_chat.send_message.return_value = mock_response
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.send_message(notebook_id, "Question?", include_references=True)
# Assert
assert len(result.sources) == 1
assert result.sources[0].title == "Source 1"
async def test_send_message_without_references(self):
"""Should not include source references when disabled."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_response = MagicMock()
mock_response.text = "Answer without references."
mock_response.citations = []
mock_chat.send_message.return_value = mock_response
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.send_message(notebook_id, "Question?", include_references=False)
# Assert
# Even with include_references=False, we still get empty sources list
assert result.sources == []
async def test_send_message_empty_message_raises_validation_error(self):
"""Should raise ValidationError for empty message."""
# Arrange
notebook_id = uuid4()
service = ChatService()
# Act & Assert
with pytest.raises(ValidationError) as exc_info:
await service.send_message(notebook_id, "")
assert "Message cannot be empty" in str(exc_info.value)
async def test_send_message_notebook_not_found_raises_not_found(self):
"""Should raise NotFoundError if notebook not found."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_client.notebooks.get.side_effect = Exception("notebook not found")
service = ChatService(client=mock_client)
# Act & Assert
with pytest.raises(NotFoundError) as exc_info:
await service.send_message(notebook_id, "Question?")
assert str(notebook_id) in str(exc_info.value)
async def test_send_message_api_error_raises_notebooklm_error(self):
"""Should raise NotebookLMError on API error."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_client.notebooks.get.side_effect = Exception("connection timeout")
service = ChatService(client=mock_client)
# Act & Assert
with pytest.raises(NotebookLMError) as exc_info:
await service.send_message(notebook_id, "Question?")
assert "Failed to send message" in str(exc_info.value)
@pytest.mark.unit
class TestChatServiceGetHistory:
"""Test suite for ChatService.get_history() method."""
async def test_get_history_returns_messages(self):
"""Should return list of chat messages."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_message = MagicMock()
mock_message.id = str(uuid4())
mock_message.role = "user"
mock_message.content = "What are the key points?"
mock_message.timestamp = datetime.utcnow()
mock_message.citations = None
mock_chat.get_history.return_value = [mock_message]
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.get_history(notebook_id)
# Assert
assert len(result) == 1
assert result[0].role == "user"
assert result[0].content == "What are the key points?"
async def test_get_history_with_assistant_message(self):
"""Should handle assistant messages with sources."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
user_msg = MagicMock()
user_msg.id = str(uuid4())
user_msg.role = "user"
user_msg.content = "Question?"
user_msg.timestamp = datetime.utcnow()
user_msg.citations = None
assistant_msg = MagicMock()
assistant_msg.id = str(uuid4())
assistant_msg.role = "assistant"
assistant_msg.content = "Answer with sources."
assistant_msg.timestamp = datetime.utcnow()
mock_citation = MagicMock()
mock_citation.source_id = str(uuid4())
mock_citation.title = "Source 1"
mock_citation.snippet = "Relevant text"
assistant_msg.citations = [mock_citation]
mock_chat.get_history.return_value = [user_msg, assistant_msg]
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.get_history(notebook_id)
# Assert
assert len(result) == 2
assert result[1].role == "assistant"
assert result[1].sources is not None
assert len(result[1].sources) == 1
async def test_get_history_empty_returns_empty_list(self):
"""Should return empty list if no history."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_chat.get_history.return_value = []
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act
result = await service.get_history(notebook_id)
# Assert
assert result == []
async def test_get_history_notebook_not_found_raises_not_found(self):
"""Should raise NotFoundError if notebook not found."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_client.notebooks.get.side_effect = Exception("not found")
service = ChatService(client=mock_client)
# Act & Assert
with pytest.raises(NotFoundError):
await service.get_history(notebook_id)
async def test_get_history_api_error_raises_notebooklm_error(self):
"""Should raise NotebookLMError on API error."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_chat = AsyncMock()
mock_chat.get_history.side_effect = Exception("API error")
mock_notebook.chat = mock_chat
mock_client.notebooks.get.return_value = mock_notebook
service = ChatService(client=mock_client)
# Act & Assert
with pytest.raises(NotebookLMError) as exc_info:
await service.get_history(notebook_id)
assert "Failed to get chat history" in str(exc_info.value)