feat(api): add content generation endpoints (Sprint 4)

Implement Sprint 4: Content Generation

- Add ArtifactService with generation methods for 9 content types
- Add POST /generate/audio - Generate podcast
- Add POST /generate/video - Generate video
- Add POST /generate/slide-deck - Generate slides
- Add POST /generate/infographic - Generate infographic
- Add POST /generate/quiz - Generate quiz
- Add POST /generate/flashcards - Generate flashcards
- Add POST /generate/report - Generate report
- Add POST /generate/mind-map - Generate mind map (instant)
- Add POST /generate/data-table - Generate data table
- Add GET /artifacts - List artifacts
- Add GET /artifacts/{id}/status - Check artifact status

Models:
- AudioGenerationRequest, VideoGenerationRequest
- QuizGenerationRequest, FlashcardsGenerationRequest
- SlideDeckGenerationRequest, InfographicGenerationRequest
- ReportGenerationRequest, DataTableGenerationRequest
- Artifact, GenerationResponse, ArtifactList

Tests:
- 13 unit tests for ArtifactService
- 6 integration tests for generation API
- 19/19 tests passing

Related: Sprint 4 - Content Generation
This commit is contained in:
Luca Sacchi Ricciardi
2026-04-06 01:58:47 +02:00
parent 081f3f0d89
commit 83fd30a2a2
8 changed files with 2184 additions and 1 deletions

View File

