feat(api): add chat functionality (Sprint 3)

Implement Sprint 3: Chat Functionality

- Add ChatService with send_message and get_history methods
- Add POST /api/v1/notebooks/{id}/chat - Send message
- Add GET /api/v1/notebooks/{id}/chat/history - Get chat history
- Add ChatRequest model (message, include_references)
- Add ChatResponse model (message, sources[], timestamp)
- Add ChatMessage model (id, role, content, timestamp, sources)
- Add SourceReference model (source_id, title, snippet)
- Integrate chat router with main app

Features:
- Send messages to notebook chat
- Get AI responses with source references
- Retrieve chat history
- Support for citations in responses

Tests:
- 14 unit tests for ChatService
- 11 integration tests for chat API
- 25/25 tests passing

Related: Sprint 3 - Chat Functionality
This commit is contained in:
Luca Sacchi Ricciardi
2026-04-06 01:48:19 +02:00
parent 3991ffdd7f
commit 081f3f0d89
8 changed files with 1225 additions and 1 deletions

View File

@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from notebooklm_agent.api.routes import health, notebooks, sources
from notebooklm_agent.api.routes import chat, health, notebooks, sources
from notebooklm_agent.core.config import get_settings
from notebooklm_agent.core.logging import setup_logging
@@ -54,6 +54,7 @@ def create_application() -> FastAPI:
app.include_router(health.router, prefix="/health", tags=["health"])
app.include_router(notebooks.router, prefix="/api/v1/notebooks", tags=["notebooks"])
app.include_router(sources.router, prefix="/api/v1/notebooks", tags=["sources"])
app.include_router(chat.router, prefix="/api/v1/notebooks", tags=["chat"])
return app

View File

@@ -282,3 +282,42 @@ class ResearchRequest(BaseModel):
if not v or not v.strip():
raise ValueError("Query cannot be empty")
return v.strip()
class ChatRequest(BaseModel):
"""Request model for sending a chat message.
Attributes:
message: The message text to send.
include_references: Whether to include source references in response.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"message": "What are the key points from the sources?",
"include_references": True,
}
}
)
message: str = Field(
...,
min_length=1,
max_length=2000,
description="The message text to send",
examples=["What are the key points from the sources?"],
)
include_references: bool = Field(
True,
description="Whether to include source references in response",
examples=[True, False],
)
@field_validator("message")
@classmethod
def validate_message(cls, v: str) -> str:
"""Validate message is not empty."""
if not v or not v.strip():
raise ValueError("Message cannot be empty")
return v.strip()

View File

@@ -418,3 +418,129 @@ class HealthStatus(BaseModel):
description="Service name",
examples=["notebooklm-agent-api"],
)
class SourceReference(BaseModel):
"""Source reference in chat response.
Attributes:
source_id: The source ID.
title: The source title.
snippet: Relevant text snippet from the source.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"source_id": "550e8400-e29b-41d4-a716-446655440001",
"title": "Example Article",
"snippet": "Key information from the source...",
}
}
)
source_id: str = Field(
...,
description="The source ID",
examples=["550e8400-e29b-41d4-a716-446655440001"],
)
title: str = Field(
...,
description="The source title",
examples=["Example Article"],
)
snippet: str | None = Field(
None,
description="Relevant text snippet from the source",
examples=["Key information from the source..."],
)
class ChatResponse(BaseModel):
"""Chat response model.
Attributes:
message: The assistant's response message.
sources: List of source references used in the response.
timestamp: Response timestamp.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"message": "Based on the sources, here are the key points...",
"sources": [
{
"source_id": "550e8400-e29b-41d4-a716-446655440001",
"title": "Example Article",
"snippet": "Key information...",
}
],
"timestamp": "2026-04-06T10:30:00Z",
}
}
)
message: str = Field(
...,
description="The assistant's response message",
examples=["Based on the sources, here are the key points..."],
)
sources: list[SourceReference] = Field(
default_factory=list,
description="List of source references used in the response",
)
timestamp: datetime = Field(
...,
description="Response timestamp",
examples=["2026-04-06T10:30:00Z"],
)
class ChatMessage(BaseModel):
"""Chat message model.
Attributes:
id: Unique message identifier.
role: Message role (user or assistant).
content: Message content.
timestamp: Message timestamp.
sources: Source references (for assistant messages).
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"id": "550e8400-e29b-41d4-a716-446655440002",
"role": "assistant",
"content": "Based on the sources...",
"timestamp": "2026-04-06T10:30:00Z",
"sources": [],
}
}
)
id: str = Field(
...,
description="Unique message identifier",
examples=["550e8400-e29b-41d4-a716-446655440002"],
)
role: str = Field(
...,
description="Message role (user or assistant)",
examples=["user", "assistant"],
)
content: str = Field(
...,
description="Message content",
examples=["What are the key points?"],
)
timestamp: datetime = Field(
...,
description="Message timestamp",
examples=["2026-04-06T10:30:00Z"],
)
sources: list[SourceReference] | None = Field(
None,
description="Source references (for assistant messages)",
)

View File

