"""Unit tests for ChatService. TDD Cycle: RED → GREEN → REFACTOR """ from datetime import datetime from unittest.mock import AsyncMock, MagicMock from uuid import UUID, uuid4 import pytest from notebooklm_agent.core.exceptions import ( NotebookLMError, NotFoundError, ValidationError, ) from notebooklm_agent.services.chat_service import ChatService @pytest.mark.unit class TestChatServiceInit: """Test suite for ChatService initialization.""" async def test_get_client_returns_existing_client(self): """Should return existing client if already initialized.""" # Arrange mock_client = AsyncMock() service = ChatService(client=mock_client) # Act client = await service._get_client() # Assert assert client == mock_client @pytest.mark.unit class TestChatServiceValidateMessage: """Test suite for ChatService._validate_message().""" def test_validate_valid_message(self): """Should accept valid message.""" # Arrange service = ChatService() message = "What are the key points?" # Act result = service._validate_message(message) # Assert assert result == message def test_validate_empty_message_raises_validation_error(self): """Should raise ValidationError for empty message.""" # Arrange service = ChatService() # Act & Assert with pytest.raises(ValidationError) as exc_info: service._validate_message("") assert "Message cannot be empty" in str(exc_info.value) def test_validate_whitespace_message_raises_validation_error(self): """Should raise ValidationError for whitespace-only message.""" # Arrange service = ChatService() # Act & Assert with pytest.raises(ValidationError) as exc_info: service._validate_message(" ") assert "Message cannot be empty" in str(exc_info.value) def test_validate_long_message_raises_validation_error(self): """Should raise ValidationError for message > 2000 chars.""" # Arrange service = ChatService() long_message = "A" * 2001 # Act & Assert with pytest.raises(ValidationError) as exc_info: service._validate_message(long_message) assert "at most 2000 characters" in str(exc_info.value) def test_validate_max_length_message_succeeds(self): """Should accept message with exactly 2000 characters.""" # Arrange service = ChatService() message = "A" * 2000 # Act result = service._validate_message(message) # Assert assert result == message @pytest.mark.unit class TestChatServiceSendMessage: """Test suite for ChatService.send_message() method.""" async def test_send_message_returns_response(self): """Should send message and return ChatResponse.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_notebook = AsyncMock() mock_chat = AsyncMock() mock_response = MagicMock() mock_response.text = "This is the answer based on your sources." mock_response.citations = [] mock_chat.send_message.return_value = mock_response mock_notebook.chat = mock_chat mock_client.notebooks.get.return_value = mock_notebook service = ChatService(client=mock_client) # Act result = await service.send_message(notebook_id, "What are the key points?") # Assert assert result.message == "This is the answer based on your sources." assert isinstance(result.sources, list) mock_chat.send_message.assert_called_once_with("What are the key points?") async def test_send_message_with_references(self): """Should include source references when available.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_notebook = AsyncMock() mock_chat = AsyncMock() mock_response = MagicMock() mock_response.text = "Answer with references." mock_citation = MagicMock() mock_citation.source_id = str(uuid4()) mock_citation.title = "Source 1" mock_citation.snippet = "Relevant text" mock_response.citations = [mock_citation] mock_chat.send_message.return_value = mock_response mock_notebook.chat = mock_chat mock_client.notebooks.get.return_value = mock_notebook service = ChatService(client=mock_client) # Act result = await service.send_message(notebook_id, "Question?", include_references=True) # Assert assert len(result.sources) == 1 assert result.sources[0].title == "Source 1" async def test_send_message_without_references(self): """Should not include source references when disabled.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_notebook = AsyncMock() mock_chat = AsyncMock() mock_response = MagicMock() mock_response.text = "Answer without references." mock_response.citations = [] mock_chat.send_message.return_value = mock_response mock_notebook.chat = mock_chat mock_client.notebooks.get.return_value = mock_notebook service = ChatService(client=mock_client) # Act result = await service.send_message(notebook_id, "Question?", include_references=False) # Assert # Even with include_references=False, we still get empty sources list assert result.sources == [] async def test_send_message_empty_message_raises_validation_error(self): """Should raise ValidationError for empty message.""" # Arrange notebook_id = uuid4() service = ChatService() # Act & Assert with pytest.raises(ValidationError) as exc_info: await service.send_message(notebook_id, "") assert "Message cannot be empty" in str(exc_info.value) async def test_send_message_notebook_not_found_raises_not_found(self): """Should raise NotFoundError if notebook not found.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_client.notebooks.get.side_effect = Exception("notebook not found") service = ChatService(client=mock_client) # Act & Assert with pytest.raises(NotFoundError) as exc_info: await service.send_message(notebook_id, "Question?") assert str(notebook_id) in str(exc_info.value) async def test_send_message_api_error_raises_notebooklm_error(self): """Should raise NotebookLMError on API error.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_client.notebooks.get.side_effect = Exception("connection timeout") service = ChatService(client=mock_client) # Act & Assert with pytest.raises(NotebookLMError) as exc_info: await service.send_message(notebook_id, "Question?") assert "Failed to send message" in str(exc_info.value) @pytest.mark.unit class TestChatServiceGetHistory: """Test suite for ChatService.get_history() method.""" async def test_get_history_returns_messages(self): """Should return list of chat messages.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_notebook = AsyncMock() mock_chat = AsyncMock() mock_message = MagicMock() mock_message.id = str(uuid4()) mock_message.role = "user" mock_message.content = "What are the key points?" mock_message.timestamp = datetime.utcnow() mock_message.citations = None mock_chat.get_history.return_value = [mock_message] mock_notebook.chat = mock_chat mock_client.notebooks.get.return_value = mock_notebook service = ChatService(client=mock_client) # Act result = await service.get_history(notebook_id) # Assert assert len(result) == 1 assert result[0].role == "user" assert result[0].content == "What are the key points?" async def test_get_history_with_assistant_message(self): """Should handle assistant messages with sources.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_notebook = AsyncMock() mock_chat = AsyncMock() user_msg = MagicMock() user_msg.id = str(uuid4()) user_msg.role = "user" user_msg.content = "Question?" user_msg.timestamp = datetime.utcnow() user_msg.citations = None assistant_msg = MagicMock() assistant_msg.id = str(uuid4()) assistant_msg.role = "assistant" assistant_msg.content = "Answer with sources." assistant_msg.timestamp = datetime.utcnow() mock_citation = MagicMock() mock_citation.source_id = str(uuid4()) mock_citation.title = "Source 1" mock_citation.snippet = "Relevant text" assistant_msg.citations = [mock_citation] mock_chat.get_history.return_value = [user_msg, assistant_msg] mock_notebook.chat = mock_chat mock_client.notebooks.get.return_value = mock_notebook service = ChatService(client=mock_client) # Act result = await service.get_history(notebook_id) # Assert assert len(result) == 2 assert result[1].role == "assistant" assert result[1].sources is not None assert len(result[1].sources) == 1 async def test_get_history_empty_returns_empty_list(self): """Should return empty list if no history.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_notebook = AsyncMock() mock_chat = AsyncMock() mock_chat.get_history.return_value = [] mock_notebook.chat = mock_chat mock_client.notebooks.get.return_value = mock_notebook service = ChatService(client=mock_client) # Act result = await service.get_history(notebook_id) # Assert assert result == [] async def test_get_history_notebook_not_found_raises_not_found(self): """Should raise NotFoundError if notebook not found.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_client.notebooks.get.side_effect = Exception("not found") service = ChatService(client=mock_client) # Act & Assert with pytest.raises(NotFoundError): await service.get_history(notebook_id) async def test_get_history_api_error_raises_notebooklm_error(self): """Should raise NotebookLMError on API error.""" # Arrange notebook_id = uuid4() mock_client = AsyncMock() mock_notebook = AsyncMock() mock_chat = AsyncMock() mock_chat.get_history.side_effect = Exception("API error") mock_notebook.chat = mock_chat mock_client.notebooks.get.return_value = mock_notebook service = ChatService(client=mock_client) # Act & Assert with pytest.raises(NotebookLMError) as exc_info: await service.get_history(notebook_id) assert "Failed to get chat history" in str(exc_info.value)