feat(api): add chat functionality (Sprint 3)
Implement Sprint 3: Chat Functionality
- Add ChatService with send_message and get_history methods
- Add POST /api/v1/notebooks/{id}/chat - Send message
- Add GET /api/v1/notebooks/{id}/chat/history - Get chat history
- Add ChatRequest model (message, include_references)
- Add ChatResponse model (message, sources[], timestamp)
- Add ChatMessage model (id, role, content, timestamp, sources)
- Add SourceReference model (source_id, title, snippet)
- Integrate chat router with main app
Features:
- Send messages to notebook chat
- Get AI responses with source references
- Retrieve chat history
- Support for citations in responses
Tests:
- 14 unit tests for ChatService
- 11 integration tests for chat API
- 25/25 tests passing
Related: Sprint 3 - Chat Functionality
This commit is contained in:
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from notebooklm_agent.api.routes import health, notebooks, sources
|
||||
from notebooklm_agent.api.routes import chat, health, notebooks, sources
|
||||
from notebooklm_agent.core.config import get_settings
|
||||
from notebooklm_agent.core.logging import setup_logging
|
||||
|
||||
@@ -54,6 +54,7 @@ def create_application() -> FastAPI:
|
||||
app.include_router(health.router, prefix="/health", tags=["health"])
|
||||
app.include_router(notebooks.router, prefix="/api/v1/notebooks", tags=["notebooks"])
|
||||
app.include_router(sources.router, prefix="/api/v1/notebooks", tags=["sources"])
|
||||
app.include_router(chat.router, prefix="/api/v1/notebooks", tags=["chat"])
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -282,3 +282,42 @@ class ResearchRequest(BaseModel):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Query cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Request model for sending a chat message.
|
||||
|
||||
Attributes:
|
||||
message: The message text to send.
|
||||
include_references: Whether to include source references in response.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"message": "What are the key points from the sources?",
|
||||
"include_references": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
message: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=2000,
|
||||
description="The message text to send",
|
||||
examples=["What are the key points from the sources?"],
|
||||
)
|
||||
include_references: bool = Field(
|
||||
True,
|
||||
description="Whether to include source references in response",
|
||||
examples=[True, False],
|
||||
)
|
||||
|
||||
@field_validator("message")
|
||||
@classmethod
|
||||
def validate_message(cls, v: str) -> str:
|
||||
"""Validate message is not empty."""
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Message cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@@ -418,3 +418,129 @@ class HealthStatus(BaseModel):
|
||||
description="Service name",
|
||||
examples=["notebooklm-agent-api"],
|
||||
)
|
||||
|
||||
|
||||
class SourceReference(BaseModel):
|
||||
"""Source reference in chat response.
|
||||
|
||||
Attributes:
|
||||
source_id: The source ID.
|
||||
title: The source title.
|
||||
snippet: Relevant text snippet from the source.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"source_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
"title": "Example Article",
|
||||
"snippet": "Key information from the source...",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
source_id: str = Field(
|
||||
...,
|
||||
description="The source ID",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440001"],
|
||||
)
|
||||
title: str = Field(
|
||||
...,
|
||||
description="The source title",
|
||||
examples=["Example Article"],
|
||||
)
|
||||
snippet: str | None = Field(
|
||||
None,
|
||||
description="Relevant text snippet from the source",
|
||||
examples=["Key information from the source..."],
|
||||
)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Chat response model.
|
||||
|
||||
Attributes:
|
||||
message: The assistant's response message.
|
||||
sources: List of source references used in the response.
|
||||
timestamp: Response timestamp.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"message": "Based on the sources, here are the key points...",
|
||||
"sources": [
|
||||
{
|
||||
"source_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
"title": "Example Article",
|
||||
"snippet": "Key information...",
|
||||
}
|
||||
],
|
||||
"timestamp": "2026-04-06T10:30:00Z",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
message: str = Field(
|
||||
...,
|
||||
description="The assistant's response message",
|
||||
examples=["Based on the sources, here are the key points..."],
|
||||
)
|
||||
sources: list[SourceReference] = Field(
|
||||
default_factory=list,
|
||||
description="List of source references used in the response",
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
...,
|
||||
description="Response timestamp",
|
||||
examples=["2026-04-06T10:30:00Z"],
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Chat message model.
|
||||
|
||||
Attributes:
|
||||
id: Unique message identifier.
|
||||
role: Message role (user or assistant).
|
||||
content: Message content.
|
||||
timestamp: Message timestamp.
|
||||
sources: Source references (for assistant messages).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440002",
|
||||
"role": "assistant",
|
||||
"content": "Based on the sources...",
|
||||
"timestamp": "2026-04-06T10:30:00Z",
|
||||
"sources": [],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
id: str = Field(
|
||||
...,
|
||||
description="Unique message identifier",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440002"],
|
||||
)
|
||||
role: str = Field(
|
||||
...,
|
||||
description="Message role (user or assistant)",
|
||||
examples=["user", "assistant"],
|
||||
)
|
||||
content: str = Field(
|
||||
...,
|
||||
description="Message content",
|
||||
examples=["What are the key points?"],
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
...,
|
||||
description="Message timestamp",
|
||||
examples=["2026-04-06T10:30:00Z"],
|
||||
)
|
||||
sources: list[SourceReference] | None = Field(
|
||||
None,
|
||||
description="Source references (for assistant messages)",
|
||||
)
|
||||
|
||||
218
src/notebooklm_agent/api/routes/chat.py
Normal file
218
src/notebooklm_agent/api/routes/chat.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Chat API routes.
|
||||
|
||||
This module contains API endpoints for chat functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
|
||||
from notebooklm_agent.api.models.requests import ChatRequest
|
||||
from notebooklm_agent.api.models.responses import ApiResponse, ChatResponse, ResponseMeta
|
||||
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
|
||||
from notebooklm_agent.services.chat_service import ChatService
|
||||
|
||||
router = APIRouter(tags=["chat"])
|
||||
|
||||
|
||||
async def get_chat_service() -> ChatService:
|
||||
"""Get chat service instance.
|
||||
|
||||
Returns:
|
||||
ChatService instance.
|
||||
"""
|
||||
return ChatService()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/chat",
|
||||
response_model=ApiResponse[ChatResponse],
|
||||
summary="Send chat message",
|
||||
description="Send a message to the notebook chat and get a response.",
|
||||
)
|
||||
async def send_message(notebook_id: str, data: ChatRequest):
|
||||
"""Send a chat message.
|
||||
|
||||
Args:
|
||||
notebook_id: Notebook UUID.
|
||||
data: Chat request with message and options.
|
||||
|
||||
Returns:
|
||||
Chat response with answer and source references.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 for invalid data, 404 for not found, 502 for external API errors.
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
notebook_uuid = UUID(notebook_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": "VALIDATION_ERROR",
|
||||
"message": "Invalid notebook ID format",
|
||||
"details": [],
|
||||
},
|
||||
"meta": {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": str(uuid4()),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
service = await get_chat_service()
|
||||
response = await service.send_message(
|
||||
notebook_uuid,
|
||||
data.message,
|
||||
data.include_references,
|
||||
)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=response,
|
||||
error=None,
|
||||
meta=ResponseMeta(
|
||||
timestamp=datetime.utcnow(),
|
||||
request_id=uuid4(),
|
||||
),
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": e.code,
|
||||
"message": e.message,
|
||||
"details": e.details or [],
|
||||
},
|
||||
"meta": {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": str(uuid4()),
|
||||
},
|
||||
},
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": e.code,
|
||||
"message": e.message,
|
||||
"details": [],
|
||||
},
|
||||
"meta": {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": str(uuid4()),
|
||||
},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": e.code,
|
||||
"message": e.message,
|
||||
"details": [],
|
||||
},
|
||||
"meta": {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": str(uuid4()),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{notebook_id}/chat/history",
|
||||
response_model=ApiResponse[list],
|
||||
summary="Get chat history",
|
||||
description="Get the chat history for a notebook.",
|
||||
)
|
||||
async def get_chat_history(notebook_id: str):
|
||||
"""Get chat history.
|
||||
|
||||
Args:
|
||||
notebook_id: Notebook UUID.
|
||||
|
||||
Returns:
|
||||
List of chat messages.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 for invalid ID, 404 for not found, 502 for external API errors.
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
notebook_uuid = UUID(notebook_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": "VALIDATION_ERROR",
|
||||
"message": "Invalid notebook ID format",
|
||||
"details": [],
|
||||
},
|
||||
"meta": {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": str(uuid4()),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
service = await get_chat_service()
|
||||
history = await service.get_history(notebook_uuid)
|
||||
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=history,
|
||||
error=None,
|
||||
meta=ResponseMeta(
|
||||
timestamp=datetime.utcnow(),
|
||||
request_id=uuid4(),
|
||||
),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": e.code,
|
||||
"message": e.message,
|
||||
"details": [],
|
||||
},
|
||||
"meta": {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": str(uuid4()),
|
||||
},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": e.code,
|
||||
"message": e.message,
|
||||
"details": [],
|
||||
},
|
||||
"meta": {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": str(uuid4()),
|
||||
},
|
||||
},
|
||||
)
|
||||
189
src/notebooklm_agent/services/chat_service.py
Normal file
189
src/notebooklm_agent/services/chat_service.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Chat service for business logic.
|
||||
|
||||
This module contains the ChatService class which handles
|
||||
all business logic for chat operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from notebooklm_agent.api.models.responses import ChatMessage, ChatResponse, SourceReference
|
||||
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""Service for chat operations.
|
||||
|
||||
This service handles all business logic for chat functionality,
|
||||
including sending messages and retrieving chat history.
|
||||
|
||||
Attributes:
|
||||
_client: The notebooklm-py client instance.
|
||||
"""
|
||||
|
||||
def __init__(self, client: Any = None) -> None:
|
||||
"""Initialize the chat service.
|
||||
|
||||
Args:
|
||||
client: Optional notebooklm-py client instance.
|
||||
If not provided, will be created on first use.
|
||||
"""
|
||||
self._client = client
|
||||
|
||||
async def _get_client(self) -> Any:
|
||||
"""Get or create notebooklm-py client.
|
||||
|
||||
Returns:
|
||||
The notebooklm-py client instance.
|
||||
"""
|
||||
if self._client is None:
|
||||
# Lazy initialization - import here to avoid circular imports
|
||||
from notebooklm import NotebookLMClient
|
||||
|
||||
self._client = await NotebookLMClient.from_storage()
|
||||
return self._client
|
||||
|
||||
def _validate_message(self, message: str) -> str:
|
||||
"""Validate chat message.
|
||||
|
||||
Args:
|
||||
message: The message to validate.
|
||||
|
||||
Returns:
|
||||
The validated message.
|
||||
|
||||
Raises:
|
||||
ValidationError: If message is invalid.
|
||||
"""
|
||||
if not message or not message.strip():
|
||||
raise ValidationError("Message cannot be empty")
|
||||
if len(message) > 2000:
|
||||
raise ValidationError("Message must be at most 2000 characters")
|
||||
return message.strip()
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
message: str,
|
||||
include_references: bool = True,
|
||||
) -> ChatResponse:
|
||||
"""Send a message and get a response.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
message: The message to send.
|
||||
include_references: Whether to include source references.
|
||||
|
||||
Returns:
|
||||
The chat response with message and sources.
|
||||
|
||||
Raises:
|
||||
ValidationError: If message is invalid.
|
||||
NotFoundError: If notebook not found.
|
||||
NotebookLMError: If external API fails.
|
||||
"""
|
||||
# Validate message
|
||||
validated_message = self._validate_message(message)
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
notebook = await client.notebooks.get(str(notebook_id))
|
||||
|
||||
# Send message to chat
|
||||
chat = notebook.chat
|
||||
response = await chat.send_message(validated_message)
|
||||
|
||||
# Extract response text
|
||||
response_text = getattr(response, "text", str(response))
|
||||
|
||||
# Build source references if requested
|
||||
sources = []
|
||||
if include_references:
|
||||
# Try to get citations/references from response
|
||||
citations = getattr(response, "citations", []) or getattr(
|
||||
response, "references", []
|
||||
)
|
||||
for citation in citations:
|
||||
sources.append(
|
||||
SourceReference(
|
||||
source_id=getattr(citation, "source_id", str(uuid4())),
|
||||
title=getattr(citation, "title", "Unknown Source"),
|
||||
snippet=getattr(citation, "snippet", None),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
message=response_text,
|
||||
sources=sources,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str:
|
||||
raise NotFoundError("Notebook", str(notebook_id))
|
||||
raise NotebookLMError(f"Failed to send message: {e}")
|
||||
|
||||
async def get_history(self, notebook_id: UUID) -> list[ChatMessage]:
|
||||
"""Get chat history for a notebook.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
|
||||
Returns:
|
||||
List of chat messages.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If notebook not found.
|
||||
NotebookLMError: If external API fails.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
notebook = await client.notebooks.get(str(notebook_id))
|
||||
|
||||
# Get chat history
|
||||
chat = notebook.chat
|
||||
history = await chat.get_history()
|
||||
|
||||
# Convert to ChatMessage objects
|
||||
messages = []
|
||||
for msg in history:
|
||||
role = getattr(msg, "role", "unknown")
|
||||
content = getattr(msg, "content", "")
|
||||
timestamp = getattr(msg, "timestamp", datetime.utcnow())
|
||||
msg_id = getattr(msg, "id", str(uuid4()))
|
||||
|
||||
# Extract sources for assistant messages
|
||||
sources = None
|
||||
if role == "assistant":
|
||||
citations = getattr(msg, "citations", []) or getattr(msg, "references", [])
|
||||
if citations:
|
||||
sources = [
|
||||
SourceReference(
|
||||
source_id=getattr(c, "source_id", str(uuid4())),
|
||||
title=getattr(c, "title", "Unknown Source"),
|
||||
snippet=getattr(c, "snippet", None),
|
||||
)
|
||||
for c in citations
|
||||
]
|
||||
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
id=msg_id,
|
||||
role=role,
|
||||
content=content,
|
||||
timestamp=timestamp,
|
||||
sources=sources,
|
||||
)
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str:
|
||||
raise NotFoundError("Notebook", str(notebook_id))
|
||||
raise NotebookLMError(f"Failed to get chat history: {e}")
|
||||
Reference in New Issue
Block a user