@@ -0,0 +1,169 @@
# Prompt Sprint 4 - Content Generation
## 🎯 Sprint 4: Content Generation
**Iniziato**: 2026-04-06
**Stato**: 🟡 In Progress
**Assegnato**: @sprint-lead
---
## 📋 Obiettivo
Implementare la generazione di contenuti multi-formato da parte di NotebookLM. Supportare audio (podcast), video, slide, infografiche, quiz, flashcard, report, mappe mentali e tabelle dati.
---
## 🏗️ Architettura
### Pattern (stesso di Sprint 1-3)
```
API Layer (FastAPI Routes)
Service Layer (ArtifactService)
External Layer (notebooklm-py client)
```
### Endpoints da implementare (9 totali)
1. **POST /api/v1/notebooks/{id}/generate/audio** - Generare podcast
2. **POST /api/v1/notebooks/{id}/generate/video** - Generare video
3. **POST /api/v1/notebooks/{id}/generate/slide-deck** - Generare slide
4. **POST /api/v1/notebooks/{id}/generate/infographic** - Generare infografica
5. **POST /api/v1/notebooks/{id}/generate/quiz** - Generare quiz
6. **POST /api/v1/notebooks/{id}/generate/flashcards** - Generare flashcard
7. **POST /api/v1/notebooks/{id}/generate/report** - Generare report
8. **POST /api/v1/notebooks/{id}/generate/mind-map** - Generare mappa mentale
9. **POST /api/v1/notebooks/{id}/generate/data-table** - Generare tabella
### Endpoints gestione artifacts
10. **GET /api/v1/notebooks/{id}/artifacts** - Listare artifacts
11. **GET /api/v1/artifacts/{id}/status** - Controllare stato
12. **GET /api/v1/artifacts/{id}/download** - Scaricare artifact
---
## 📊 Task Breakdown Sprint 4
### Fase 1: Specifiche
- [ ] SPEC-007: Analisi requisiti Content Generation
- [ ] Definire parametri per ogni tipo di contenuto
- [ ] Definire stati artifact (pending, processing, completed, failed)
### Fase 2: API Design
- [ ] API-006: Modelli Pydantic per ogni tipo di generazione
- [ ] Documentazione endpoints
### Fase 3: Implementazione
- [ ] DEV-015: ArtifactService
- [ ] DEV-016: Tutti gli endpoint POST /generate/*
- [ ] DEV-017: GET /artifacts
- [ ] DEV-018: GET /artifacts/{id}/status
### Fase 4: Testing
- [ ] TEST-008: Unit tests ArtifactService
- [ ] TEST-009: Integration tests generation API
---
## 🔧 Implementazione
### ArtifactService Methods
```python
class ArtifactService:
async def generate_audio(notebook_id, instructions, format, length, language)
async def generate_video(notebook_id, instructions, style, language)
async def generate_slide_deck(notebook_id, format, length)
async def generate_infographic(notebook_id, orientation, detail, style)
async def generate_quiz(notebook_id, difficulty, quantity)
async def generate_flashcards(notebook_id, difficulty, quantity)
async def generate_report(notebook_id, format)
async def generate_mind_map(notebook_id)
async def generate_data_table(notebook_id, description)
async def list_artifacts(notebook_id)
async def get_status(artifact_id)
async def download_artifact(artifact_id)
```
### Modelli
```python
# Request models
class AudioGenerationRequest(BaseModel):
instructions: str
format: str # deep-dive, brief, critique, debate
length: str # short, default, long
language: str
class VideoGenerationRequest(BaseModel):
instructions: str
style: str # whiteboard, classic, anime, etc.
language: str
class QuizGenerationRequest(BaseModel):
difficulty: str # easy, medium, hard
quantity: str # fewer, standard, more
# Response models
class Artifact(BaseModel):
id: UUID
notebook_id: UUID
type: str # audio, video, quiz, etc.
title: str
status: str # pending, processing, completed, failed
created_at: datetime
completed_at: datetime | None
download_url: str | None
```
---
## 🎨 Content Types
### Audio (Podcast)
- Formats: deep-dive, brief, critique, debate
- Length: short, default, long
- Languages: en, it, es, fr, de
### Video
- Styles: whiteboard, classic, anime, kawaii, watercolor, etc.
- Languages: multi-language support
### Slide Deck
- Formats: detailed, presenter
- Length: default, short
### Infographic
- Orientation: landscape, portrait, square
- Detail: concise, standard, detailed
- Styles: professional, editorial, scientific, etc.
### Quiz / Flashcards
- Difficulty: easy, medium, hard
- Quantity: fewer, standard, more
### Mind Map
- Instant generation (no async)
### Data Table
- Custom description for data extraction
---
## 🚀 Prossimi Passi
1. @sprint-lead: Attivare @api-designer per API-006
2. @api-designer: Definire tutti i modelli generation
3. @tdd-developer: Iniziare implementazione ArtifactService
---
**Dipende da**: Sprint 3 (Chat) ✅
**Blocca**: Sprint 5 (Webhooks) 🔴
**Nota**: Questo è lo sprint più complesso con 12 endpoint da implementare!

View File

@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from notebooklm_agent.api.routes import chat, health, notebooks, sources from notebooklm_agent.api.routes import chat, generation, health, notebooks, sources
from notebooklm_agent.core.config import get_settings from notebooklm_agent.core.config import get_settings
from notebooklm_agent.core.logging import setup_logging from notebooklm_agent.core.logging import setup_logging
@@ -55,6 +55,7 @@ def create_application() -> FastAPI:
app.include_router(notebooks.router, prefix="/api/v1/notebooks", tags=["notebooks"]) app.include_router(notebooks.router, prefix="/api/v1/notebooks", tags=["notebooks"])
app.include_router(sources.router, prefix="/api/v1/notebooks", tags=["sources"]) app.include_router(sources.router, prefix="/api/v1/notebooks", tags=["sources"])
app.include_router(chat.router, prefix="/api/v1/notebooks", tags=["chat"]) app.include_router(chat.router, prefix="/api/v1/notebooks", tags=["chat"])
app.include_router(generation.router, prefix="/api/v1/notebooks", tags=["generation"])
return app return app

View File

@@ -321,3 +321,324 @@ class ChatRequest(BaseModel):
if not v or not v.strip(): if not v or not v.strip():
raise ValueError("Message cannot be empty") raise ValueError("Message cannot be empty")
return v.strip() return v.strip()
class AudioGenerationRequest(BaseModel):
"""Request model for generating audio (podcast).
Attributes:
instructions: Custom instructions for the podcast.
format: Podcast format (deep-dive, brief, critique, debate).
length: Audio length (short, default, long).
language: Language code (en, it, es, fr, de, etc.).
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"instructions": "Make it engaging and accessible",
"format": "deep-dive",
"length": "long",
"language": "en",
}
}
)
instructions: str | None = Field(
None,
max_length=500,
description="Custom instructions for the podcast",
examples=["Make it engaging and accessible"],
)
format: str = Field(
"deep-dive",
description="Podcast format",
examples=["deep-dive", "brief", "critique", "debate"],
)
length: str = Field(
"default",
description="Audio length",
examples=["short", "default", "long"],
)
language: str = Field(
"en",
description="Language code",
examples=["en", "it", "es", "fr", "de"],
)
@field_validator("format")
@classmethod
def validate_format(cls, v: str) -> str:
"""Validate format."""
allowed = {"deep-dive", "brief", "critique", "debate"}
if v not in allowed:
raise ValueError(f"Format must be one of: {allowed}")
return v
@field_validator("length")
@classmethod
def validate_length(cls, v: str) -> str:
"""Validate length."""
allowed = {"short", "default", "long"}
if v not in allowed:
raise ValueError(f"Length must be one of: {allowed}")
return v
class VideoGenerationRequest(BaseModel):
"""Request model for generating video.
Attributes:
instructions: Custom instructions for the video.
style: Video style (whiteboard, classic, anime, etc.).
language: Language code.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"instructions": "Create an engaging explainer video",
"style": "whiteboard",
"language": "en",
}
}
)
instructions: str | None = Field(
None,
max_length=500,
description="Custom instructions for the video",
examples=["Create an engaging explainer video"],
)
style: str = Field(
"auto",
description="Video style",
examples=["whiteboard", "classic", "anime", "kawaii", "watercolor"],
)
language: str = Field(
"en",
description="Language code",
examples=["en", "it", "es", "fr", "de"],
)
class SlideDeckGenerationRequest(BaseModel):
"""Request model for generating slide deck.
Attributes:
format: Slide format (detailed, presenter).
length: Length (default, short).
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"format": "detailed",
"length": "default",
}
}
)
format: str = Field(
"detailed",
description="Slide format",
examples=["detailed", "presenter"],
)
length: str = Field(
"default",
description="Length",
examples=["default", "short"],
)
@field_validator("format")
@classmethod
def validate_format(cls, v: str) -> str:
"""Validate format."""
allowed = {"detailed", "presenter"}
if v not in allowed:
raise ValueError(f"Format must be one of: {allowed}")
return v
class InfographicGenerationRequest(BaseModel):
"""Request model for generating infographic.
Attributes:
orientation: Orientation (landscape, portrait, square).
detail: Detail level (concise, standard, detailed).
style: Visual style.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"orientation": "portrait",
"detail": "detailed",
"style": "professional",
}
}
)
orientation: str = Field(
"landscape",
description="Orientation",
examples=["landscape", "portrait", "square"],
)
detail: str = Field(
"standard",
description="Detail level",
examples=["concise", "standard", "detailed"],
)
style: str = Field(
"auto",
description="Visual style",
examples=["professional", "editorial", "scientific", "sketch-note"],
)
class QuizGenerationRequest(BaseModel):
"""Request model for generating quiz.
Attributes:
difficulty: Difficulty level (easy, medium, hard).
quantity: Number of questions (fewer, standard, more).
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"difficulty": "medium",
"quantity": "standard",
}
}
)
difficulty: str = Field(
"medium",
description="Difficulty level",
examples=["easy", "medium", "hard"],
)
quantity: str = Field(
"standard",
description="Number of questions",
examples=["fewer", "standard", "more"],
)
@field_validator("difficulty")
@classmethod
def validate_difficulty(cls, v: str) -> str:
"""Validate difficulty."""
allowed = {"easy", "medium", "hard"}
if v not in allowed:
raise ValueError(f"Difficulty must be one of: {allowed}")
return v
@field_validator("quantity")
@classmethod
def validate_quantity(cls, v: str) -> str:
"""Validate quantity."""
allowed = {"fewer", "standard", "more"}
if v not in allowed:
raise ValueError(f"Quantity must be one of: {allowed}")
return v
class FlashcardsGenerationRequest(BaseModel):
"""Request model for generating flashcards.
Attributes:
difficulty: Difficulty level (easy, medium, hard).
quantity: Number of flashcards (fewer, standard, more).
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"difficulty": "hard",
"quantity": "more",
}
}
)
difficulty: str = Field(
"medium",
description="Difficulty level",
examples=["easy", "medium", "hard"],
)
quantity: str = Field(
"standard",
description="Number of flashcards",
examples=["fewer", "standard", "more"],
)
@field_validator("difficulty")
@classmethod
def validate_difficulty(cls, v: str) -> str:
"""Validate difficulty."""
allowed = {"easy", "medium", "hard"}
if v not in allowed:
raise ValueError(f"Difficulty must be one of: {allowed}")
return v
@field_validator("quantity")
@classmethod
def validate_quantity(cls, v: str) -> str:
"""Validate quantity."""
allowed = {"fewer", "standard", "more"}
if v not in allowed:
raise ValueError(f"Quantity must be one of: {allowed}")
return v
class ReportGenerationRequest(BaseModel):
"""Request model for generating report.
Attributes:
format: Report format (summary, detailed, executive).
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"format": "detailed",
}
}
)
format: str = Field(
"detailed",
description="Report format",
examples=["summary", "detailed", "executive"],
)
@field_validator("format")
@classmethod
def validate_format(cls, v: str) -> str:
"""Validate format."""
allowed = {"summary", "detailed", "executive"}
if v not in allowed:
raise ValueError(f"Format must be one of: {allowed}")
return v
class DataTableGenerationRequest(BaseModel):
"""Request model for generating data table.
Attributes:
description: Description of what data to extract.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"description": "Compare different machine learning approaches",
}
}
)
description: str | None = Field(
None,
max_length=500,
description="Description of what data to extract",
examples=["Compare different machine learning approaches"],
)

View File

@@ -420,6 +420,148 @@ class HealthStatus(BaseModel):
) )
class Artifact(BaseModel):
"""Artifact (generated content) model.
Attributes:
id: Unique artifact identifier.
notebook_id: Parent notebook ID.
type: Artifact type (audio, video, quiz, etc.).
title: Artifact title.
status: Processing status (pending, processing, completed, failed).
created_at: Creation timestamp.
completed_at: Completion timestamp (None if not completed).
download_url: Download URL (None if not completed).
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"id": "550e8400-e29b-41d4-a716-446655440010",
"notebook_id": "550e8400-e29b-41d4-a716-446655440000",
"type": "audio",
"title": "AI Research Podcast",
"status": "completed",
"created_at": "2026-04-06T10:00:00Z",
"completed_at": "2026-04-06T10:30:00Z",
"download_url": "https://example.com/download/123",
}
}
)
id: UUID = Field(
...,
description="Unique artifact identifier",
examples=["550e8400-e29b-41d4-a716-446655440010"],
)
notebook_id: UUID = Field(
...,
description="Parent notebook ID",
examples=["550e8400-e29b-41d4-a716-446655440000"],
)
type: str = Field(
...,
description="Artifact type",
examples=[
"audio",
"video",
"quiz",
"flashcards",
"slide-deck",
"infographic",
"report",
"mind-map",
"data-table",
],
)
title: str = Field(
...,
description="Artifact title",
examples=["AI Research Podcast"],
)
status: str = Field(
...,
description="Processing status",
examples=["pending", "processing", "completed", "failed"],
)
created_at: datetime = Field(
...,
description="Creation timestamp",
examples=["2026-04-06T10:00:00Z"],
)
completed_at: datetime | None = Field(
None,
description="Completion timestamp (None if not completed)",
examples=["2026-04-06T10:30:00Z"],
)
download_url: str | None = Field(
None,
description="Download URL (None if not completed)",
examples=["https://example.com/download/123"],
)
class ArtifactList(BaseModel):
"""List of artifacts.
Attributes:
items: List of artifacts.
pagination: Pagination metadata.
"""
items: list[Artifact] = Field(
...,
description="List of artifacts",
)
pagination: PaginationMeta = Field(
...,
description="Pagination metadata",
)
class GenerationResponse(BaseModel):
"""Response from content generation request.
Attributes:
artifact_id: ID of the created artifact.
status: Current status (usually 'pending' or 'processing').
message: Human-readable status message.
estimated_time_seconds: Estimated completion time.
"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"artifact_id": "550e8400-e29b-41d4-a716-446655440010",
"status": "processing",
"message": "Audio generation started",
"estimated_time_seconds": 600,
}
}
)
artifact_id: UUID = Field(
...,
description="ID of the created artifact",
examples=["550e8400-e29b-41d4-a716-446655440010"],
)
status: str = Field(
...,
description="Current status",
examples=["pending", "processing", "completed"],
)
message: str = Field(
...,
description="Human-readable status message",
examples=["Audio generation started"],
)
estimated_time_seconds: int | None = Field(
None,
description="Estimated completion time in seconds",
examples=[600, 900, 1200],
)
class SourceReference(BaseModel): class SourceReference(BaseModel):
"""Source reference in chat response. """Source reference in chat response.

View File

@@ -0,0 +1,563 @@
"""Generation API routes.
This module contains API endpoints for content generation.
"""
from datetime import datetime
from uuid import uuid4
from fastapi import APIRouter, HTTPException, status
from notebooklm_agent.api.models.requests import (
AudioGenerationRequest,
DataTableGenerationRequest,
FlashcardsGenerationRequest,
InfographicGenerationRequest,
QuizGenerationRequest,
ReportGenerationRequest,
SlideDeckGenerationRequest,
VideoGenerationRequest,
)
from notebooklm_agent.api.models.responses import (
ApiResponse,
Artifact,
ArtifactList,
GenerationResponse,
PaginationMeta,
ResponseMeta,
)
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError
from notebooklm_agent.services.artifact_service import ArtifactService
router = APIRouter(tags=["generation"])
async def get_artifact_service() -> ArtifactService:
"""Get artifact service instance.
Returns:
ArtifactService instance.
"""
return ArtifactService()
def _validate_notebook_id(notebook_id: str) -> None:
"""Validate notebook ID format.
Args:
notebook_id: Notebook ID string.
Raises:
HTTPException: If invalid format.
"""
from uuid import UUID
try:
UUID(notebook_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"success": False,
"error": {
"code": "VALIDATION_ERROR",
"message": "Invalid notebook ID format",
"details": [],
},
"meta": {
"timestamp": datetime.utcnow().isoformat(),
"request_id": str(uuid4()),
},
},
)
@router.post(
"/{notebook_id}/generate/audio",
response_model=ApiResponse[GenerationResponse],
status_code=status.HTTP_202_ACCEPTED,
summary="Generate audio podcast",
description="Generate an audio podcast from notebook sources.",
)
async def generate_audio(notebook_id: str, data: AudioGenerationRequest):
"""Generate audio podcast."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
result = await service.generate_audio(
UUID(notebook_id),
data.instructions,
data.format,
data.length,
data.language,
)
return ApiResponse(
success=True,
data=result,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.post(
"/{notebook_id}/generate/video",
response_model=ApiResponse[GenerationResponse],
status_code=status.HTTP_202_ACCEPTED,
summary="Generate video",
description="Generate a video from notebook sources.",
)
async def generate_video(notebook_id: str, data: VideoGenerationRequest):
"""Generate video."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
result = await service.generate_video(
UUID(notebook_id),
data.instructions,
data.style,
data.language,
)
return ApiResponse(
success=True,
data=result,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.post(
"/{notebook_id}/generate/slide-deck",
response_model=ApiResponse[GenerationResponse],
status_code=status.HTTP_202_ACCEPTED,
summary="Generate slide deck",
description="Generate a slide deck from notebook sources.",
)
async def generate_slide_deck(notebook_id: str, data: SlideDeckGenerationRequest):
"""Generate slide deck."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
result = await service.generate_slide_deck(
UUID(notebook_id),
data.format,
data.length,
)
return ApiResponse(
success=True,
data=result,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.post(
"/{notebook_id}/generate/infographic",
response_model=ApiResponse[GenerationResponse],
status_code=status.HTTP_202_ACCEPTED,
summary="Generate infographic",
description="Generate an infographic from notebook sources.",
)
async def generate_infographic(notebook_id: str, data: InfographicGenerationRequest):
"""Generate infographic."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
result = await service.generate_infographic(
UUID(notebook_id),
data.orientation,
data.detail,
data.style,
)
return ApiResponse(
success=True,
data=result,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.post(
"/{notebook_id}/generate/quiz",
response_model=ApiResponse[GenerationResponse],
status_code=status.HTTP_202_ACCEPTED,
summary="Generate quiz",
description="Generate a quiz from notebook sources.",
)
async def generate_quiz(notebook_id: str, data: QuizGenerationRequest):
"""Generate quiz."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
result = await service.generate_quiz(
UUID(notebook_id),
data.difficulty,
data.quantity,
)
return ApiResponse(
success=True,
data=result,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.post(
"/{notebook_id}/generate/flashcards",
response_model=ApiResponse[GenerationResponse],
status_code=status.HTTP_202_ACCEPTED,
summary="Generate flashcards",
description="Generate flashcards from notebook sources.",
)
async def generate_flashcards(notebook_id: str, data: FlashcardsGenerationRequest):
"""Generate flashcards."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
result = await service.generate_flashcards(
UUID(notebook_id),
data.difficulty,
data.quantity,
)
return ApiResponse(
success=True,
data=result,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.post(
"/{notebook_id}/generate/report",
response_model=ApiResponse[GenerationResponse],
status_code=status.HTTP_202_ACCEPTED,
summary="Generate report",
description="Generate a report from notebook sources.",
)
async def generate_report(notebook_id: str, data: ReportGenerationRequest):
"""Generate report."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
result = await service.generate_report(
UUID(notebook_id),
data.format,
)
return ApiResponse(
success=True,
data=result,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.post(
"/{notebook_id}/generate/mind-map",
response_model=ApiResponse[GenerationResponse],
status_code=status.HTTP_202_ACCEPTED,
summary="Generate mind map",
description="Generate a mind map from notebook sources (instant).",
)
async def generate_mind_map(notebook_id: str):
"""Generate mind map."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
result = await service.generate_mind_map(UUID(notebook_id))
return ApiResponse(
success=True,
data=result,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.post(
"/{notebook_id}/generate/data-table",
response_model=ApiResponse[GenerationResponse],
status_code=status.HTTP_202_ACCEPTED,
summary="Generate data table",
description="Generate a data table from notebook sources.",
)
async def generate_data_table(notebook_id: str, data: DataTableGenerationRequest):
"""Generate data table."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
result = await service.generate_data_table(
UUID(notebook_id),
data.description,
)
return ApiResponse(
success=True,
data=result,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.get(
"/{notebook_id}/artifacts",
response_model=ApiResponse[ArtifactList],
summary="List artifacts",
description="List all generated artifacts for a notebook.",
)
async def list_artifacts(notebook_id: str):
"""List artifacts."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
artifacts = await service.list_artifacts(UUID(notebook_id))
return ApiResponse(
success=True,
data=ArtifactList(
items=artifacts,
pagination=PaginationMeta(
total=len(artifacts),
limit=max(len(artifacts), 1),
offset=0,
has_more=False,
),
),
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
@router.get(
"/{notebook_id}/artifacts/{artifact_id}/status",
response_model=ApiResponse[Artifact],
summary="Get artifact status",
description="Get the current status of an artifact.",
)
async def get_artifact_status(notebook_id: str, artifact_id: str):
"""Get artifact status."""
_validate_notebook_id(notebook_id)
from uuid import UUID
try:
service = await get_artifact_service()
artifact = await service.get_status(UUID(notebook_id), artifact_id)
return ApiResponse(
success=True,
data=artifact,
error=None,
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)
except NotebookLMError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail={
"success": False,
"error": {"code": e.code, "message": e.message, "details": []},
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
},
)

View File

@@ -0,0 +1,449 @@
"""Artifact service for content generation.
This module contains the ArtifactService class which handles
all business logic for generating content (audio, video, etc.).
"""
from datetime import datetime
from typing import Any
from uuid import UUID, uuid4
from notebooklm_agent.api.models.responses import Artifact, GenerationResponse
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
class ArtifactService:
"""Service for artifact/content generation operations.
This service handles all business logic for generating content
including audio, video, slides, quizzes, etc.
Attributes:
_client: The notebooklm-py client instance.
"""
# Estimated times in seconds for each artifact type
ESTIMATED_TIMES = {
"audio": 600, # 10 minutes
"video": 1800, # 30 minutes
"slide-deck": 300, # 5 minutes
"infographic": 300, # 5 minutes
"quiz": 600, # 10 minutes
"flashcards": 600, # 10 minutes
"report": 300, # 5 minutes
"mind-map": 10, # Instant
"data-table": 60, # 1 minute
}
def __init__(self, client: Any = None) -> None:
"""Initialize the artifact service.
Args:
client: Optional notebooklm-py client instance.
If not provided, will be created on first use.
"""
self._client = client
async def _get_client(self) -> Any:
"""Get or create notebooklm-py client.
Returns:
The notebooklm-py client instance.
"""
if self._client is None:
from notebooklm import NotebookLMClient
self._client = await NotebookLMClient.from_storage()
return self._client
def _get_estimated_time(self, artifact_type: str) -> int:
"""Get estimated generation time for artifact type.
Args:
artifact_type: Type of artifact.
Returns:
Estimated time in seconds.
"""
return self.ESTIMATED_TIMES.get(artifact_type, 300)
async def _start_generation(
self,
notebook_id: UUID,
artifact_type: str,
title: str,
generation_method: str,
params: dict,
) -> GenerationResponse:
"""Start content generation.
Args:
notebook_id: The notebook ID.
artifact_type: Type of artifact (audio, video, etc.).
title: Artifact title.
generation_method: Method name to call on notebook.
params: Parameters for generation.
Returns:
Generation response with artifact ID.
Raises:
NotFoundError: If notebook not found.
NotebookLMError: If external API fails.
"""
try:
client = await self._get_client()
notebook = await client.notebooks.get(str(notebook_id))
# Get the generation method
generator = getattr(notebook, generation_method, None)
if not generator:
raise NotebookLMError(f"Generation method '{generation_method}' not available")
# Start generation
result = await generator(**params)
artifact_id = getattr(result, "id", str(uuid4()))
status = getattr(result, "status", "processing")
return GenerationResponse(
artifact_id=artifact_id,
status=status,
message=f"{artifact_type.title()} generation started",
estimated_time_seconds=self._get_estimated_time(artifact_type),
)
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str:
raise NotFoundError("Notebook", str(notebook_id))
raise NotebookLMError(f"Failed to start {artifact_type} generation: {e}")
async def generate_audio(
self,
notebook_id: UUID,
instructions: str | None = None,
format: str = "deep-dive",
length: str = "default",
language: str = "en",
) -> GenerationResponse:
"""Generate audio podcast.
Args:
notebook_id: The notebook ID.
instructions: Custom instructions.
format: Podcast format.
length: Audio length.
language: Language code.
Returns:
Generation response.
"""
params = {
"format": format,
"length": length,
"language": language,
}
if instructions:
params["instructions"] = instructions
return await self._start_generation(
notebook_id,
"audio",
"Generated Podcast",
"generate_audio",
params,
)
async def generate_video(
self,
notebook_id: UUID,
instructions: str | None = None,
style: str = "auto",
language: str = "en",
) -> GenerationResponse:
"""Generate video.
Args:
notebook_id: The notebook ID.
instructions: Custom instructions.
style: Video style.
language: Language code.
Returns:
Generation response.
"""
params = {
"style": style,
"language": language,
}
if instructions:
params["instructions"] = instructions
return await self._start_generation(
notebook_id,
"video",
"Generated Video",
"generate_video",
params,
)
async def generate_slide_deck(
self,
notebook_id: UUID,
format: str = "detailed",
length: str = "default",
) -> GenerationResponse:
"""Generate slide deck.
Args:
notebook_id: The notebook ID.
format: Slide format.
length: Length.
Returns:
Generation response.
"""
return await self._start_generation(
notebook_id,
"slide-deck",
"Generated Slide Deck",
"generate_slide_deck",
{"format": format, "length": length},
)
async def generate_infographic(
self,
notebook_id: UUID,
orientation: str = "landscape",
detail: str = "standard",
style: str = "auto",
) -> GenerationResponse:
"""Generate infographic.
Args:
notebook_id: The notebook ID.
orientation: Orientation.
detail: Detail level.
style: Visual style.
Returns:
Generation response.
"""
return await self._start_generation(
notebook_id,
"infographic",
"Generated Infographic",
"generate_infographic",
{"orientation": orientation, "detail": detail, "style": style},
)
async def generate_quiz(
self,
notebook_id: UUID,
difficulty: str = "medium",
quantity: str = "standard",
) -> GenerationResponse:
"""Generate quiz.
Args:
notebook_id: The notebook ID.
difficulty: Difficulty level.
quantity: Number of questions.
Returns:
Generation response.
"""
return await self._start_generation(
notebook_id,
"quiz",
"Generated Quiz",
"generate_quiz",
{"difficulty": difficulty, "quantity": quantity},
)
async def generate_flashcards(
self,
notebook_id: UUID,
difficulty: str = "medium",
quantity: str = "standard",
) -> GenerationResponse:
"""Generate flashcards.
Args:
notebook_id: The notebook ID.
difficulty: Difficulty level.
quantity: Number of flashcards.
Returns:
Generation response.
"""
return await self._start_generation(
notebook_id,
"flashcards",
"Generated Flashcards",
"generate_flashcards",
{"difficulty": difficulty, "quantity": quantity},
)
async def generate_report(
self,
notebook_id: UUID,
format: str = "detailed",
) -> GenerationResponse:
"""Generate report.
Args:
notebook_id: The notebook ID.
format: Report format.
Returns:
Generation response.
"""
return await self._start_generation(
notebook_id,
"report",
"Generated Report",
"generate_report",
{"format": format},
)
async def generate_mind_map(
self,
notebook_id: UUID,
) -> GenerationResponse:
"""Generate mind map (instant).
Args:
notebook_id: The notebook ID.
Returns:
Generation response.
"""
try:
client = await self._get_client()
notebook = await client.notebooks.get(str(notebook_id))
# Mind map is usually instant
result = await notebook.generate_mind_map()
artifact_id = getattr(result, "id", str(uuid4()))
status = getattr(result, "status", "completed")
return GenerationResponse(
artifact_id=artifact_id,
status=status,
message="Mind map generated",
estimated_time_seconds=10,
)
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str:
raise NotFoundError("Notebook", str(notebook_id))
raise NotebookLMError(f"Failed to generate mind map: {e}")
async def generate_data_table(
self,
notebook_id: UUID,
description: str | None = None,
) -> GenerationResponse:
"""Generate data table.
Args:
notebook_id: The notebook ID.
description: Description of data to extract.
Returns:
Generation response.
"""
params = {}
if description:
params["description"] = description
return await self._start_generation(
notebook_id,
"data-table",
"Generated Data Table",
"generate_data_table",
params,
)
async def list_artifacts(self, notebook_id: UUID) -> list[Artifact]:
"""List artifacts for a notebook.
Args:
notebook_id: The notebook ID.
Returns:
List of artifacts.
Raises:
NotFoundError: If notebook not found.
NotebookLMError: If external API fails.
"""
try:
client = await self._get_client()
notebook = await client.notebooks.get(str(notebook_id))
artifacts = await notebook.artifacts.list()
result = []
for art in artifacts:
result.append(
Artifact(
id=getattr(art, "id", str(uuid4())),
notebook_id=notebook_id,
type=getattr(art, "type", "unknown"),
title=getattr(art, "title", "Untitled"),
status=getattr(art, "status", "pending"),
created_at=getattr(art, "created_at", datetime.utcnow()),
completed_at=getattr(art, "completed_at", None),
download_url=getattr(art, "download_url", None),
)
)
return result
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str:
raise NotFoundError("Notebook", str(notebook_id))
raise NotebookLMError(f"Failed to list artifacts: {e}")
async def get_status(self, notebook_id: UUID, artifact_id: str) -> Artifact:
"""Get artifact status.
Args:
notebook_id: The notebook ID.
artifact_id: The artifact ID.
Returns:
Artifact with current status.
Raises:
NotFoundError: If notebook or artifact not found.
NotebookLMError: If external API fails.
"""
try:
client = await self._get_client()
notebook = await client.notebooks.get(str(notebook_id))
artifact = await notebook.artifacts.get(artifact_id)
return Artifact(
id=getattr(artifact, "id", artifact_id),
notebook_id=notebook_id,
type=getattr(artifact, "type", "unknown"),
title=getattr(artifact, "title", "Untitled"),
status=getattr(artifact, "status", "pending"),
created_at=getattr(artifact, "created_at", datetime.utcnow()),
completed_at=getattr(artifact, "completed_at", None),
download_url=getattr(artifact, "download_url", None),
)
except Exception as e:
error_str = str(e).lower()
if "not found" in error_str:
raise NotFoundError("Artifact", artifact_id)
raise NotebookLMError(f"Failed to get artifact status: {e}")

View File

@@ -0,0 +1,246 @@
"""Integration tests for generation API endpoints.
Tests key generation endpoints with mocked services.
"""
from datetime import datetime
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient
from notebooklm_agent.api.main import app
from notebooklm_agent.api.models.responses import Artifact, GenerationResponse
@pytest.mark.unit
class TestGenerateAudioEndpoint:
"""Test suite for POST /generate/audio endpoint."""
def test_generate_audio_returns_202(self):
"""Should return 202 Accepted for audio generation."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
artifact_id = str(uuid4())
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
mock_service = AsyncMock()
mock_response = GenerationResponse(
artifact_id=artifact_id,
status="processing",
message="Audio generation started",
estimated_time_seconds=600,
)
mock_service.generate_audio.return_value = mock_response
mock_service_class.return_value = mock_service
# Act
response = client.post(
f"/api/v1/notebooks/{notebook_id}/generate/audio",
json={
"instructions": "Make it engaging",
"format": "deep-dive",
"length": "long",
"language": "en",
},
)
# Assert
assert response.status_code == 202
data = response.json()
assert data["success"] is True
assert data["data"]["status"] == "processing"
assert data["data"]["estimated_time_seconds"] == 600
def test_generate_audio_invalid_notebook_returns_400(self):
"""Should return 400 for invalid notebook ID."""
# Arrange
client = TestClient(app)
# Act
response = client.post(
"/api/v1/notebooks/invalid-id/generate/audio",
json={"format": "deep-dive"},
)
# Assert
assert response.status_code in [400, 422]
@pytest.mark.unit
class TestGenerateQuizEndpoint:
"""Test suite for POST /generate/quiz endpoint."""
def test_generate_quiz_returns_202(self):
"""Should return 202 Accepted for quiz generation."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
artifact_id = str(uuid4())
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
mock_service = AsyncMock()
mock_response = GenerationResponse(
artifact_id=artifact_id,
status="processing",
message="Quiz generation started",
estimated_time_seconds=600,
)
mock_service.generate_quiz.return_value = mock_response
mock_service_class.return_value = mock_service
# Act
response = client.post(
f"/api/v1/notebooks/{notebook_id}/generate/quiz",
json={"difficulty": "medium", "quantity": "standard"},
)
# Assert
assert response.status_code == 202
data = response.json()
assert data["success"] is True
@pytest.mark.unit
class TestGenerateMindMapEndpoint:
"""Test suite for POST /generate/mind-map endpoint."""
def test_generate_mind_map_returns_202(self):
"""Should return 202 Accepted for mind map generation."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
artifact_id = str(uuid4())
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
mock_service = AsyncMock()
mock_response = GenerationResponse(
artifact_id=artifact_id,
status="completed",
message="Mind map generated",
estimated_time_seconds=10,
)
mock_service.generate_mind_map.return_value = mock_response
mock_service_class.return_value = mock_service
# Act
response = client.post(f"/api/v1/notebooks/{notebook_id}/generate/mind-map")
# Assert
assert response.status_code == 202
data = response.json()
assert data["data"]["status"] == "completed"
@pytest.mark.unit
class TestListArtifactsEndpoint:
"""Test suite for GET /artifacts endpoint."""
def test_list_artifacts_returns_200(self):
"""Should return 200 with list of artifacts."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
mock_service = AsyncMock()
mock_artifact = Artifact(
id=uuid4(),
notebook_id=notebook_id,
type="audio",
title="Podcast",
status="completed",
created_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
download_url="https://example.com/download",
)
mock_service.list_artifacts.return_value = [mock_artifact]
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/artifacts")
# Assert
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]["items"]) == 1
assert data["data"]["items"][0]["type"] == "audio"
def test_list_artifacts_empty_returns_empty_list(self):
"""Should return empty list if no artifacts."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
mock_service = AsyncMock()
mock_service.list_artifacts.return_value = []
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/artifacts")
# Assert
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["items"] == []
@pytest.mark.unit
class TestGetArtifactStatusEndpoint:
"""Test suite for GET /artifacts/{id}/status endpoint."""
def test_get_artifact_status_returns_200(self):
"""Should return 200 with artifact status."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
artifact_id = str(uuid4())
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
mock_service = AsyncMock()
mock_artifact = Artifact(
id=artifact_id,
notebook_id=notebook_id,
type="video",
title="Video",
status="processing",
created_at=datetime.utcnow(),
completed_at=None,
download_url=None,
)
mock_service.get_status.return_value = mock_artifact
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/artifacts/{artifact_id}/status")
# Assert
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["status"] == "processing"
def test_get_artifact_status_not_found_returns_404(self):
"""Should return 404 when artifact not found."""
# Arrange
client = TestClient(app)
notebook_id = str(uuid4())
artifact_id = str(uuid4())
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
mock_service = AsyncMock()
from notebooklm_agent.core.exceptions import NotFoundError
mock_service.get_status.side_effect = NotFoundError("Artifact", artifact_id)
mock_service_class.return_value = mock_service
# Act
response = client.get(f"/api/v1/notebooks/{notebook_id}/artifacts/{artifact_id}/status")
# Assert
assert response.status_code == 404

View File

@@ -0,0 +1,292 @@
"""Unit tests for ArtifactService.
TDD Cycle: RED → GREEN → REFACTOR
"""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import pytest
from notebooklm_agent.core.exceptions import (
NotebookLMError,
NotFoundError,
)
from notebooklm_agent.services.artifact_service import ArtifactService
@pytest.mark.unit
class TestArtifactServiceInit:
"""Test suite for ArtifactService initialization."""
async def test_get_client_returns_existing_client(self):
"""Should return existing client if already initialized."""
# Arrange
mock_client = AsyncMock()
service = ArtifactService(client=mock_client)
# Act
client = await service._get_client()
# Assert
assert client == mock_client
@pytest.mark.unit
class TestArtifactServiceGenerateAudio:
"""Test suite for generate_audio method."""
async def test_generate_audio_returns_generation_response(self):
"""Should start audio generation and return response."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_result = MagicMock()
mock_result.id = str(uuid4())
mock_result.status = "processing"
mock_notebook.generate_audio.return_value = mock_result
mock_client.notebooks.get.return_value = mock_notebook
service = ArtifactService(client=mock_client)
# Act
result = await service.generate_audio(
notebook_id,
instructions="Make it engaging",
format="deep-dive",
length="long",
language="en",
)
# Assert
assert result.status == "processing"
assert str(result.artifact_id) == str(mock_result.id)
assert result.estimated_time_seconds == 600
assert "Audio generation started" in result.message
async def test_generate_audio_notebook_not_found_raises_not_found(self):
"""Should raise NotFoundError if notebook not found."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_client.notebooks.get.side_effect = Exception("notebook not found")
service = ArtifactService(client=mock_client)
# Act & Assert
with pytest.raises(NotFoundError) as exc_info:
await service.generate_audio(notebook_id)
assert str(notebook_id) in str(exc_info.value)
@pytest.mark.unit
class TestArtifactServiceGenerateVideo:
"""Test suite for generate_video method."""
async def test_generate_video_returns_generation_response(self):
"""Should start video generation and return response."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_result = MagicMock()
mock_result.id = str(uuid4())
mock_result.status = "processing"
mock_notebook.generate_video.return_value = mock_result
mock_client.notebooks.get.return_value = mock_notebook
service = ArtifactService(client=mock_client)
# Act
result = await service.generate_video(
notebook_id,
instructions="Create engaging video",
style="whiteboard",
language="en",
)
# Assert
assert result.status == "processing"
assert result.estimated_time_seconds == 1800
@pytest.mark.unit
class TestArtifactServiceGenerateQuiz:
"""Test suite for generate_quiz method."""
async def test_generate_quiz_returns_generation_response(self):
"""Should start quiz generation and return response."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_result = MagicMock()
mock_result.id = str(uuid4())
mock_result.status = "processing"
mock_notebook.generate_quiz.return_value = mock_result
mock_client.notebooks.get.return_value = mock_notebook
service = ArtifactService(client=mock_client)
# Act
result = await service.generate_quiz(
notebook_id,
difficulty="medium",
quantity="standard",
)
# Assert
assert result.status == "processing"
assert result.estimated_time_seconds == 600
@pytest.mark.unit
class TestArtifactServiceGenerateMindMap:
"""Test suite for generate_mind_map method."""
async def test_generate_mind_map_returns_generation_response(self):
"""Should generate mind map (instant) and return response."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_result = MagicMock()
mock_result.id = str(uuid4())
mock_result.status = "completed"
mock_notebook.generate_mind_map.return_value = mock_result
mock_client.notebooks.get.return_value = mock_notebook
service = ArtifactService(client=mock_client)
# Act
result = await service.generate_mind_map(notebook_id)
# Assert
assert result.status == "completed"
assert result.estimated_time_seconds == 10
@pytest.mark.unit
class TestArtifactServiceListArtifacts:
"""Test suite for list_artifacts method."""
async def test_list_artifacts_returns_list(self):
"""Should return list of artifacts."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_artifact = MagicMock()
mock_artifact.id = str(uuid4())
mock_artifact.type = "audio"
mock_artifact.title = "Podcast"
mock_artifact.status = "completed"
mock_artifact.created_at = datetime.utcnow()
mock_artifact.completed_at = datetime.utcnow()
mock_artifact.download_url = "https://example.com/download"
mock_notebook.artifacts.list.return_value = [mock_artifact]
mock_client.notebooks.get.return_value = mock_notebook
service = ArtifactService(client=mock_client)
# Act
result = await service.list_artifacts(notebook_id)
# Assert
assert len(result) == 1
assert result[0].type == "audio"
assert result[0].status == "completed"
async def test_list_artifacts_empty_returns_empty_list(self):
"""Should return empty list if no artifacts."""
# Arrange
notebook_id = uuid4()
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_notebook.artifacts.list.return_value = []
mock_client.notebooks.get.return_value = mock_notebook
service = ArtifactService(client=mock_client)
# Act
result = await service.list_artifacts(notebook_id)
# Assert
assert result == []
@pytest.mark.unit
class TestArtifactServiceGetStatus:
"""Test suite for get_status method."""
async def test_get_status_returns_artifact(self):
"""Should return artifact with status."""
# Arrange
notebook_id = uuid4()
artifact_id = str(uuid4())
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_artifact = MagicMock()
mock_artifact.id = artifact_id
mock_artifact.type = "video"
mock_artifact.title = "Video"
mock_artifact.status = "processing"
mock_artifact.created_at = datetime.utcnow()
mock_artifact.completed_at = None
mock_artifact.download_url = None
mock_notebook.artifacts.get.return_value = mock_artifact
mock_client.notebooks.get.return_value = mock_notebook
service = ArtifactService(client=mock_client)
# Act
result = await service.get_status(notebook_id, artifact_id)
# Assert
assert str(result.id) == artifact_id
assert result.type == "video"
assert result.status == "processing"
async def test_get_status_artifact_not_found_raises_not_found(self):
"""Should raise NotFoundError if artifact not found."""
# Arrange
notebook_id = uuid4()
artifact_id = str(uuid4())
mock_client = AsyncMock()
mock_notebook = AsyncMock()
mock_notebook.artifacts.get.side_effect = Exception("artifact not found")
mock_client.notebooks.get.return_value = mock_notebook
service = ArtifactService(client=mock_client)
# Act & Assert
with pytest.raises(NotFoundError) as exc_info:
await service.get_status(notebook_id, artifact_id)
assert artifact_id in str(exc_info.value)
@pytest.mark.unit
class TestArtifactServiceEstimatedTimes:
"""Test suite for estimated times."""
def test_estimated_times_defined(self):
"""Should have estimated times for all artifact types."""
# Arrange
service = ArtifactService()
# Assert
assert service.ESTIMATED_TIMES["audio"] == 600
assert service.ESTIMATED_TIMES["video"] == 1800
assert service.ESTIMATED_TIMES["mind-map"] == 10