@@ -0,0 +1,218 @@
"""Chat API routes.
This module contains API endpoints for chat functionality.
"""
from datetime import datetime
from uuid import uuid4
from fastapi import APIRouter, HTTPException, status
from notebooklm_agent.api.models.requests import ChatRequest
from notebooklm_agent.api.models.responses import ApiResponse, ChatResponse, ResponseMeta
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
from notebooklm_agent.services.chat_service import ChatService
router = APIRouter(tags=["chat"])
async def get_chat_service() -> ChatService:
"""Get chat service instance.
Returns:
ChatService instance.
"""
return ChatService()
@router.post(
"/{notebook_id}/chat",
response_model=ApiResponse[ChatResponse],
summary="Send chat message",
description="Send a message to the notebook chat and get a response.",
)
async def send_message(notebook_id: str, data: ChatRequest):
"""Send a chat message.
Args:
notebook_id: Notebook UUID.
data: Chat request with message and options.
Returns:
Chat response with answer and source references.
Raises:
HTTPException: 400 for invalid data, 404 for not found, 502 for external API errors.
"""
from uuid import UUID
try:
notebook_uuid = UUID(notebook_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"success": False,
"error": {
"code": "VALIDATION_ERROR",
"message": "Invalid notebook ID format",
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
try:
service = await get_chat_service()
response = await service.send_message(
notebook_uuid,
data.message,
data.include_references,
)
return ApiResponse(
success=True,
data=response,
error=None,
meta=ResponseMeta(
timestamp=datetime.utcnow(),
request_id=uuid4(),
),
)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": e.details or [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
@router.get(
"/{notebook_id}/chat/history",
response_model=ApiResponse[list],
summary="Get chat history",
description="Get the chat history for a notebook.",
)
async def get_chat_history(notebook_id: str):
"""Get chat history.
Args:
notebook_id: Notebook UUID.
Returns:
List of chat messages.
Raises:
HTTPException: 400 for invalid ID, 404 for not found, 502 for external API errors.
"""
from uuid import UUID
try:
notebook_uuid = UUID(notebook_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"success": False,
"error": {
"code": "VALIDATION_ERROR",
"message": "Invalid notebook ID format",
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
try:
service = await get_chat_service()
history = await service.get_history(notebook_uuid)
return ApiResponse(
success=True,
data=history,
error=None,
meta=ResponseMeta(
timestamp=datetime.utcnow(),
request_id=uuid4(),
),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {
"code": e.code,
"message": e.message,
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)

View File

@@ -0,0 +1,189 @@
"""Chat service for business logic.
This module contains the ChatService class which handles
all business logic for chat operations.
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from notebooklm_agent.api.models.responses import ChatMessage, ChatResponse, SourceReference
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
class ChatService:
"""Service for chat operations.
This service handles all business logic for chat functionality,
including sending messages and retrieving chat history.
Attributes:
_client: The notebooklm-py client instance.
"""
def __init__(self, client: Any = None) -> None:
"""Initialize the chat service.
Args:
client: Optional notebooklm-py client instance.
If not provided, will be created on first use.
"""
self._client = client
async def _get_client(self) -> Any:
"""Get or create notebooklm-py client.
Returns:
The notebooklm-py client instance.
"""
if self._client is None:
# Lazy initialization - import here to avoid circular imports
from notebooklm import NotebookLMClient
self._client = await NotebookLMClient.from_storage()
return self._client
def _validate_message(self, message: str) -> str:
"""Validate chat message.
Args:
message: The message to validate.
Returns:
The validated message.
Raises:
ValidationError: If message is invalid.
"""
if not message or not message.strip():
raise ValidationError("Message cannot be empty")
if len(message) > 2000:
raise ValidationError("Message must be at most 2000 characters")
return message.strip()
async def send_message(
self,
notebook_id: UUID,
message: str,
include_references: bool = True,
) -> ChatResponse:
"""Send a message and get a response.
Args:
notebook_id: The notebook ID.
message: The message to send.
include_references: Whether to include source references.
Returns:
The chat response with message and sources.
Raises:
ValidationError: If message is invalid.
NotFoundError: If notebook not found.
NotebookLMError: If external API fails.
"""
# Validate message
validated_message = self._validate_message(message)
try:
client = await self._get_client()
notebook = await client.notebooks.get(str(notebook_id))
# Send message to chat
chat = notebook.chat
response = await chat.send_message(validated_message)
# Extract response text
response_text = getattr(response, "text", str(response))
# Build source references if requested
sources = []
if include_references:
# Try to get citations/references from response
citations = getattr(response, "citations", []) or getattr(
response, "references", []
)
for citation in citations:
sources.append(
SourceReference(
source_id=getattr(citation, "source_id", str(uuid4())),
title=getattr(citation, "title", "Unknown Source"),
snippet=getattr(citation, "snippet", None),
)
)
return ChatResponse(
message=response_text,
sources=sources,
timestamp=datetime.utcnow(),
)
except ValidationError:
raise
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str:
raise NotFoundError("Notebook", str(notebook_id))
raise NotebookLMError(f"Failed to send message: {e}")
async def get_history(self, notebook_id: UUID) -> list[ChatMessage]:
"""Get chat history for a notebook.
Args:
notebook_id: The notebook ID.
Returns:
List of chat messages.
Raises:
NotFoundError: If notebook not found.
NotebookLMError: If external API fails.
"""
try:
client = await self._get_client()
notebook = await client.notebooks.get(str(notebook_id))
# Get chat history
chat = notebook.chat
history = await chat.get_history()
# Convert to ChatMessage objects
messages = []
for msg in history:
role = getattr(msg, "role", "unknown")
content = getattr(msg, "content", "")
timestamp = getattr(msg, "timestamp", datetime.utcnow())
msg_id = getattr(msg, "id", str(uuid4()))
# Extract sources for assistant messages
sources = None
if role == "assistant":
citations = getattr(msg, "citations", []) or getattr(msg, "references", [])
if citations:
sources = [
SourceReference(
source_id=getattr(c, "source_id", str(uuid4())),
title=getattr(c, "title", "Unknown Source"),
snippet=getattr(c, "snippet", None),
)
for c in citations
]
messages.append(
ChatMessage(
id=msg_id,
role=role,
content=content,
timestamp=timestamp,
sources=sources,
)
)
return messages
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str:
raise NotFoundError("Notebook", str(notebook_id))
raise NotebookLMError(f"Failed to get chat history: {e}")