"""Base repository with generic CRUD operations.""" from typing import Generic, TypeVar, Optional, List, Any from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, delete, update, func from src.models.base import Base ModelType = TypeVar("ModelType", bound=Base) class BaseRepository(Generic[ModelType]): """Generic base repository with common CRUD operations.""" def __init__(self, model: type[ModelType]): self.model = model async def get(self, db: AsyncSession, id: UUID) -> Optional[ModelType]: """Get a single record by ID.""" result = await db.execute(select(self.model).where(self.model.id == id)) return result.scalar_one_or_none() async def get_multi( self, db: AsyncSession, *, skip: int = 0, limit: int = 100, **filters ) -> List[ModelType]: """Get multiple records with optional filtering.""" query = select(self.model) # Apply filters for key, value in filters.items(): if hasattr(self.model, key) and value is not None: query = query.where(getattr(self.model, key) == value) query = query.offset(skip).limit(limit) result = await db.execute(query) return result.scalars().all() async def count(self, db: AsyncSession, **filters) -> int: """Count records with optional filtering.""" query = select(func.count(self.model.id)) for key, value in filters.items(): if hasattr(self.model, key) and value is not None: query = query.where(getattr(self.model, key) == value) result = await db.execute(query) return result.scalar() async def create(self, db: AsyncSession, *, obj_in: dict) -> ModelType: """Create a new record.""" db_obj = self.model(**obj_in) db.add(db_obj) await db.commit() await db.refresh(db_obj) return db_obj async def update( self, db: AsyncSession, *, db_obj: ModelType, obj_in: dict ) -> ModelType: """Update a record.""" for field, value in obj_in.items(): if hasattr(db_obj, field) and value is not None: setattr(db_obj, field, value) db.add(db_obj) await db.commit() await db.refresh(db_obj) return db_obj async def delete(self, db: AsyncSession, *, id: UUID) -> bool: """Delete a record by ID.""" result = await db.execute(delete(self.model).where(self.model.id == id)) await db.commit() return result.rowcount > 0