feat(api): add chat functionality (Sprint 3)
Implement Sprint 3: Chat Functionality
- Add ChatService with send_message and get_history methods
- Add POST /api/v1/notebooks/{id}/chat - Send message
- Add GET /api/v1/notebooks/{id}/chat/history - Get chat history
- Add ChatRequest model (message, include_references)
- Add ChatResponse model (message, sources[], timestamp)
- Add ChatMessage model (id, role, content, timestamp, sources)
- Add SourceReference model (source_id, title, snippet)
- Integrate chat router with main app
Features:
- Send messages to notebook chat
- Get AI responses with source references
- Retrieve chat history
- Support for citations in responses
Tests:
- 14 unit tests for ChatService
- 11 integration tests for chat API
- 25/25 tests passing
Related: Sprint 3 - Chat Functionality
This commit is contained in:
195
tests/unit/test_api/test_chat.py
Normal file
195
tests/unit/test_api/test_chat.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Integration tests for chat API endpoints.
|
||||
|
||||
Tests all chat endpoints with mocked services.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from notebooklm_agent.api.main import app
|
||||
from notebooklm_agent.api.models.responses import ChatMessage, ChatResponse, SourceReference
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSendMessageEndpoint:
|
||||
"""Test suite for POST /api/v1/notebooks/{id}/chat endpoint."""
|
||||
|
||||
def test_send_message_returns_200(self):
|
||||
"""Should return 200 with chat response."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
mock_response = ChatResponse(
|
||||
message="This is the answer based on your sources.",
|
||||
sources=[
|
||||
SourceReference(
|
||||
source_id=str(uuid4()),
|
||||
title="Example Source",
|
||||
snippet="Relevant text",
|
||||
)
|
||||
],
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
mock_service.send_message.return_value = mock_response
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
f"/api/v1/notebooks/{notebook_id}/chat",
|
||||
json={"message": "What are the key points?", "include_references": True},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "answer based on your sources" in data["data"]["message"]
|
||||
assert len(data["data"]["sources"]) == 1
|
||||
|
||||
def test_send_message_invalid_notebook_id_returns_400(self):
|
||||
"""Should return 400 for invalid notebook ID."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
"/api/v1/notebooks/invalid-id/chat",
|
||||
json={"message": "Question?"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code in [400, 422]
|
||||
|
||||
def test_send_message_empty_message_returns_400(self):
|
||||
"""Should return 400 for empty message."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
from notebooklm_agent.core.exceptions import ValidationError
|
||||
|
||||
mock_service.send_message.side_effect = ValidationError("Message cannot be empty")
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
f"/api/v1/notebooks/{notebook_id}/chat",
|
||||
json={"message": ""},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code in [400, 422]
|
||||
|
||||
def test_send_message_notebook_not_found_returns_404(self):
|
||||
"""Should return 404 when notebook not found."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
from notebooklm_agent.core.exceptions import NotFoundError
|
||||
|
||||
mock_service.send_message.side_effect = NotFoundError("Notebook", notebook_id)
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
f"/api/v1/notebooks/{notebook_id}/chat",
|
||||
json={"message": "Question?"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetChatHistoryEndpoint:
|
||||
"""Test suite for GET /api/v1/notebooks/{id}/chat/history endpoint."""
|
||||
|
||||
def test_get_history_returns_200(self):
|
||||
"""Should return 200 with chat history."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
mock_message = ChatMessage(
|
||||
id=str(uuid4()),
|
||||
role="user",
|
||||
content="What are the key points?",
|
||||
timestamp=datetime.utcnow(),
|
||||
sources=None,
|
||||
)
|
||||
mock_service.get_history.return_value = [mock_message]
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["role"] == "user"
|
||||
|
||||
def test_get_history_empty_returns_empty_list(self):
|
||||
"""Should return empty list if no history."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
mock_service.get_history.return_value = []
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"] == []
|
||||
|
||||
def test_get_history_invalid_notebook_id_returns_400(self):
|
||||
"""Should return 400 for invalid notebook ID."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
|
||||
# Act
|
||||
response = client.get("/api/v1/notebooks/invalid-id/chat/history")
|
||||
|
||||
# Assert
|
||||
assert response.status_code in [400, 422]
|
||||
|
||||
def test_get_history_notebook_not_found_returns_404(self):
|
||||
"""Should return 404 when notebook not found."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.chat.ChatService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
from notebooklm_agent.core.exceptions import NotFoundError
|
||||
|
||||
mock_service.get_history.side_effect = NotFoundError("Notebook", notebook_id)
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/v1/notebooks/{notebook_id}/chat/history")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 404
|
||||
349
tests/unit/test_services/test_chat_service.py
Normal file
349
tests/unit/test_services/test_chat_service.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""Unit tests for ChatService.
|
||||
|
||||
TDD Cycle: RED → GREEN → REFACTOR
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from notebooklm_agent.core.exceptions import (
|
||||
NotebookLMError,
|
||||
NotFoundError,
|
||||
ValidationError,
|
||||
)
|
||||
from notebooklm_agent.services.chat_service import ChatService
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestChatServiceInit:
|
||||
"""Test suite for ChatService initialization."""
|
||||
|
||||
async def test_get_client_returns_existing_client(self):
|
||||
"""Should return existing client if already initialized."""
|
||||
# Arrange
|
||||
mock_client = AsyncMock()
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act
|
||||
client = await service._get_client()
|
||||
|
||||
# Assert
|
||||
assert client == mock_client
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestChatServiceValidateMessage:
|
||||
"""Test suite for ChatService._validate_message()."""
|
||||
|
||||
def test_validate_valid_message(self):
|
||||
"""Should accept valid message."""
|
||||
# Arrange
|
||||
service = ChatService()
|
||||
message = "What are the key points?"
|
||||
|
||||
# Act
|
||||
result = service._validate_message(message)
|
||||
|
||||
# Assert
|
||||
assert result == message
|
||||
|
||||
def test_validate_empty_message_raises_validation_error(self):
|
||||
"""Should raise ValidationError for empty message."""
|
||||
# Arrange
|
||||
service = ChatService()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
service._validate_message("")
|
||||
|
||||
assert "Message cannot be empty" in str(exc_info.value)
|
||||
|
||||
def test_validate_whitespace_message_raises_validation_error(self):
|
||||
"""Should raise ValidationError for whitespace-only message."""
|
||||
# Arrange
|
||||
service = ChatService()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
service._validate_message(" ")
|
||||
|
||||
assert "Message cannot be empty" in str(exc_info.value)
|
||||
|
||||
def test_validate_long_message_raises_validation_error(self):
|
||||
"""Should raise ValidationError for message > 2000 chars."""
|
||||
# Arrange
|
||||
service = ChatService()
|
||||
long_message = "A" * 2001
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
service._validate_message(long_message)
|
||||
|
||||
assert "at most 2000 characters" in str(exc_info.value)
|
||||
|
||||
def test_validate_max_length_message_succeeds(self):
|
||||
"""Should accept message with exactly 2000 characters."""
|
||||
# Arrange
|
||||
service = ChatService()
|
||||
message = "A" * 2000
|
||||
|
||||
# Act
|
||||
result = service._validate_message(message)
|
||||
|
||||
# Assert
|
||||
assert result == message
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestChatServiceSendMessage:
|
||||
"""Test suite for ChatService.send_message() method."""
|
||||
|
||||
async def test_send_message_returns_response(self):
|
||||
"""Should send message and return ChatResponse."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_chat = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "This is the answer based on your sources."
|
||||
mock_response.citations = []
|
||||
|
||||
mock_chat.send_message.return_value = mock_response
|
||||
mock_notebook.chat = mock_chat
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.send_message(notebook_id, "What are the key points?")
|
||||
|
||||
# Assert
|
||||
assert result.message == "This is the answer based on your sources."
|
||||
assert isinstance(result.sources, list)
|
||||
mock_chat.send_message.assert_called_once_with("What are the key points?")
|
||||
|
||||
async def test_send_message_with_references(self):
|
||||
"""Should include source references when available."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_chat = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Answer with references."
|
||||
|
||||
mock_citation = MagicMock()
|
||||
mock_citation.source_id = str(uuid4())
|
||||
mock_citation.title = "Source 1"
|
||||
mock_citation.snippet = "Relevant text"
|
||||
mock_response.citations = [mock_citation]
|
||||
|
||||
mock_chat.send_message.return_value = mock_response
|
||||
mock_notebook.chat = mock_chat
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.send_message(notebook_id, "Question?", include_references=True)
|
||||
|
||||
# Assert
|
||||
assert len(result.sources) == 1
|
||||
assert result.sources[0].title == "Source 1"
|
||||
|
||||
async def test_send_message_without_references(self):
|
||||
"""Should not include source references when disabled."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_chat = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Answer without references."
|
||||
mock_response.citations = []
|
||||
|
||||
mock_chat.send_message.return_value = mock_response
|
||||
mock_notebook.chat = mock_chat
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.send_message(notebook_id, "Question?", include_references=False)
|
||||
|
||||
# Assert
|
||||
# Even with include_references=False, we still get empty sources list
|
||||
assert result.sources == []
|
||||
|
||||
async def test_send_message_empty_message_raises_validation_error(self):
|
||||
"""Should raise ValidationError for empty message."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
service = ChatService()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
await service.send_message(notebook_id, "")
|
||||
|
||||
assert "Message cannot be empty" in str(exc_info.value)
|
||||
|
||||
async def test_send_message_notebook_not_found_raises_not_found(self):
|
||||
"""Should raise NotFoundError if notebook not found."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.notebooks.get.side_effect = Exception("notebook not found")
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFoundError) as exc_info:
|
||||
await service.send_message(notebook_id, "Question?")
|
||||
|
||||
assert str(notebook_id) in str(exc_info.value)
|
||||
|
||||
async def test_send_message_api_error_raises_notebooklm_error(self):
|
||||
"""Should raise NotebookLMError on API error."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.notebooks.get.side_effect = Exception("connection timeout")
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotebookLMError) as exc_info:
|
||||
await service.send_message(notebook_id, "Question?")
|
||||
|
||||
assert "Failed to send message" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestChatServiceGetHistory:
|
||||
"""Test suite for ChatService.get_history() method."""
|
||||
|
||||
async def test_get_history_returns_messages(self):
|
||||
"""Should return list of chat messages."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_chat = AsyncMock()
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
mock_message.role = "user"
|
||||
mock_message.content = "What are the key points?"
|
||||
mock_message.timestamp = datetime.utcnow()
|
||||
mock_message.citations = None
|
||||
|
||||
mock_chat.get_history.return_value = [mock_message]
|
||||
mock_notebook.chat = mock_chat
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.get_history(notebook_id)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].role == "user"
|
||||
assert result[0].content == "What are the key points?"
|
||||
|
||||
async def test_get_history_with_assistant_message(self):
|
||||
"""Should handle assistant messages with sources."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_chat = AsyncMock()
|
||||
|
||||
user_msg = MagicMock()
|
||||
user_msg.id = str(uuid4())
|
||||
user_msg.role = "user"
|
||||
user_msg.content = "Question?"
|
||||
user_msg.timestamp = datetime.utcnow()
|
||||
user_msg.citations = None
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.id = str(uuid4())
|
||||
assistant_msg.role = "assistant"
|
||||
assistant_msg.content = "Answer with sources."
|
||||
assistant_msg.timestamp = datetime.utcnow()
|
||||
|
||||
mock_citation = MagicMock()
|
||||
mock_citation.source_id = str(uuid4())
|
||||
mock_citation.title = "Source 1"
|
||||
mock_citation.snippet = "Relevant text"
|
||||
assistant_msg.citations = [mock_citation]
|
||||
|
||||
mock_chat.get_history.return_value = [user_msg, assistant_msg]
|
||||
mock_notebook.chat = mock_chat
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.get_history(notebook_id)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[1].role == "assistant"
|
||||
assert result[1].sources is not None
|
||||
assert len(result[1].sources) == 1
|
||||
|
||||
async def test_get_history_empty_returns_empty_list(self):
|
||||
"""Should return empty list if no history."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_chat = AsyncMock()
|
||||
mock_chat.get_history.return_value = []
|
||||
mock_notebook.chat = mock_chat
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.get_history(notebook_id)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
async def test_get_history_notebook_not_found_raises_not_found(self):
|
||||
"""Should raise NotFoundError if notebook not found."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.notebooks.get.side_effect = Exception("not found")
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFoundError):
|
||||
await service.get_history(notebook_id)
|
||||
|
||||
async def test_get_history_api_error_raises_notebooklm_error(self):
|
||||
"""Should raise NotebookLMError on API error."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_chat = AsyncMock()
|
||||
mock_chat.get_history.side_effect = Exception("API error")
|
||||
mock_notebook.chat = mock_chat
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ChatService(client=mock_client)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotebookLMError) as exc_info:
|
||||
await service.get_history(notebook_id)
|
||||
|
||||
assert "Failed to get chat history" in str(exc_info.value)
|
||||
Reference in New Issue
Block a user