feat(api): add content generation endpoints (Sprint 4)
Implement Sprint 4: Content Generation
- Add ArtifactService with generation methods for 9 content types
- Add POST /generate/audio - Generate podcast
- Add POST /generate/video - Generate video
- Add POST /generate/slide-deck - Generate slides
- Add POST /generate/infographic - Generate infographic
- Add POST /generate/quiz - Generate quiz
- Add POST /generate/flashcards - Generate flashcards
- Add POST /generate/report - Generate report
- Add POST /generate/mind-map - Generate mind map (instant)
- Add POST /generate/data-table - Generate data table
- Add GET /artifacts - List artifacts
- Add GET /artifacts/{id}/status - Check artifact status
Models:
- AudioGenerationRequest, VideoGenerationRequest
- QuizGenerationRequest, FlashcardsGenerationRequest
- SlideDeckGenerationRequest, InfographicGenerationRequest
- ReportGenerationRequest, DataTableGenerationRequest
- Artifact, GenerationResponse, ArtifactList
Tests:
- 13 unit tests for ArtifactService
- 6 integration tests for generation API
- 19/19 tests passing
Related: Sprint 4 - Content Generation
This commit is contained in:
169
prompts/4-content-generation.md
Normal file
169
prompts/4-content-generation.md
Normal file
@@ -0,0 +1,169 @@
|
||||
# Prompt Sprint 4 - Content Generation
|
||||
|
||||
## 🎯 Sprint 4: Content Generation
|
||||
|
||||
**Iniziato**: 2026-04-06
|
||||
**Stato**: 🟡 In Progress
|
||||
**Assegnato**: @sprint-lead
|
||||
|
||||
---
|
||||
|
||||
## 📋 Obiettivo
|
||||
|
||||
Implementare la generazione di contenuti multi-formato da parte di NotebookLM. Supportare audio (podcast), video, slide, infografiche, quiz, flashcard, report, mappe mentali e tabelle dati.
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ Architettura
|
||||
|
||||
### Pattern (stesso di Sprint 1-3)
|
||||
|
||||
```
|
||||
API Layer (FastAPI Routes)
|
||||
↓
|
||||
Service Layer (ArtifactService)
|
||||
↓
|
||||
External Layer (notebooklm-py client)
|
||||
```
|
||||
|
||||
### Endpoints da implementare (9 totali)
|
||||
|
||||
1. **POST /api/v1/notebooks/{id}/generate/audio** - Generare podcast
|
||||
2. **POST /api/v1/notebooks/{id}/generate/video** - Generare video
|
||||
3. **POST /api/v1/notebooks/{id}/generate/slide-deck** - Generare slide
|
||||
4. **POST /api/v1/notebooks/{id}/generate/infographic** - Generare infografica
|
||||
5. **POST /api/v1/notebooks/{id}/generate/quiz** - Generare quiz
|
||||
6. **POST /api/v1/notebooks/{id}/generate/flashcards** - Generare flashcard
|
||||
7. **POST /api/v1/notebooks/{id}/generate/report** - Generare report
|
||||
8. **POST /api/v1/notebooks/{id}/generate/mind-map** - Generare mappa mentale
|
||||
9. **POST /api/v1/notebooks/{id}/generate/data-table** - Generare tabella
|
||||
|
||||
### Endpoints gestione artifacts
|
||||
|
||||
10. **GET /api/v1/notebooks/{id}/artifacts** - Listare artifacts
|
||||
11. **GET /api/v1/artifacts/{id}/status** - Controllare stato
|
||||
12. **GET /api/v1/artifacts/{id}/download** - Scaricare artifact
|
||||
|
||||
---
|
||||
|
||||
## 📊 Task Breakdown Sprint 4
|
||||
|
||||
### Fase 1: Specifiche
|
||||
- [ ] SPEC-007: Analisi requisiti Content Generation
|
||||
- [ ] Definire parametri per ogni tipo di contenuto
|
||||
- [ ] Definire stati artifact (pending, processing, completed, failed)
|
||||
|
||||
### Fase 2: API Design
|
||||
- [ ] API-006: Modelli Pydantic per ogni tipo di generazione
|
||||
- [ ] Documentazione endpoints
|
||||
|
||||
### Fase 3: Implementazione
|
||||
- [ ] DEV-015: ArtifactService
|
||||
- [ ] DEV-016: Tutti gli endpoint POST /generate/*
|
||||
- [ ] DEV-017: GET /artifacts
|
||||
- [ ] DEV-018: GET /artifacts/{id}/status
|
||||
|
||||
### Fase 4: Testing
|
||||
- [ ] TEST-008: Unit tests ArtifactService
|
||||
- [ ] TEST-009: Integration tests generation API
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Implementazione
|
||||
|
||||
### ArtifactService Methods
|
||||
|
||||
```python
|
||||
class ArtifactService:
|
||||
async def generate_audio(notebook_id, instructions, format, length, language)
|
||||
async def generate_video(notebook_id, instructions, style, language)
|
||||
async def generate_slide_deck(notebook_id, format, length)
|
||||
async def generate_infographic(notebook_id, orientation, detail, style)
|
||||
async def generate_quiz(notebook_id, difficulty, quantity)
|
||||
async def generate_flashcards(notebook_id, difficulty, quantity)
|
||||
async def generate_report(notebook_id, format)
|
||||
async def generate_mind_map(notebook_id)
|
||||
async def generate_data_table(notebook_id, description)
|
||||
|
||||
async def list_artifacts(notebook_id)
|
||||
async def get_status(artifact_id)
|
||||
async def download_artifact(artifact_id)
|
||||
```
|
||||
|
||||
### Modelli
|
||||
|
||||
```python
|
||||
# Request models
|
||||
class AudioGenerationRequest(BaseModel):
|
||||
instructions: str
|
||||
format: str # deep-dive, brief, critique, debate
|
||||
length: str # short, default, long
|
||||
language: str
|
||||
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
instructions: str
|
||||
style: str # whiteboard, classic, anime, etc.
|
||||
language: str
|
||||
|
||||
class QuizGenerationRequest(BaseModel):
|
||||
difficulty: str # easy, medium, hard
|
||||
quantity: str # fewer, standard, more
|
||||
|
||||
# Response models
|
||||
class Artifact(BaseModel):
|
||||
id: UUID
|
||||
notebook_id: UUID
|
||||
type: str # audio, video, quiz, etc.
|
||||
title: str
|
||||
status: str # pending, processing, completed, failed
|
||||
created_at: datetime
|
||||
completed_at: datetime | None
|
||||
download_url: str | None
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎨 Content Types
|
||||
|
||||
### Audio (Podcast)
|
||||
- Formats: deep-dive, brief, critique, debate
|
||||
- Length: short, default, long
|
||||
- Languages: en, it, es, fr, de
|
||||
|
||||
### Video
|
||||
- Styles: whiteboard, classic, anime, kawaii, watercolor, etc.
|
||||
- Languages: multi-language support
|
||||
|
||||
### Slide Deck
|
||||
- Formats: detailed, presenter
|
||||
- Length: default, short
|
||||
|
||||
### Infographic
|
||||
- Orientation: landscape, portrait, square
|
||||
- Detail: concise, standard, detailed
|
||||
- Styles: professional, editorial, scientific, etc.
|
||||
|
||||
### Quiz / Flashcards
|
||||
- Difficulty: easy, medium, hard
|
||||
- Quantity: fewer, standard, more
|
||||
|
||||
### Mind Map
|
||||
- Instant generation (no async)
|
||||
|
||||
### Data Table
|
||||
- Custom description for data extraction
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Prossimi Passi
|
||||
|
||||
1. @sprint-lead: Attivare @api-designer per API-006
|
||||
2. @api-designer: Definire tutti i modelli generation
|
||||
3. @tdd-developer: Iniziare implementazione ArtifactService
|
||||
|
||||
---
|
||||
|
||||
**Dipende da**: Sprint 3 (Chat) ✅
|
||||
**Blocca**: Sprint 5 (Webhooks) 🔴
|
||||
|
||||
**Nota**: Questo è lo sprint più complesso con 12 endpoint da implementare!
|
||||
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from notebooklm_agent.api.routes import chat, health, notebooks, sources
|
||||
from notebooklm_agent.api.routes import chat, generation, health, notebooks, sources
|
||||
from notebooklm_agent.core.config import get_settings
|
||||
from notebooklm_agent.core.logging import setup_logging
|
||||
|
||||
@@ -55,6 +55,7 @@ def create_application() -> FastAPI:
|
||||
app.include_router(notebooks.router, prefix="/api/v1/notebooks", tags=["notebooks"])
|
||||
app.include_router(sources.router, prefix="/api/v1/notebooks", tags=["sources"])
|
||||
app.include_router(chat.router, prefix="/api/v1/notebooks", tags=["chat"])
|
||||
app.include_router(generation.router, prefix="/api/v1/notebooks", tags=["generation"])
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -321,3 +321,324 @@ class ChatRequest(BaseModel):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Message cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class AudioGenerationRequest(BaseModel):
|
||||
"""Request model for generating audio (podcast).
|
||||
|
||||
Attributes:
|
||||
instructions: Custom instructions for the podcast.
|
||||
format: Podcast format (deep-dive, brief, critique, debate).
|
||||
length: Audio length (short, default, long).
|
||||
language: Language code (en, it, es, fr, de, etc.).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"instructions": "Make it engaging and accessible",
|
||||
"format": "deep-dive",
|
||||
"length": "long",
|
||||
"language": "en",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
instructions: str | None = Field(
|
||||
None,
|
||||
max_length=500,
|
||||
description="Custom instructions for the podcast",
|
||||
examples=["Make it engaging and accessible"],
|
||||
)
|
||||
format: str = Field(
|
||||
"deep-dive",
|
||||
description="Podcast format",
|
||||
examples=["deep-dive", "brief", "critique", "debate"],
|
||||
)
|
||||
length: str = Field(
|
||||
"default",
|
||||
description="Audio length",
|
||||
examples=["short", "default", "long"],
|
||||
)
|
||||
language: str = Field(
|
||||
"en",
|
||||
description="Language code",
|
||||
examples=["en", "it", "es", "fr", "de"],
|
||||
)
|
||||
|
||||
@field_validator("format")
|
||||
@classmethod
|
||||
def validate_format(cls, v: str) -> str:
|
||||
"""Validate format."""
|
||||
allowed = {"deep-dive", "brief", "critique", "debate"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"Format must be one of: {allowed}")
|
||||
return v
|
||||
|
||||
@field_validator("length")
|
||||
@classmethod
|
||||
def validate_length(cls, v: str) -> str:
|
||||
"""Validate length."""
|
||||
allowed = {"short", "default", "long"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"Length must be one of: {allowed}")
|
||||
return v
|
||||
|
||||
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
"""Request model for generating video.
|
||||
|
||||
Attributes:
|
||||
instructions: Custom instructions for the video.
|
||||
style: Video style (whiteboard, classic, anime, etc.).
|
||||
language: Language code.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"instructions": "Create an engaging explainer video",
|
||||
"style": "whiteboard",
|
||||
"language": "en",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
instructions: str | None = Field(
|
||||
None,
|
||||
max_length=500,
|
||||
description="Custom instructions for the video",
|
||||
examples=["Create an engaging explainer video"],
|
||||
)
|
||||
style: str = Field(
|
||||
"auto",
|
||||
description="Video style",
|
||||
examples=["whiteboard", "classic", "anime", "kawaii", "watercolor"],
|
||||
)
|
||||
language: str = Field(
|
||||
"en",
|
||||
description="Language code",
|
||||
examples=["en", "it", "es", "fr", "de"],
|
||||
)
|
||||
|
||||
|
||||
class SlideDeckGenerationRequest(BaseModel):
|
||||
"""Request model for generating slide deck.
|
||||
|
||||
Attributes:
|
||||
format: Slide format (detailed, presenter).
|
||||
length: Length (default, short).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"format": "detailed",
|
||||
"length": "default",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
format: str = Field(
|
||||
"detailed",
|
||||
description="Slide format",
|
||||
examples=["detailed", "presenter"],
|
||||
)
|
||||
length: str = Field(
|
||||
"default",
|
||||
description="Length",
|
||||
examples=["default", "short"],
|
||||
)
|
||||
|
||||
@field_validator("format")
|
||||
@classmethod
|
||||
def validate_format(cls, v: str) -> str:
|
||||
"""Validate format."""
|
||||
allowed = {"detailed", "presenter"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"Format must be one of: {allowed}")
|
||||
return v
|
||||
|
||||
|
||||
class InfographicGenerationRequest(BaseModel):
|
||||
"""Request model for generating infographic.
|
||||
|
||||
Attributes:
|
||||
orientation: Orientation (landscape, portrait, square).
|
||||
detail: Detail level (concise, standard, detailed).
|
||||
style: Visual style.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"orientation": "portrait",
|
||||
"detail": "detailed",
|
||||
"style": "professional",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
orientation: str = Field(
|
||||
"landscape",
|
||||
description="Orientation",
|
||||
examples=["landscape", "portrait", "square"],
|
||||
)
|
||||
detail: str = Field(
|
||||
"standard",
|
||||
description="Detail level",
|
||||
examples=["concise", "standard", "detailed"],
|
||||
)
|
||||
style: str = Field(
|
||||
"auto",
|
||||
description="Visual style",
|
||||
examples=["professional", "editorial", "scientific", "sketch-note"],
|
||||
)
|
||||
|
||||
|
||||
class QuizGenerationRequest(BaseModel):
|
||||
"""Request model for generating quiz.
|
||||
|
||||
Attributes:
|
||||
difficulty: Difficulty level (easy, medium, hard).
|
||||
quantity: Number of questions (fewer, standard, more).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"difficulty": "medium",
|
||||
"quantity": "standard",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
difficulty: str = Field(
|
||||
"medium",
|
||||
description="Difficulty level",
|
||||
examples=["easy", "medium", "hard"],
|
||||
)
|
||||
quantity: str = Field(
|
||||
"standard",
|
||||
description="Number of questions",
|
||||
examples=["fewer", "standard", "more"],
|
||||
)
|
||||
|
||||
@field_validator("difficulty")
|
||||
@classmethod
|
||||
def validate_difficulty(cls, v: str) -> str:
|
||||
"""Validate difficulty."""
|
||||
allowed = {"easy", "medium", "hard"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"Difficulty must be one of: {allowed}")
|
||||
return v
|
||||
|
||||
@field_validator("quantity")
|
||||
@classmethod
|
||||
def validate_quantity(cls, v: str) -> str:
|
||||
"""Validate quantity."""
|
||||
allowed = {"fewer", "standard", "more"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"Quantity must be one of: {allowed}")
|
||||
return v
|
||||
|
||||
|
||||
class FlashcardsGenerationRequest(BaseModel):
|
||||
"""Request model for generating flashcards.
|
||||
|
||||
Attributes:
|
||||
difficulty: Difficulty level (easy, medium, hard).
|
||||
quantity: Number of flashcards (fewer, standard, more).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"difficulty": "hard",
|
||||
"quantity": "more",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
difficulty: str = Field(
|
||||
"medium",
|
||||
description="Difficulty level",
|
||||
examples=["easy", "medium", "hard"],
|
||||
)
|
||||
quantity: str = Field(
|
||||
"standard",
|
||||
description="Number of flashcards",
|
||||
examples=["fewer", "standard", "more"],
|
||||
)
|
||||
|
||||
@field_validator("difficulty")
|
||||
@classmethod
|
||||
def validate_difficulty(cls, v: str) -> str:
|
||||
"""Validate difficulty."""
|
||||
allowed = {"easy", "medium", "hard"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"Difficulty must be one of: {allowed}")
|
||||
return v
|
||||
|
||||
@field_validator("quantity")
|
||||
@classmethod
|
||||
def validate_quantity(cls, v: str) -> str:
|
||||
"""Validate quantity."""
|
||||
allowed = {"fewer", "standard", "more"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"Quantity must be one of: {allowed}")
|
||||
return v
|
||||
|
||||
|
||||
class ReportGenerationRequest(BaseModel):
|
||||
"""Request model for generating report.
|
||||
|
||||
Attributes:
|
||||
format: Report format (summary, detailed, executive).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"format": "detailed",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
format: str = Field(
|
||||
"detailed",
|
||||
description="Report format",
|
||||
examples=["summary", "detailed", "executive"],
|
||||
)
|
||||
|
||||
@field_validator("format")
|
||||
@classmethod
|
||||
def validate_format(cls, v: str) -> str:
|
||||
"""Validate format."""
|
||||
allowed = {"summary", "detailed", "executive"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"Format must be one of: {allowed}")
|
||||
return v
|
||||
|
||||
|
||||
class DataTableGenerationRequest(BaseModel):
|
||||
"""Request model for generating data table.
|
||||
|
||||
Attributes:
|
||||
description: Description of what data to extract.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"description": "Compare different machine learning approaches",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
None,
|
||||
max_length=500,
|
||||
description="Description of what data to extract",
|
||||
examples=["Compare different machine learning approaches"],
|
||||
)
|
||||
|
||||
@@ -420,6 +420,148 @@ class HealthStatus(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
"""Artifact (generated content) model.
|
||||
|
||||
Attributes:
|
||||
id: Unique artifact identifier.
|
||||
notebook_id: Parent notebook ID.
|
||||
type: Artifact type (audio, video, quiz, etc.).
|
||||
title: Artifact title.
|
||||
status: Processing status (pending, processing, completed, failed).
|
||||
created_at: Creation timestamp.
|
||||
completed_at: Completion timestamp (None if not completed).
|
||||
download_url: Download URL (None if not completed).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440010",
|
||||
"notebook_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"type": "audio",
|
||||
"title": "AI Research Podcast",
|
||||
"status": "completed",
|
||||
"created_at": "2026-04-06T10:00:00Z",
|
||||
"completed_at": "2026-04-06T10:30:00Z",
|
||||
"download_url": "https://example.com/download/123",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
id: UUID = Field(
|
||||
...,
|
||||
description="Unique artifact identifier",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440010"],
|
||||
)
|
||||
notebook_id: UUID = Field(
|
||||
...,
|
||||
description="Parent notebook ID",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440000"],
|
||||
)
|
||||
type: str = Field(
|
||||
...,
|
||||
description="Artifact type",
|
||||
examples=[
|
||||
"audio",
|
||||
"video",
|
||||
"quiz",
|
||||
"flashcards",
|
||||
"slide-deck",
|
||||
"infographic",
|
||||
"report",
|
||||
"mind-map",
|
||||
"data-table",
|
||||
],
|
||||
)
|
||||
title: str = Field(
|
||||
...,
|
||||
description="Artifact title",
|
||||
examples=["AI Research Podcast"],
|
||||
)
|
||||
status: str = Field(
|
||||
...,
|
||||
description="Processing status",
|
||||
examples=["pending", "processing", "completed", "failed"],
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
...,
|
||||
description="Creation timestamp",
|
||||
examples=["2026-04-06T10:00:00Z"],
|
||||
)
|
||||
completed_at: datetime | None = Field(
|
||||
None,
|
||||
description="Completion timestamp (None if not completed)",
|
||||
examples=["2026-04-06T10:30:00Z"],
|
||||
)
|
||||
download_url: str | None = Field(
|
||||
None,
|
||||
description="Download URL (None if not completed)",
|
||||
examples=["https://example.com/download/123"],
|
||||
)
|
||||
|
||||
|
||||
class ArtifactList(BaseModel):
|
||||
"""List of artifacts.
|
||||
|
||||
Attributes:
|
||||
items: List of artifacts.
|
||||
pagination: Pagination metadata.
|
||||
"""
|
||||
|
||||
items: list[Artifact] = Field(
|
||||
...,
|
||||
description="List of artifacts",
|
||||
)
|
||||
pagination: PaginationMeta = Field(
|
||||
...,
|
||||
description="Pagination metadata",
|
||||
)
|
||||
|
||||
|
||||
class GenerationResponse(BaseModel):
|
||||
"""Response from content generation request.
|
||||
|
||||
Attributes:
|
||||
artifact_id: ID of the created artifact.
|
||||
status: Current status (usually 'pending' or 'processing').
|
||||
message: Human-readable status message.
|
||||
estimated_time_seconds: Estimated completion time.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"artifact_id": "550e8400-e29b-41d4-a716-446655440010",
|
||||
"status": "processing",
|
||||
"message": "Audio generation started",
|
||||
"estimated_time_seconds": 600,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
artifact_id: UUID = Field(
|
||||
...,
|
||||
description="ID of the created artifact",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440010"],
|
||||
)
|
||||
status: str = Field(
|
||||
...,
|
||||
description="Current status",
|
||||
examples=["pending", "processing", "completed"],
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="Human-readable status message",
|
||||
examples=["Audio generation started"],
|
||||
)
|
||||
estimated_time_seconds: int | None = Field(
|
||||
None,
|
||||
description="Estimated completion time in seconds",
|
||||
examples=[600, 900, 1200],
|
||||
)
|
||||
|
||||
|
||||
class SourceReference(BaseModel):
|
||||
"""Source reference in chat response.
|
||||
|
||||
|
||||
563
src/notebooklm_agent/api/routes/generation.py
Normal file
563
src/notebooklm_agent/api/routes/generation.py
Normal file
@@ -0,0 +1,563 @@
|
||||
"""Generation API routes.
|
||||
|
||||
This module contains API endpoints for content generation.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
|
||||
from notebooklm_agent.api.models.requests import (
|
||||
AudioGenerationRequest,
|
||||
DataTableGenerationRequest,
|
||||
FlashcardsGenerationRequest,
|
||||
InfographicGenerationRequest,
|
||||
QuizGenerationRequest,
|
||||
ReportGenerationRequest,
|
||||
SlideDeckGenerationRequest,
|
||||
VideoGenerationRequest,
|
||||
)
|
||||
from notebooklm_agent.api.models.responses import (
|
||||
ApiResponse,
|
||||
Artifact,
|
||||
ArtifactList,
|
||||
GenerationResponse,
|
||||
PaginationMeta,
|
||||
ResponseMeta,
|
||||
)
|
||||
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError
|
||||
from notebooklm_agent.services.artifact_service import ArtifactService
|
||||
|
||||
router = APIRouter(tags=["generation"])
|
||||
|
||||
|
||||
async def get_artifact_service() -> ArtifactService:
|
||||
"""Get artifact service instance.
|
||||
|
||||
Returns:
|
||||
ArtifactService instance.
|
||||
"""
|
||||
return ArtifactService()
|
||||
|
||||
|
||||
def _validate_notebook_id(notebook_id: str) -> None:
|
||||
"""Validate notebook ID format.
|
||||
|
||||
Args:
|
||||
notebook_id: Notebook ID string.
|
||||
|
||||
Raises:
|
||||
HTTPException: If invalid format.
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
UUID(notebook_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": "VALIDATION_ERROR",
|
||||
"message": "Invalid notebook ID format",
|
||||
"details": [],
|
||||
},
|
||||
"meta": {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": str(uuid4()),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/generate/audio",
|
||||
response_model=ApiResponse[GenerationResponse],
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Generate audio podcast",
|
||||
description="Generate an audio podcast from notebook sources.",
|
||||
)
|
||||
async def generate_audio(notebook_id: str, data: AudioGenerationRequest):
|
||||
"""Generate audio podcast."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
result = await service.generate_audio(
|
||||
UUID(notebook_id),
|
||||
data.instructions,
|
||||
data.format,
|
||||
data.length,
|
||||
data.language,
|
||||
)
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/generate/video",
|
||||
response_model=ApiResponse[GenerationResponse],
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Generate video",
|
||||
description="Generate a video from notebook sources.",
|
||||
)
|
||||
async def generate_video(notebook_id: str, data: VideoGenerationRequest):
|
||||
"""Generate video."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
result = await service.generate_video(
|
||||
UUID(notebook_id),
|
||||
data.instructions,
|
||||
data.style,
|
||||
data.language,
|
||||
)
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/generate/slide-deck",
|
||||
response_model=ApiResponse[GenerationResponse],
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Generate slide deck",
|
||||
description="Generate a slide deck from notebook sources.",
|
||||
)
|
||||
async def generate_slide_deck(notebook_id: str, data: SlideDeckGenerationRequest):
|
||||
"""Generate slide deck."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
result = await service.generate_slide_deck(
|
||||
UUID(notebook_id),
|
||||
data.format,
|
||||
data.length,
|
||||
)
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/generate/infographic",
|
||||
response_model=ApiResponse[GenerationResponse],
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Generate infographic",
|
||||
description="Generate an infographic from notebook sources.",
|
||||
)
|
||||
async def generate_infographic(notebook_id: str, data: InfographicGenerationRequest):
|
||||
"""Generate infographic."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
result = await service.generate_infographic(
|
||||
UUID(notebook_id),
|
||||
data.orientation,
|
||||
data.detail,
|
||||
data.style,
|
||||
)
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/generate/quiz",
|
||||
response_model=ApiResponse[GenerationResponse],
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Generate quiz",
|
||||
description="Generate a quiz from notebook sources.",
|
||||
)
|
||||
async def generate_quiz(notebook_id: str, data: QuizGenerationRequest):
|
||||
"""Generate quiz."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
result = await service.generate_quiz(
|
||||
UUID(notebook_id),
|
||||
data.difficulty,
|
||||
data.quantity,
|
||||
)
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/generate/flashcards",
|
||||
response_model=ApiResponse[GenerationResponse],
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Generate flashcards",
|
||||
description="Generate flashcards from notebook sources.",
|
||||
)
|
||||
async def generate_flashcards(notebook_id: str, data: FlashcardsGenerationRequest):
|
||||
"""Generate flashcards."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
result = await service.generate_flashcards(
|
||||
UUID(notebook_id),
|
||||
data.difficulty,
|
||||
data.quantity,
|
||||
)
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/generate/report",
|
||||
response_model=ApiResponse[GenerationResponse],
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Generate report",
|
||||
description="Generate a report from notebook sources.",
|
||||
)
|
||||
async def generate_report(notebook_id: str, data: ReportGenerationRequest):
|
||||
"""Generate report."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
result = await service.generate_report(
|
||||
UUID(notebook_id),
|
||||
data.format,
|
||||
)
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/generate/mind-map",
|
||||
response_model=ApiResponse[GenerationResponse],
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Generate mind map",
|
||||
description="Generate a mind map from notebook sources (instant).",
|
||||
)
|
||||
async def generate_mind_map(notebook_id: str):
|
||||
"""Generate mind map."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
result = await service.generate_mind_map(UUID(notebook_id))
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{notebook_id}/generate/data-table",
|
||||
response_model=ApiResponse[GenerationResponse],
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Generate data table",
|
||||
description="Generate a data table from notebook sources.",
|
||||
)
|
||||
async def generate_data_table(notebook_id: str, data: DataTableGenerationRequest):
|
||||
"""Generate data table."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
result = await service.generate_data_table(
|
||||
UUID(notebook_id),
|
||||
data.description,
|
||||
)
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{notebook_id}/artifacts",
|
||||
response_model=ApiResponse[ArtifactList],
|
||||
summary="List artifacts",
|
||||
description="List all generated artifacts for a notebook.",
|
||||
)
|
||||
async def list_artifacts(notebook_id: str):
|
||||
"""List artifacts."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
artifacts = await service.list_artifacts(UUID(notebook_id))
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=ArtifactList(
|
||||
items=artifacts,
|
||||
pagination=PaginationMeta(
|
||||
total=len(artifacts),
|
||||
limit=max(len(artifacts), 1),
|
||||
offset=0,
|
||||
has_more=False,
|
||||
),
|
||||
),
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{notebook_id}/artifacts/{artifact_id}/status",
|
||||
response_model=ApiResponse[Artifact],
|
||||
summary="Get artifact status",
|
||||
description="Get the current status of an artifact.",
|
||||
)
|
||||
async def get_artifact_status(notebook_id: str, artifact_id: str):
|
||||
"""Get artifact status."""
|
||||
_validate_notebook_id(notebook_id)
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
service = await get_artifact_service()
|
||||
artifact = await service.get_status(UUID(notebook_id), artifact_id)
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=artifact,
|
||||
error=None,
|
||||
meta=ResponseMeta(timestamp=datetime.utcnow(), request_id=uuid4()),
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
except NotebookLMError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail={
|
||||
"success": False,
|
||||
"error": {"code": e.code, "message": e.message, "details": []},
|
||||
"meta": {"timestamp": datetime.utcnow().isoformat(), "request_id": str(uuid4())},
|
||||
},
|
||||
)
|
||||
449
src/notebooklm_agent/services/artifact_service.py
Normal file
449
src/notebooklm_agent/services/artifact_service.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""Artifact service for content generation.
|
||||
|
||||
This module contains the ArtifactService class which handles
|
||||
all business logic for generating content (audio, video, etc.).
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from notebooklm_agent.api.models.responses import Artifact, GenerationResponse
|
||||
from notebooklm_agent.core.exceptions import NotebookLMError, NotFoundError, ValidationError
|
||||
|
||||
|
||||
class ArtifactService:
|
||||
"""Service for artifact/content generation operations.
|
||||
|
||||
This service handles all business logic for generating content
|
||||
including audio, video, slides, quizzes, etc.
|
||||
|
||||
Attributes:
|
||||
_client: The notebooklm-py client instance.
|
||||
"""
|
||||
|
||||
# Estimated times in seconds for each artifact type
|
||||
ESTIMATED_TIMES = {
|
||||
"audio": 600, # 10 minutes
|
||||
"video": 1800, # 30 minutes
|
||||
"slide-deck": 300, # 5 minutes
|
||||
"infographic": 300, # 5 minutes
|
||||
"quiz": 600, # 10 minutes
|
||||
"flashcards": 600, # 10 minutes
|
||||
"report": 300, # 5 minutes
|
||||
"mind-map": 10, # Instant
|
||||
"data-table": 60, # 1 minute
|
||||
}
|
||||
|
||||
def __init__(self, client: Any = None) -> None:
|
||||
"""Initialize the artifact service.
|
||||
|
||||
Args:
|
||||
client: Optional notebooklm-py client instance.
|
||||
If not provided, will be created on first use.
|
||||
"""
|
||||
self._client = client
|
||||
|
||||
async def _get_client(self) -> Any:
|
||||
"""Get or create notebooklm-py client.
|
||||
|
||||
Returns:
|
||||
The notebooklm-py client instance.
|
||||
"""
|
||||
if self._client is None:
|
||||
from notebooklm import NotebookLMClient
|
||||
|
||||
self._client = await NotebookLMClient.from_storage()
|
||||
return self._client
|
||||
|
||||
def _get_estimated_time(self, artifact_type: str) -> int:
|
||||
"""Get estimated generation time for artifact type.
|
||||
|
||||
Args:
|
||||
artifact_type: Type of artifact.
|
||||
|
||||
Returns:
|
||||
Estimated time in seconds.
|
||||
"""
|
||||
return self.ESTIMATED_TIMES.get(artifact_type, 300)
|
||||
|
||||
async def _start_generation(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
artifact_type: str,
|
||||
title: str,
|
||||
generation_method: str,
|
||||
params: dict,
|
||||
) -> GenerationResponse:
|
||||
"""Start content generation.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
artifact_type: Type of artifact (audio, video, etc.).
|
||||
title: Artifact title.
|
||||
generation_method: Method name to call on notebook.
|
||||
params: Parameters for generation.
|
||||
|
||||
Returns:
|
||||
Generation response with artifact ID.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If notebook not found.
|
||||
NotebookLMError: If external API fails.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
notebook = await client.notebooks.get(str(notebook_id))
|
||||
|
||||
# Get the generation method
|
||||
generator = getattr(notebook, generation_method, None)
|
||||
if not generator:
|
||||
raise NotebookLMError(f"Generation method '{generation_method}' not available")
|
||||
|
||||
# Start generation
|
||||
result = await generator(**params)
|
||||
|
||||
artifact_id = getattr(result, "id", str(uuid4()))
|
||||
status = getattr(result, "status", "processing")
|
||||
|
||||
return GenerationResponse(
|
||||
artifact_id=artifact_id,
|
||||
status=status,
|
||||
message=f"{artifact_type.title()} generation started",
|
||||
estimated_time_seconds=self._get_estimated_time(artifact_type),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str:
|
||||
raise NotFoundError("Notebook", str(notebook_id))
|
||||
raise NotebookLMError(f"Failed to start {artifact_type} generation: {e}")
|
||||
|
||||
async def generate_audio(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
instructions: str | None = None,
|
||||
format: str = "deep-dive",
|
||||
length: str = "default",
|
||||
language: str = "en",
|
||||
) -> GenerationResponse:
|
||||
"""Generate audio podcast.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
instructions: Custom instructions.
|
||||
format: Podcast format.
|
||||
length: Audio length.
|
||||
language: Language code.
|
||||
|
||||
Returns:
|
||||
Generation response.
|
||||
"""
|
||||
params = {
|
||||
"format": format,
|
||||
"length": length,
|
||||
"language": language,
|
||||
}
|
||||
if instructions:
|
||||
params["instructions"] = instructions
|
||||
|
||||
return await self._start_generation(
|
||||
notebook_id,
|
||||
"audio",
|
||||
"Generated Podcast",
|
||||
"generate_audio",
|
||||
params,
|
||||
)
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
instructions: str | None = None,
|
||||
style: str = "auto",
|
||||
language: str = "en",
|
||||
) -> GenerationResponse:
|
||||
"""Generate video.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
instructions: Custom instructions.
|
||||
style: Video style.
|
||||
language: Language code.
|
||||
|
||||
Returns:
|
||||
Generation response.
|
||||
"""
|
||||
params = {
|
||||
"style": style,
|
||||
"language": language,
|
||||
}
|
||||
if instructions:
|
||||
params["instructions"] = instructions
|
||||
|
||||
return await self._start_generation(
|
||||
notebook_id,
|
||||
"video",
|
||||
"Generated Video",
|
||||
"generate_video",
|
||||
params,
|
||||
)
|
||||
|
||||
async def generate_slide_deck(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
format: str = "detailed",
|
||||
length: str = "default",
|
||||
) -> GenerationResponse:
|
||||
"""Generate slide deck.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
format: Slide format.
|
||||
length: Length.
|
||||
|
||||
Returns:
|
||||
Generation response.
|
||||
"""
|
||||
return await self._start_generation(
|
||||
notebook_id,
|
||||
"slide-deck",
|
||||
"Generated Slide Deck",
|
||||
"generate_slide_deck",
|
||||
{"format": format, "length": length},
|
||||
)
|
||||
|
||||
async def generate_infographic(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
orientation: str = "landscape",
|
||||
detail: str = "standard",
|
||||
style: str = "auto",
|
||||
) -> GenerationResponse:
|
||||
"""Generate infographic.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
orientation: Orientation.
|
||||
detail: Detail level.
|
||||
style: Visual style.
|
||||
|
||||
Returns:
|
||||
Generation response.
|
||||
"""
|
||||
return await self._start_generation(
|
||||
notebook_id,
|
||||
"infographic",
|
||||
"Generated Infographic",
|
||||
"generate_infographic",
|
||||
{"orientation": orientation, "detail": detail, "style": style},
|
||||
)
|
||||
|
||||
async def generate_quiz(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
difficulty: str = "medium",
|
||||
quantity: str = "standard",
|
||||
) -> GenerationResponse:
|
||||
"""Generate quiz.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
difficulty: Difficulty level.
|
||||
quantity: Number of questions.
|
||||
|
||||
Returns:
|
||||
Generation response.
|
||||
"""
|
||||
return await self._start_generation(
|
||||
notebook_id,
|
||||
"quiz",
|
||||
"Generated Quiz",
|
||||
"generate_quiz",
|
||||
{"difficulty": difficulty, "quantity": quantity},
|
||||
)
|
||||
|
||||
async def generate_flashcards(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
difficulty: str = "medium",
|
||||
quantity: str = "standard",
|
||||
) -> GenerationResponse:
|
||||
"""Generate flashcards.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
difficulty: Difficulty level.
|
||||
quantity: Number of flashcards.
|
||||
|
||||
Returns:
|
||||
Generation response.
|
||||
"""
|
||||
return await self._start_generation(
|
||||
notebook_id,
|
||||
"flashcards",
|
||||
"Generated Flashcards",
|
||||
"generate_flashcards",
|
||||
{"difficulty": difficulty, "quantity": quantity},
|
||||
)
|
||||
|
||||
async def generate_report(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
format: str = "detailed",
|
||||
) -> GenerationResponse:
|
||||
"""Generate report.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
format: Report format.
|
||||
|
||||
Returns:
|
||||
Generation response.
|
||||
"""
|
||||
return await self._start_generation(
|
||||
notebook_id,
|
||||
"report",
|
||||
"Generated Report",
|
||||
"generate_report",
|
||||
{"format": format},
|
||||
)
|
||||
|
||||
async def generate_mind_map(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
) -> GenerationResponse:
|
||||
"""Generate mind map (instant).
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
|
||||
Returns:
|
||||
Generation response.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
notebook = await client.notebooks.get(str(notebook_id))
|
||||
|
||||
# Mind map is usually instant
|
||||
result = await notebook.generate_mind_map()
|
||||
|
||||
artifact_id = getattr(result, "id", str(uuid4()))
|
||||
status = getattr(result, "status", "completed")
|
||||
|
||||
return GenerationResponse(
|
||||
artifact_id=artifact_id,
|
||||
status=status,
|
||||
message="Mind map generated",
|
||||
estimated_time_seconds=10,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str:
|
||||
raise NotFoundError("Notebook", str(notebook_id))
|
||||
raise NotebookLMError(f"Failed to generate mind map: {e}")
|
||||
|
||||
async def generate_data_table(
|
||||
self,
|
||||
notebook_id: UUID,
|
||||
description: str | None = None,
|
||||
) -> GenerationResponse:
|
||||
"""Generate data table.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
description: Description of data to extract.
|
||||
|
||||
Returns:
|
||||
Generation response.
|
||||
"""
|
||||
params = {}
|
||||
if description:
|
||||
params["description"] = description
|
||||
|
||||
return await self._start_generation(
|
||||
notebook_id,
|
||||
"data-table",
|
||||
"Generated Data Table",
|
||||
"generate_data_table",
|
||||
params,
|
||||
)
|
||||
|
||||
async def list_artifacts(self, notebook_id: UUID) -> list[Artifact]:
|
||||
"""List artifacts for a notebook.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
|
||||
Returns:
|
||||
List of artifacts.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If notebook not found.
|
||||
NotebookLMError: If external API fails.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
notebook = await client.notebooks.get(str(notebook_id))
|
||||
|
||||
artifacts = await notebook.artifacts.list()
|
||||
|
||||
result = []
|
||||
for art in artifacts:
|
||||
result.append(
|
||||
Artifact(
|
||||
id=getattr(art, "id", str(uuid4())),
|
||||
notebook_id=notebook_id,
|
||||
type=getattr(art, "type", "unknown"),
|
||||
title=getattr(art, "title", "Untitled"),
|
||||
status=getattr(art, "status", "pending"),
|
||||
created_at=getattr(art, "created_at", datetime.utcnow()),
|
||||
completed_at=getattr(art, "completed_at", None),
|
||||
download_url=getattr(art, "download_url", None),
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str:
|
||||
raise NotFoundError("Notebook", str(notebook_id))
|
||||
raise NotebookLMError(f"Failed to list artifacts: {e}")
|
||||
|
||||
async def get_status(self, notebook_id: UUID, artifact_id: str) -> Artifact:
|
||||
"""Get artifact status.
|
||||
|
||||
Args:
|
||||
notebook_id: The notebook ID.
|
||||
artifact_id: The artifact ID.
|
||||
|
||||
Returns:
|
||||
Artifact with current status.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If notebook or artifact not found.
|
||||
NotebookLMError: If external API fails.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
notebook = await client.notebooks.get(str(notebook_id))
|
||||
|
||||
artifact = await notebook.artifacts.get(artifact_id)
|
||||
|
||||
return Artifact(
|
||||
id=getattr(artifact, "id", artifact_id),
|
||||
notebook_id=notebook_id,
|
||||
type=getattr(artifact, "type", "unknown"),
|
||||
title=getattr(artifact, "title", "Untitled"),
|
||||
status=getattr(artifact, "status", "pending"),
|
||||
created_at=getattr(artifact, "created_at", datetime.utcnow()),
|
||||
completed_at=getattr(artifact, "completed_at", None),
|
||||
download_url=getattr(artifact, "download_url", None),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "not found" in error_str:
|
||||
raise NotFoundError("Artifact", artifact_id)
|
||||
raise NotebookLMError(f"Failed to get artifact status: {e}")
|
||||
246
tests/unit/test_api/test_generation.py
Normal file
246
tests/unit/test_api/test_generation.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Integration tests for generation API endpoints.
|
||||
|
||||
Tests key generation endpoints with mocked services.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from notebooklm_agent.api.main import app
|
||||
from notebooklm_agent.api.models.responses import Artifact, GenerationResponse
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateAudioEndpoint:
|
||||
"""Test suite for POST /generate/audio endpoint."""
|
||||
|
||||
def test_generate_audio_returns_202(self):
|
||||
"""Should return 202 Accepted for audio generation."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
artifact_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
mock_response = GenerationResponse(
|
||||
artifact_id=artifact_id,
|
||||
status="processing",
|
||||
message="Audio generation started",
|
||||
estimated_time_seconds=600,
|
||||
)
|
||||
mock_service.generate_audio.return_value = mock_response
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
f"/api/v1/notebooks/{notebook_id}/generate/audio",
|
||||
json={
|
||||
"instructions": "Make it engaging",
|
||||
"format": "deep-dive",
|
||||
"length": "long",
|
||||
"language": "en",
|
||||
},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["status"] == "processing"
|
||||
assert data["data"]["estimated_time_seconds"] == 600
|
||||
|
||||
def test_generate_audio_invalid_notebook_returns_400(self):
|
||||
"""Should return 400 for invalid notebook ID."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
"/api/v1/notebooks/invalid-id/generate/audio",
|
||||
json={"format": "deep-dive"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code in [400, 422]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateQuizEndpoint:
|
||||
"""Test suite for POST /generate/quiz endpoint."""
|
||||
|
||||
def test_generate_quiz_returns_202(self):
|
||||
"""Should return 202 Accepted for quiz generation."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
artifact_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
mock_response = GenerationResponse(
|
||||
artifact_id=artifact_id,
|
||||
status="processing",
|
||||
message="Quiz generation started",
|
||||
estimated_time_seconds=600,
|
||||
)
|
||||
mock_service.generate_quiz.return_value = mock_response
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
f"/api/v1/notebooks/{notebook_id}/generate/quiz",
|
||||
json={"difficulty": "medium", "quantity": "standard"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateMindMapEndpoint:
|
||||
"""Test suite for POST /generate/mind-map endpoint."""
|
||||
|
||||
def test_generate_mind_map_returns_202(self):
|
||||
"""Should return 202 Accepted for mind map generation."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
artifact_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
mock_response = GenerationResponse(
|
||||
artifact_id=artifact_id,
|
||||
status="completed",
|
||||
message="Mind map generated",
|
||||
estimated_time_seconds=10,
|
||||
)
|
||||
mock_service.generate_mind_map.return_value = mock_response
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.post(f"/api/v1/notebooks/{notebook_id}/generate/mind-map")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["data"]["status"] == "completed"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestListArtifactsEndpoint:
|
||||
"""Test suite for GET /artifacts endpoint."""
|
||||
|
||||
def test_list_artifacts_returns_200(self):
|
||||
"""Should return 200 with list of artifacts."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
mock_artifact = Artifact(
|
||||
id=uuid4(),
|
||||
notebook_id=notebook_id,
|
||||
type="audio",
|
||||
title="Podcast",
|
||||
status="completed",
|
||||
created_at=datetime.utcnow(),
|
||||
completed_at=datetime.utcnow(),
|
||||
download_url="https://example.com/download",
|
||||
)
|
||||
mock_service.list_artifacts.return_value = [mock_artifact]
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/v1/notebooks/{notebook_id}/artifacts")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]["items"]) == 1
|
||||
assert data["data"]["items"][0]["type"] == "audio"
|
||||
|
||||
def test_list_artifacts_empty_returns_empty_list(self):
|
||||
"""Should return empty list if no artifacts."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
mock_service.list_artifacts.return_value = []
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/v1/notebooks/{notebook_id}/artifacts")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["items"] == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetArtifactStatusEndpoint:
|
||||
"""Test suite for GET /artifacts/{id}/status endpoint."""
|
||||
|
||||
def test_get_artifact_status_returns_200(self):
|
||||
"""Should return 200 with artifact status."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
artifact_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
mock_artifact = Artifact(
|
||||
id=artifact_id,
|
||||
notebook_id=notebook_id,
|
||||
type="video",
|
||||
title="Video",
|
||||
status="processing",
|
||||
created_at=datetime.utcnow(),
|
||||
completed_at=None,
|
||||
download_url=None,
|
||||
)
|
||||
mock_service.get_status.return_value = mock_artifact
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/v1/notebooks/{notebook_id}/artifacts/{artifact_id}/status")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["status"] == "processing"
|
||||
|
||||
def test_get_artifact_status_not_found_returns_404(self):
|
||||
"""Should return 404 when artifact not found."""
|
||||
# Arrange
|
||||
client = TestClient(app)
|
||||
notebook_id = str(uuid4())
|
||||
artifact_id = str(uuid4())
|
||||
|
||||
with patch("notebooklm_agent.api.routes.generation.ArtifactService") as mock_service_class:
|
||||
mock_service = AsyncMock()
|
||||
from notebooklm_agent.core.exceptions import NotFoundError
|
||||
|
||||
mock_service.get_status.side_effect = NotFoundError("Artifact", artifact_id)
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/v1/notebooks/{notebook_id}/artifacts/{artifact_id}/status")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 404
|
||||
292
tests/unit/test_services/test_artifact_service.py
Normal file
292
tests/unit/test_services/test_artifact_service.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Unit tests for ArtifactService.
|
||||
|
||||
TDD Cycle: RED → GREEN → REFACTOR
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from notebooklm_agent.core.exceptions import (
|
||||
NotebookLMError,
|
||||
NotFoundError,
|
||||
)
|
||||
from notebooklm_agent.services.artifact_service import ArtifactService
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestArtifactServiceInit:
|
||||
"""Test suite for ArtifactService initialization."""
|
||||
|
||||
async def test_get_client_returns_existing_client(self):
|
||||
"""Should return existing client if already initialized."""
|
||||
# Arrange
|
||||
mock_client = AsyncMock()
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act
|
||||
client = await service._get_client()
|
||||
|
||||
# Assert
|
||||
assert client == mock_client
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestArtifactServiceGenerateAudio:
|
||||
"""Test suite for generate_audio method."""
|
||||
|
||||
async def test_generate_audio_returns_generation_response(self):
|
||||
"""Should start audio generation and return response."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.id = str(uuid4())
|
||||
mock_result.status = "processing"
|
||||
|
||||
mock_notebook.generate_audio.return_value = mock_result
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.generate_audio(
|
||||
notebook_id,
|
||||
instructions="Make it engaging",
|
||||
format="deep-dive",
|
||||
length="long",
|
||||
language="en",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.status == "processing"
|
||||
assert str(result.artifact_id) == str(mock_result.id)
|
||||
assert result.estimated_time_seconds == 600
|
||||
assert "Audio generation started" in result.message
|
||||
|
||||
async def test_generate_audio_notebook_not_found_raises_not_found(self):
|
||||
"""Should raise NotFoundError if notebook not found."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.notebooks.get.side_effect = Exception("notebook not found")
|
||||
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFoundError) as exc_info:
|
||||
await service.generate_audio(notebook_id)
|
||||
|
||||
assert str(notebook_id) in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestArtifactServiceGenerateVideo:
|
||||
"""Test suite for generate_video method."""
|
||||
|
||||
async def test_generate_video_returns_generation_response(self):
|
||||
"""Should start video generation and return response."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.id = str(uuid4())
|
||||
mock_result.status = "processing"
|
||||
|
||||
mock_notebook.generate_video.return_value = mock_result
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.generate_video(
|
||||
notebook_id,
|
||||
instructions="Create engaging video",
|
||||
style="whiteboard",
|
||||
language="en",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.status == "processing"
|
||||
assert result.estimated_time_seconds == 1800
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestArtifactServiceGenerateQuiz:
|
||||
"""Test suite for generate_quiz method."""
|
||||
|
||||
async def test_generate_quiz_returns_generation_response(self):
|
||||
"""Should start quiz generation and return response."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.id = str(uuid4())
|
||||
mock_result.status = "processing"
|
||||
|
||||
mock_notebook.generate_quiz.return_value = mock_result
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.generate_quiz(
|
||||
notebook_id,
|
||||
difficulty="medium",
|
||||
quantity="standard",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.status == "processing"
|
||||
assert result.estimated_time_seconds == 600
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestArtifactServiceGenerateMindMap:
|
||||
"""Test suite for generate_mind_map method."""
|
||||
|
||||
async def test_generate_mind_map_returns_generation_response(self):
|
||||
"""Should generate mind map (instant) and return response."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.id = str(uuid4())
|
||||
mock_result.status = "completed"
|
||||
|
||||
mock_notebook.generate_mind_map.return_value = mock_result
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.generate_mind_map(notebook_id)
|
||||
|
||||
# Assert
|
||||
assert result.status == "completed"
|
||||
assert result.estimated_time_seconds == 10
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestArtifactServiceListArtifacts:
|
||||
"""Test suite for list_artifacts method."""
|
||||
|
||||
async def test_list_artifacts_returns_list(self):
|
||||
"""Should return list of artifacts."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
|
||||
mock_artifact = MagicMock()
|
||||
mock_artifact.id = str(uuid4())
|
||||
mock_artifact.type = "audio"
|
||||
mock_artifact.title = "Podcast"
|
||||
mock_artifact.status = "completed"
|
||||
mock_artifact.created_at = datetime.utcnow()
|
||||
mock_artifact.completed_at = datetime.utcnow()
|
||||
mock_artifact.download_url = "https://example.com/download"
|
||||
|
||||
mock_notebook.artifacts.list.return_value = [mock_artifact]
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.list_artifacts(notebook_id)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].type == "audio"
|
||||
assert result[0].status == "completed"
|
||||
|
||||
async def test_list_artifacts_empty_returns_empty_list(self):
|
||||
"""Should return empty list if no artifacts."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_notebook.artifacts.list.return_value = []
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.list_artifacts(notebook_id)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestArtifactServiceGetStatus:
|
||||
"""Test suite for get_status method."""
|
||||
|
||||
async def test_get_status_returns_artifact(self):
|
||||
"""Should return artifact with status."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
artifact_id = str(uuid4())
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
|
||||
mock_artifact = MagicMock()
|
||||
mock_artifact.id = artifact_id
|
||||
mock_artifact.type = "video"
|
||||
mock_artifact.title = "Video"
|
||||
mock_artifact.status = "processing"
|
||||
mock_artifact.created_at = datetime.utcnow()
|
||||
mock_artifact.completed_at = None
|
||||
mock_artifact.download_url = None
|
||||
|
||||
mock_notebook.artifacts.get.return_value = mock_artifact
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act
|
||||
result = await service.get_status(notebook_id, artifact_id)
|
||||
|
||||
# Assert
|
||||
assert str(result.id) == artifact_id
|
||||
assert result.type == "video"
|
||||
assert result.status == "processing"
|
||||
|
||||
async def test_get_status_artifact_not_found_raises_not_found(self):
|
||||
"""Should raise NotFoundError if artifact not found."""
|
||||
# Arrange
|
||||
notebook_id = uuid4()
|
||||
artifact_id = str(uuid4())
|
||||
mock_client = AsyncMock()
|
||||
mock_notebook = AsyncMock()
|
||||
mock_notebook.artifacts.get.side_effect = Exception("artifact not found")
|
||||
mock_client.notebooks.get.return_value = mock_notebook
|
||||
|
||||
service = ArtifactService(client=mock_client)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFoundError) as exc_info:
|
||||
await service.get_status(notebook_id, artifact_id)
|
||||
|
||||
assert artifact_id in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestArtifactServiceEstimatedTimes:
|
||||
"""Test suite for estimated times."""
|
||||
|
||||
def test_estimated_times_defined(self):
|
||||
"""Should have estimated times for all artifact types."""
|
||||
# Arrange
|
||||
service = ArtifactService()
|
||||
|
||||
# Assert
|
||||
assert service.ESTIMATED_TIMES["audio"] == 600
|
||||
assert service.ESTIMATED_TIMES["video"] == 1800
|
||||
assert service.ESTIMATED_TIMES["mind-map"] == 10
|
||||
Reference in New Issue
Block a user