release: v0.5.0 - Authentication, API Keys & Advanced Features
Complete v0.5.0 implementation: Database (@db-engineer): - 3 migrations: users, api_keys, report_schedules tables - Foreign keys, indexes, constraints, enums Backend (@backend-dev): - JWT authentication service with bcrypt (cost=12) - Auth endpoints: /register, /login, /refresh, /me - API Keys service with hash storage and prefix validation - API Keys endpoints: CRUD + rotate - Security module with JWT HS256 Frontend (@frontend-dev): - Login/Register pages with validation - AuthContext with localStorage persistence - Protected routes implementation - API Keys management UI (create, revoke, rotate) - Header with user dropdown DevOps (@devops-engineer): - .env.example and .env.production.example - docker-compose.scheduler.yml - scripts/setup-secrets.sh - INFRASTRUCTURE_SETUP.md QA (@qa-engineer): - 85 E2E tests: auth.spec.ts, apikeys.spec.ts, scenarios.spec.ts, regression-v050.spec.ts - auth-helpers.ts with 20+ utility functions - Test plans and documentation Architecture (@spec-architect): - SECURITY.md with best practices - SECURITY-CHECKLIST.md pre-deployment - Updated architecture.md with auth flows - Updated README.md with v0.5.0 features Documentation: - Updated todo.md with v0.5.0 status - Added docs/README.md index - Complete setup instructions Dependencies added: - bcrypt, python-jose, passlib, email-validator Tested: JWT auth flow, API keys CRUD, protected routes, 85 E2E tests ready Closes: v0.5.0 milestone
This commit is contained in:
@@ -6,8 +6,12 @@ from src.api.v1.scenarios import router as scenarios_router
|
||||
from src.api.v1.ingest import router as ingest_router
|
||||
from src.api.v1.metrics import router as metrics_router
|
||||
from src.api.v1.reports import scenario_reports_router, reports_router
|
||||
from src.api.v1.auth import router as auth_router
|
||||
from src.api.v1.apikeys import router as apikeys_router
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth_router, tags=["authentication"])
|
||||
api_router.include_router(apikeys_router, tags=["api-keys"])
|
||||
api_router.include_router(scenarios_router, prefix="/scenarios", tags=["scenarios"])
|
||||
api_router.include_router(ingest_router, tags=["ingest"])
|
||||
api_router.include_router(metrics_router, prefix="/scenarios", tags=["metrics"])
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
"""API Keys API endpoints."""
|
||||
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.core.database import get_db
|
||||
from src.schemas.user import UserResponse
|
||||
from src.schemas.api_key import (
|
||||
APIKeyCreate,
|
||||
APIKeyUpdate,
|
||||
APIKeyResponse,
|
||||
APIKeyCreateResponse,
|
||||
APIKeyList,
|
||||
)
|
||||
from src.api.v1.auth import get_current_user
|
||||
from src.services.apikey_service import (
|
||||
create_api_key,
|
||||
list_api_keys,
|
||||
revoke_api_key,
|
||||
rotate_api_key,
|
||||
update_api_key,
|
||||
APIKeyNotFoundError,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api-keys", tags=["api-keys"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=APIKeyCreateResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_new_api_key(
|
||||
key_data: APIKeyCreate,
|
||||
current_user: Annotated[UserResponse, Depends(get_current_user)],
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new API key.
|
||||
|
||||
⚠️ WARNING: The full API key is shown ONLY at creation!
|
||||
Make sure to copy and save it immediately.
|
||||
|
||||
Args:
|
||||
key_data: API key creation data
|
||||
current_user: Current authenticated user
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
APIKeyCreateResponse with full key (shown only once)
|
||||
"""
|
||||
api_key, full_key = await create_api_key(
|
||||
session=session,
|
||||
user_id=current_user.id,
|
||||
name=key_data.name,
|
||||
scopes=key_data.scopes,
|
||||
expires_days=key_data.expires_days,
|
||||
)
|
||||
|
||||
return APIKeyCreateResponse(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
key=full_key, # Full key shown ONLY ONCE!
|
||||
key_prefix=api_key.key_prefix,
|
||||
scopes=api_key.scopes,
|
||||
is_active=api_key.is_active,
|
||||
created_at=api_key.created_at,
|
||||
expires_at=api_key.expires_at,
|
||||
last_used_at=api_key.last_used_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=APIKeyList,
|
||||
)
|
||||
async def list_user_api_keys(
|
||||
current_user: Annotated[UserResponse, Depends(get_current_user)],
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all API keys for the current user.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
APIKeyList with user's API keys (without key_hash)
|
||||
"""
|
||||
api_keys = await list_api_keys(session, current_user.id)
|
||||
|
||||
return APIKeyList(
|
||||
items=[APIKeyResponse.model_validate(key) for key in api_keys],
|
||||
total=len(api_keys),
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{key_id}",
|
||||
response_model=APIKeyResponse,
|
||||
)
|
||||
async def update_api_key_endpoint(
|
||||
key_id: UUID,
|
||||
key_data: APIKeyUpdate,
|
||||
current_user: Annotated[UserResponse, Depends(get_current_user)],
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update an API key (name only).
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
key_data: Update data
|
||||
current_user: Current authenticated user
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
Updated APIKeyResponse
|
||||
|
||||
Raises:
|
||||
HTTPException: If key not found
|
||||
"""
|
||||
try:
|
||||
api_key = await update_api_key(
|
||||
session=session,
|
||||
api_key_id=key_id,
|
||||
user_id=current_user.id,
|
||||
name=key_data.name,
|
||||
)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found",
|
||||
)
|
||||
|
||||
return APIKeyResponse.model_validate(api_key)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{key_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def revoke_user_api_key(
|
||||
key_id: UUID,
|
||||
current_user: Annotated[UserResponse, Depends(get_current_user)],
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Revoke (delete) an API key.
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
current_user: Current authenticated user
|
||||
session: Database session
|
||||
|
||||
Raises:
|
||||
HTTPException: If key not found
|
||||
"""
|
||||
try:
|
||||
await revoke_api_key(
|
||||
session=session,
|
||||
api_key_id=key_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{key_id}/rotate",
|
||||
response_model=APIKeyCreateResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def rotate_user_api_key(
|
||||
key_id: UUID,
|
||||
current_user: Annotated[UserResponse, Depends(get_current_user)],
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Rotate (regenerate) an API key.
|
||||
|
||||
The old key is revoked and a new key is created with the same settings.
|
||||
|
||||
⚠️ WARNING: The new full API key is shown ONLY at creation!
|
||||
|
||||
Args:
|
||||
key_id: API key ID to rotate
|
||||
current_user: Current authenticated user
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
APIKeyCreateResponse with new full key
|
||||
|
||||
Raises:
|
||||
HTTPException: If key not found
|
||||
"""
|
||||
try:
|
||||
new_key, full_key = await rotate_api_key(
|
||||
session=session,
|
||||
api_key_id=key_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found",
|
||||
)
|
||||
|
||||
return APIKeyCreateResponse(
|
||||
id=new_key.id,
|
||||
name=new_key.name,
|
||||
key=full_key, # New full key shown ONLY ONCE!
|
||||
key_prefix=new_key.key_prefix,
|
||||
scopes=new_key.scopes,
|
||||
is_active=new_key.is_active,
|
||||
created_at=new_key.created_at,
|
||||
expires_at=new_key.expires_at,
|
||||
last_used_at=new_key.last_used_at,
|
||||
)
|
||||
@@ -0,0 +1,355 @@
|
||||
"""Authentication API endpoints."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.core.database import get_db
|
||||
from src.core.security import verify_access_token, verify_refresh_token
|
||||
from src.schemas.user import (
|
||||
UserCreate,
|
||||
UserLogin,
|
||||
UserResponse,
|
||||
AuthResponse,
|
||||
TokenRefresh,
|
||||
TokenResponse,
|
||||
PasswordChange,
|
||||
PasswordResetRequest,
|
||||
PasswordReset,
|
||||
)
|
||||
from src.services.auth_service import (
|
||||
register_user,
|
||||
authenticate_user,
|
||||
change_password,
|
||||
reset_password_request,
|
||||
reset_password,
|
||||
get_user_by_id,
|
||||
create_tokens_for_user,
|
||||
EmailAlreadyExistsError,
|
||||
InvalidCredentialsError,
|
||||
UserNotFoundError,
|
||||
InvalidPasswordError,
|
||||
InvalidTokenError,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)],
|
||||
session: AsyncSession = Depends(get_db),
|
||||
) -> UserResponse:
|
||||
"""Get current authenticated user from JWT token.
|
||||
|
||||
Args:
|
||||
credentials: HTTP Authorization credentials with Bearer token
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
UserResponse object
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is invalid or user not found
|
||||
"""
|
||||
token = credentials.credentials
|
||||
payload = verify_access_token(token)
|
||||
|
||||
if not payload:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
user = await get_user_by_id(session, UUID(user_id))
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User account is disabled",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=AuthResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def register(
|
||||
user_data: UserCreate,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Register a new user.
|
||||
|
||||
Args:
|
||||
user_data: User registration data
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
AuthResponse with user and tokens
|
||||
|
||||
Raises:
|
||||
HTTPException: If email already exists or validation fails
|
||||
"""
|
||||
try:
|
||||
user = await register_user(
|
||||
session=session,
|
||||
email=user_data.email,
|
||||
password=user_data.password,
|
||||
full_name=user_data.full_name,
|
||||
)
|
||||
except EmailAlreadyExistsError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Create tokens
|
||||
access_token, refresh_token = create_tokens_for_user(user)
|
||||
|
||||
return AuthResponse(
|
||||
user=UserResponse.model_validate(user),
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/login",
|
||||
response_model=TokenResponse,
|
||||
)
|
||||
async def login(
|
||||
credentials: UserLogin,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Login with email and password.
|
||||
|
||||
Args:
|
||||
credentials: Login credentials
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
TokenResponse with access and refresh tokens
|
||||
|
||||
Raises:
|
||||
HTTPException: If credentials are invalid
|
||||
"""
|
||||
user = await authenticate_user(
|
||||
session=session,
|
||||
email=credentials.email,
|
||||
password=credentials.password,
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token, refresh_token = create_tokens_for_user(user)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/refresh",
|
||||
response_model=TokenResponse,
|
||||
)
|
||||
async def refresh_token(
|
||||
token_data: TokenRefresh,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Refresh access token using refresh token.
|
||||
|
||||
Args:
|
||||
token_data: Refresh token data
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
TokenResponse with new access and refresh tokens
|
||||
|
||||
Raises:
|
||||
HTTPException: If refresh token is invalid
|
||||
"""
|
||||
payload = verify_refresh_token(token_data.refresh_token)
|
||||
|
||||
if not payload:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
user_id = payload.get("sub")
|
||||
user = await get_user_by_id(session, UUID(user_id))
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token, refresh_token = create_tokens_for_user(user)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=UserResponse,
|
||||
)
|
||||
async def get_me(
|
||||
current_user: Annotated[UserResponse, Depends(get_current_user)],
|
||||
):
|
||||
"""Get current user information.
|
||||
|
||||
Returns:
|
||||
UserResponse with current user data
|
||||
"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post(
|
||||
"/change-password",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def change_user_password(
|
||||
password_data: PasswordChange,
|
||||
current_user: Annotated[UserResponse, Depends(get_current_user)],
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Change current user password.
|
||||
|
||||
Args:
|
||||
password_data: Old and new password
|
||||
current_user: Current authenticated user
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
|
||||
Raises:
|
||||
HTTPException: If old password is incorrect
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
await change_password(
|
||||
session=session,
|
||||
user_id=UUID(current_user.id),
|
||||
old_password=password_data.old_password,
|
||||
new_password=password_data.new_password,
|
||||
)
|
||||
except InvalidPasswordError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect",
|
||||
)
|
||||
|
||||
return {"message": "Password changed successfully"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/reset-password-request",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def request_password_reset(
|
||||
request_data: PasswordResetRequest,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Request a password reset.
|
||||
|
||||
Args:
|
||||
request_data: Email for password reset
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
Success message (always returns success for security)
|
||||
"""
|
||||
# Always return success to prevent email enumeration
|
||||
await reset_password_request(
|
||||
session=session,
|
||||
email=request_data.email,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "If the email exists, a password reset link has been sent",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/reset-password",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def reset_user_password(
|
||||
reset_data: PasswordReset,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Reset password using token.
|
||||
|
||||
Args:
|
||||
reset_data: Token and new password
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is invalid
|
||||
"""
|
||||
try:
|
||||
await reset_password(
|
||||
session=session,
|
||||
token=reset_data.token,
|
||||
new_password=reset_data.new_password,
|
||||
)
|
||||
except InvalidTokenError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired token",
|
||||
)
|
||||
except UserNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
return {"message": "Password reset successfully"}
|
||||
@@ -24,9 +24,19 @@ class Settings(BaseSettings):
|
||||
reports_cleanup_days: int = 30
|
||||
reports_rate_limit_per_minute: int = 10
|
||||
|
||||
# JWT Configuration
|
||||
jwt_secret_key: str = "super-secret-change-in-production"
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 30
|
||||
refresh_token_expire_days: int = 7
|
||||
|
||||
# Security
|
||||
bcrypt_rounds: int = 12
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
"""Security utilities - JWT and password hashing."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import EmailStr
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
|
||||
# JWT Configuration
|
||||
JWT_SECRET_KEY = getattr(
|
||||
settings, "jwt_secret_key", "super-secret-change-in-production"
|
||||
)
|
||||
JWT_ALGORITHM = getattr(settings, "jwt_algorithm", "HS256")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = getattr(settings, "access_token_expire_minutes", 30)
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = getattr(settings, "refresh_token_expire_days", 7)
|
||||
|
||||
|
||||
# Password hashing
|
||||
BCRYPT_ROUNDS = getattr(settings, "bcrypt_rounds", 12)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt.
|
||||
|
||||
Args:
|
||||
password: Plain text password
|
||||
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
password_bytes = password.encode("utf-8")
|
||||
salt = bcrypt.gensalt(rounds=BCRYPT_ROUNDS)
|
||||
hashed = bcrypt.hashpw(password_bytes, salt)
|
||||
return hashed.decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash.
|
||||
|
||||
Args:
|
||||
plain_password: Plain text password
|
||||
hashed_password: Hashed password string
|
||||
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
password_bytes = plain_password.encode("utf-8")
|
||||
hashed_bytes = hashed_password.encode("utf-8")
|
||||
return bcrypt.checkpw(password_bytes, hashed_bytes)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
data: Data to encode in the token
|
||||
expires_delta: Optional custom expiration time
|
||||
|
||||
Returns:
|
||||
JWT token string
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(data: dict) -> str:
|
||||
"""Create a JWT refresh token.
|
||||
|
||||
Args:
|
||||
data: Data to encode in the token
|
||||
|
||||
Returns:
|
||||
JWT refresh token string
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def verify_token(token: str) -> Optional[dict]:
|
||||
"""Verify and decode a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Decoded payload dict or None if invalid
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def verify_access_token(token: str) -> Optional[dict]:
|
||||
"""Verify an access token specifically.
|
||||
|
||||
Args:
|
||||
token: JWT access token string
|
||||
|
||||
Returns:
|
||||
Decoded payload dict or None if invalid
|
||||
"""
|
||||
payload = verify_token(token)
|
||||
if payload and payload.get("type") == "access":
|
||||
return payload
|
||||
return None
|
||||
|
||||
|
||||
def verify_refresh_token(token: str) -> Optional[dict]:
|
||||
"""Verify a refresh token specifically.
|
||||
|
||||
Args:
|
||||
token: JWT refresh token string
|
||||
|
||||
Returns:
|
||||
Decoded payload dict or None if invalid
|
||||
"""
|
||||
payload = verify_token(token)
|
||||
if payload and payload.get("type") == "refresh":
|
||||
return payload
|
||||
return None
|
||||
|
||||
|
||||
def generate_api_key() -> tuple[str, str]:
|
||||
"""Generate a new API key and its hash.
|
||||
|
||||
Returns:
|
||||
Tuple of (full_key, key_hash)
|
||||
- full_key: The complete API key to show once (mk_ + base64)
|
||||
- key_hash: SHA-256 hash to store in database
|
||||
"""
|
||||
# Generate 32 random bytes
|
||||
random_bytes = secrets.token_bytes(32)
|
||||
# Encode to base64 (URL-safe)
|
||||
key_part = base64.urlsafe_b64encode(random_bytes).decode("utf-8").rstrip("=")
|
||||
# Full key with prefix
|
||||
full_key = f"mk_{key_part}"
|
||||
# Create hash for storage (using bcrypt for security)
|
||||
key_hash = bcrypt.hashpw(
|
||||
full_key.encode("utf-8"), bcrypt.gensalt(rounds=12)
|
||||
).decode("utf-8")
|
||||
# Prefix for identification (first 8 chars after mk_)
|
||||
return full_key, key_hash
|
||||
|
||||
|
||||
def get_key_prefix(key: str) -> str:
|
||||
"""Extract prefix from API key for identification.
|
||||
|
||||
Args:
|
||||
key: Full API key
|
||||
|
||||
Returns:
|
||||
First 8 characters of the key part (after mk_)
|
||||
"""
|
||||
if key.startswith("mk_"):
|
||||
key_part = key[3:] # Remove "mk_" prefix
|
||||
return key_part[:8]
|
||||
return key[:8]
|
||||
|
||||
|
||||
def verify_api_key(key: str, key_hash: str) -> bool:
|
||||
"""Verify an API key against its stored hash.
|
||||
|
||||
Args:
|
||||
key: Full API key
|
||||
key_hash: Stored bcrypt hash
|
||||
|
||||
Returns:
|
||||
True if key matches, False otherwise
|
||||
"""
|
||||
return bcrypt.checkpw(key.encode("utf-8"), key_hash.encode("utf-8"))
|
||||
|
||||
|
||||
def validate_email_format(email: str) -> bool:
|
||||
"""Validate email format.
|
||||
|
||||
Args:
|
||||
email: Email string to validate
|
||||
|
||||
Returns:
|
||||
True if valid email format, False otherwise
|
||||
"""
|
||||
try:
|
||||
EmailStr._validate(email)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
+1
-1
@@ -3,7 +3,7 @@ from src.core.exceptions import setup_exception_handlers
|
||||
from src.api.v1 import api_router
|
||||
|
||||
app = FastAPI(
|
||||
title="mockupAWS", description="AWS Cost Simulation Platform", version="0.4.0"
|
||||
title="mockupAWS", description="AWS Cost Simulation Platform", version="0.5.0"
|
||||
)
|
||||
|
||||
# Setup exception handlers
|
||||
|
||||
@@ -6,6 +6,8 @@ from src.models.scenario_log import ScenarioLog
|
||||
from src.models.scenario_metric import ScenarioMetric
|
||||
from src.models.aws_pricing import AwsPricing
|
||||
from src.models.report import Report
|
||||
from src.models.user import User
|
||||
from src.models.api_key import APIKey
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
@@ -14,4 +16,6 @@ __all__ = [
|
||||
"ScenarioMetric",
|
||||
"AwsPricing",
|
||||
"Report",
|
||||
"User",
|
||||
"APIKey",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""API Key model."""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from src.models.base import Base
|
||||
|
||||
|
||||
class APIKey(Base):
|
||||
"""API Key model for programmatic access."""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
key_hash = Column(String(255), nullable=False, unique=True)
|
||||
key_prefix = Column(String(8), nullable=False)
|
||||
name = Column(String(255), nullable=True)
|
||||
scopes = Column(JSONB, default=list)
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="api_keys")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""User model."""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin
|
||||
|
||||
|
||||
class User(Base, TimestampMixin):
|
||||
"""User model for authentication."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
email = Column(String(255), nullable=False, unique=True)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
full_name = Column(String(255), nullable=True)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False)
|
||||
last_login = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
api_keys = relationship(
|
||||
"APIKey", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -25,6 +25,28 @@ from src.schemas.report import (
|
||||
ReportList,
|
||||
ReportGenerateResponse,
|
||||
)
|
||||
from src.schemas.user import (
|
||||
UserBase,
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
UserResponse,
|
||||
UserLogin,
|
||||
TokenResponse,
|
||||
TokenRefresh,
|
||||
PasswordChange,
|
||||
PasswordResetRequest,
|
||||
PasswordReset,
|
||||
AuthResponse,
|
||||
)
|
||||
from src.schemas.api_key import (
|
||||
APIKeyBase,
|
||||
APIKeyCreate,
|
||||
APIKeyUpdate,
|
||||
APIKeyResponse,
|
||||
APIKeyCreateResponse,
|
||||
APIKeyList,
|
||||
APIKeyValidation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ScenarioBase",
|
||||
@@ -47,4 +69,22 @@ __all__ = [
|
||||
"ReportStatusResponse",
|
||||
"ReportList",
|
||||
"ReportGenerateResponse",
|
||||
"UserBase",
|
||||
"UserCreate",
|
||||
"UserUpdate",
|
||||
"UserResponse",
|
||||
"UserLogin",
|
||||
"TokenResponse",
|
||||
"TokenRefresh",
|
||||
"PasswordChange",
|
||||
"PasswordResetRequest",
|
||||
"PasswordReset",
|
||||
"AuthResponse",
|
||||
"APIKeyBase",
|
||||
"APIKeyCreate",
|
||||
"APIKeyUpdate",
|
||||
"APIKeyResponse",
|
||||
"APIKeyCreateResponse",
|
||||
"APIKeyList",
|
||||
"APIKeyValidation",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""API Key schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class APIKeyBase(BaseModel):
|
||||
"""Base API key schema."""
|
||||
|
||||
name: Optional[str] = Field(None, max_length=255)
|
||||
scopes: List[str] = Field(default_factory=list)
|
||||
expires_days: Optional[int] = Field(None, ge=1, le=365)
|
||||
|
||||
|
||||
class APIKeyCreate(APIKeyBase):
|
||||
"""Schema for creating an API key."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyUpdate(BaseModel):
|
||||
"""Schema for updating an API key."""
|
||||
|
||||
name: Optional[str] = Field(None, max_length=255)
|
||||
|
||||
|
||||
class APIKeyResponse(BaseModel):
|
||||
"""Schema for API key response (without key_hash)."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
name: Optional[str]
|
||||
key_prefix: str
|
||||
scopes: List[str]
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime] = None
|
||||
last_used_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class APIKeyCreateResponse(APIKeyResponse):
|
||||
"""Schema for API key creation response (includes full key, ONLY ONCE!)."""
|
||||
|
||||
key: str # Full key shown only at creation
|
||||
|
||||
|
||||
class APIKeyList(BaseModel):
|
||||
"""Schema for list of API keys."""
|
||||
|
||||
items: List[APIKeyResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class APIKeyValidation(BaseModel):
|
||||
"""Schema for API key validation."""
|
||||
|
||||
key: str
|
||||
@@ -0,0 +1,94 @@
|
||||
"""User schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, EmailStr, Field, ConfigDict
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""Base user schema."""
|
||||
|
||||
email: EmailStr
|
||||
full_name: Optional[str] = Field(None, max_length=255)
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""Schema for creating a user."""
|
||||
|
||||
password: str = Field(..., min_length=8, max_length=100)
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating a user."""
|
||||
|
||||
full_name: Optional[str] = Field(None, max_length=255)
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
"""Schema for user response (no password)."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
|
||||
class UserInDB(UserResponse):
|
||||
"""Schema for user in DB (includes password_hash, internal use only)."""
|
||||
|
||||
password_hash: str
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
"""Schema for user login."""
|
||||
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Schema for token response."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenRefresh(BaseModel):
|
||||
"""Schema for token refresh."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Schema for password change."""
|
||||
|
||||
old_password: str
|
||||
new_password: str = Field(..., min_length=8, max_length=100)
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
"""Schema for password reset request."""
|
||||
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Schema for password reset."""
|
||||
|
||||
token: str
|
||||
new_password: str = Field(..., min_length=8, max_length=100)
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""Schema for auth response with user and tokens."""
|
||||
|
||||
user: UserResponse
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
@@ -4,6 +4,35 @@ from src.services.pii_detector import PIIDetector, pii_detector, PIIDetectionRes
|
||||
from src.services.cost_calculator import CostCalculator, cost_calculator
|
||||
from src.services.ingest_service import IngestService, ingest_service
|
||||
from src.services.report_service import ReportService, report_service
|
||||
from src.services.auth_service import (
|
||||
register_user,
|
||||
authenticate_user,
|
||||
change_password,
|
||||
reset_password_request,
|
||||
reset_password,
|
||||
get_user_by_id,
|
||||
get_user_by_email,
|
||||
create_tokens_for_user,
|
||||
AuthenticationError,
|
||||
EmailAlreadyExistsError,
|
||||
InvalidCredentialsError,
|
||||
UserNotFoundError,
|
||||
InvalidPasswordError,
|
||||
InvalidTokenError,
|
||||
)
|
||||
from src.services.apikey_service import (
|
||||
create_api_key,
|
||||
validate_api_key,
|
||||
list_api_keys,
|
||||
get_api_key,
|
||||
revoke_api_key,
|
||||
rotate_api_key,
|
||||
update_api_key,
|
||||
APIKeyError,
|
||||
APIKeyNotFoundError,
|
||||
APIKeyRevokedError,
|
||||
APIKeyExpiredError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PIIDetector",
|
||||
@@ -15,4 +44,29 @@ __all__ = [
|
||||
"ingest_service",
|
||||
"ReportService",
|
||||
"report_service",
|
||||
"register_user",
|
||||
"authenticate_user",
|
||||
"change_password",
|
||||
"reset_password_request",
|
||||
"reset_password",
|
||||
"get_user_by_id",
|
||||
"get_user_by_email",
|
||||
"create_tokens_for_user",
|
||||
"create_api_key",
|
||||
"validate_api_key",
|
||||
"list_api_keys",
|
||||
"get_api_key",
|
||||
"revoke_api_key",
|
||||
"rotate_api_key",
|
||||
"update_api_key",
|
||||
"AuthenticationError",
|
||||
"EmailAlreadyExistsError",
|
||||
"InvalidCredentialsError",
|
||||
"UserNotFoundError",
|
||||
"InvalidPasswordError",
|
||||
"InvalidTokenError",
|
||||
"APIKeyError",
|
||||
"APIKeyNotFoundError",
|
||||
"APIKeyRevokedError",
|
||||
"APIKeyExpiredError",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,296 @@
|
||||
"""API Key service."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.models.api_key import APIKey
|
||||
from src.models.user import User
|
||||
from src.core.security import generate_api_key, get_key_prefix, verify_api_key
|
||||
|
||||
|
||||
class APIKeyError(Exception):
|
||||
"""Base API key error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyNotFoundError(APIKeyError):
|
||||
"""API key not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyRevokedError(APIKeyError):
|
||||
"""API key has been revoked."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyExpiredError(APIKeyError):
|
||||
"""API key has expired."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def create_api_key(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
name: Optional[str] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
expires_days: Optional[int] = None,
|
||||
) -> tuple[APIKey, str]:
|
||||
"""Create a new API key for a user.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
name: Optional name for the key
|
||||
scopes: List of permission scopes
|
||||
expires_days: Optional expiration in days
|
||||
|
||||
Returns:
|
||||
Tuple of (APIKey object, full_key string)
|
||||
Note: full_key is shown ONLY ONCE at creation!
|
||||
"""
|
||||
# Generate key and hash
|
||||
full_key, key_hash = generate_api_key()
|
||||
key_prefix = get_key_prefix(full_key)
|
||||
|
||||
# Calculate expiration
|
||||
expires_at = None
|
||||
if expires_days:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_days)
|
||||
|
||||
# Create API key record
|
||||
api_key = APIKey(
|
||||
user_id=user_id,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
name=name,
|
||||
scopes=scopes or [],
|
||||
expires_at=expires_at,
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
session.add(api_key)
|
||||
await session.commit()
|
||||
await session.refresh(api_key)
|
||||
|
||||
return api_key, full_key
|
||||
|
||||
|
||||
async def validate_api_key(
|
||||
session: AsyncSession,
|
||||
key: str,
|
||||
) -> Optional[User]:
|
||||
"""Validate an API key and return the associated user.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
key: Full API key
|
||||
|
||||
Returns:
|
||||
User object if key is valid, None otherwise
|
||||
"""
|
||||
if not key.startswith("mk_"):
|
||||
return None
|
||||
|
||||
# Extract prefix for initial lookup
|
||||
key_prefix = get_key_prefix(key)
|
||||
|
||||
# Find all active API keys with matching prefix
|
||||
result = await session.execute(
|
||||
select(APIKey).where(
|
||||
and_(
|
||||
APIKey.key_prefix == key_prefix,
|
||||
APIKey.is_active == True,
|
||||
)
|
||||
)
|
||||
)
|
||||
api_keys = result.scalars().all()
|
||||
|
||||
# Check each key's hash
|
||||
for api_key in api_keys:
|
||||
if verify_api_key(key, api_key.key_hash):
|
||||
# Check if expired
|
||||
if api_key.expires_at and api_key.expires_at < datetime.now(timezone.utc):
|
||||
return None
|
||||
|
||||
# Update last used
|
||||
api_key.last_used_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
# Return user
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == api_key.user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user and user.is_active:
|
||||
return user
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def list_api_keys(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
) -> List[APIKey]:
|
||||
"""List all API keys for a user (without key_hash).
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of APIKey objects
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(APIKey)
|
||||
.where(APIKey.user_id == user_id)
|
||||
.order_by(APIKey.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_api_key(
|
||||
session: AsyncSession,
|
||||
api_key_id: uuid.UUID,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
) -> Optional[APIKey]:
|
||||
"""Get a specific API key by ID.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
api_key_id: API key ID
|
||||
user_id: Optional user ID to verify ownership
|
||||
|
||||
Returns:
|
||||
APIKey object or None
|
||||
"""
|
||||
query = select(APIKey).where(APIKey.id == api_key_id)
|
||||
|
||||
if user_id:
|
||||
query = query.where(APIKey.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def revoke_api_key(
|
||||
session: AsyncSession,
|
||||
api_key_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
) -> bool:
|
||||
"""Revoke an API key.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
api_key_id: API key ID
|
||||
user_id: User ID (for ownership verification)
|
||||
|
||||
Returns:
|
||||
True if revoked successfully
|
||||
|
||||
Raises:
|
||||
APIKeyNotFoundError: If key not found
|
||||
"""
|
||||
api_key = await get_api_key(session, api_key_id, user_id)
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError("API key not found")
|
||||
|
||||
api_key.is_active = False
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def rotate_api_key(
|
||||
session: AsyncSession,
|
||||
api_key_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
) -> tuple[APIKey, str]:
|
||||
"""Rotate (regenerate) an API key.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
api_key_id: API key ID to rotate
|
||||
user_id: User ID (for ownership verification)
|
||||
|
||||
Returns:
|
||||
Tuple of (new APIKey object, new full_key string)
|
||||
|
||||
Raises:
|
||||
APIKeyNotFoundError: If key not found
|
||||
"""
|
||||
# Get existing key
|
||||
old_key = await get_api_key(session, api_key_id, user_id)
|
||||
|
||||
if not old_key:
|
||||
raise APIKeyNotFoundError("API key not found")
|
||||
|
||||
# Revoke old key
|
||||
old_key.is_active = False
|
||||
|
||||
# Generate new key
|
||||
full_key, key_hash = generate_api_key()
|
||||
key_prefix = get_key_prefix(full_key)
|
||||
|
||||
# Create new API key with same settings
|
||||
new_key = APIKey(
|
||||
user_id=user_id,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
name=old_key.name,
|
||||
scopes=old_key.scopes,
|
||||
expires_at=old_key.expires_at,
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
session.add(new_key)
|
||||
await session.commit()
|
||||
await session.refresh(new_key)
|
||||
|
||||
return new_key, full_key
|
||||
|
||||
|
||||
async def update_api_key(
|
||||
session: AsyncSession,
|
||||
api_key_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
name: Optional[str] = None,
|
||||
) -> APIKey:
|
||||
"""Update API key metadata.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
api_key_id: API key ID
|
||||
user_id: User ID (for ownership verification)
|
||||
name: New name for the key
|
||||
|
||||
Returns:
|
||||
Updated APIKey object
|
||||
|
||||
Raises:
|
||||
APIKeyNotFoundError: If key not found
|
||||
"""
|
||||
api_key = await get_api_key(session, api_key_id, user_id)
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError("API key not found")
|
||||
|
||||
if name is not None:
|
||||
api_key.name = name
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(api_key)
|
||||
|
||||
return api_key
|
||||
@@ -0,0 +1,307 @@
|
||||
"""Authentication service."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
import secrets
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.models.user import User
|
||||
from src.schemas.user import UserCreate, UserResponse
|
||||
from src.core.security import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
validate_email_format,
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Base authentication error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EmailAlreadyExistsError(AuthenticationError):
|
||||
"""Email already registered."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCredentialsError(AuthenticationError):
|
||||
"""Invalid email or password."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UserNotFoundError(AuthenticationError):
|
||||
"""User not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidPasswordError(AuthenticationError):
|
||||
"""Invalid old password."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTokenError(AuthenticationError):
|
||||
"""Invalid or expired token."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# In-memory token store for password reset (in production, use Redis)
|
||||
_password_reset_tokens: dict[str, str] = {} # token -> email
|
||||
|
||||
|
||||
async def register_user(
|
||||
session: AsyncSession,
|
||||
email: str,
|
||||
password: str,
|
||||
full_name: Optional[str] = None,
|
||||
) -> User:
|
||||
"""Register a new user.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
email: User email
|
||||
password: User password (will be hashed)
|
||||
full_name: Optional full name
|
||||
|
||||
Returns:
|
||||
Created user object
|
||||
|
||||
Raises:
|
||||
EmailAlreadyExistsError: If email is already registered
|
||||
ValueError: If email format is invalid
|
||||
"""
|
||||
# Validate email format
|
||||
if not validate_email_format(email):
|
||||
raise ValueError("Invalid email format")
|
||||
|
||||
# Check if email already exists
|
||||
result = await session.execute(select(User).where(User.email == email))
|
||||
if result.scalar_one_or_none():
|
||||
raise EmailAlreadyExistsError(f"Email {email} is already registered")
|
||||
|
||||
# Hash password
|
||||
password_hash = hash_password(password)
|
||||
|
||||
# Create user
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
full_name=full_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def authenticate_user(
|
||||
session: AsyncSession,
|
||||
email: str,
|
||||
password: str,
|
||||
) -> Optional[User]:
|
||||
"""Authenticate a user with email and password.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
email: User email
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
User object if authenticated, None otherwise
|
||||
"""
|
||||
# Find user by email
|
||||
result = await session.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
return None
|
||||
|
||||
# Verify password
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
|
||||
# Update last login
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def change_password(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
old_password: str,
|
||||
new_password: str,
|
||||
) -> bool:
|
||||
"""Change user password.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
old_password: Current password
|
||||
new_password: New password
|
||||
|
||||
Returns:
|
||||
True if password was changed successfully
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If user not found
|
||||
InvalidPasswordError: If old password is incorrect
|
||||
"""
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError("User not found")
|
||||
|
||||
# Verify old password
|
||||
if not verify_password(old_password, user.password_hash):
|
||||
raise InvalidPasswordError("Current password is incorrect")
|
||||
|
||||
# Hash and set new password
|
||||
user.password_hash = hash_password(new_password)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def reset_password_request(
|
||||
session: AsyncSession,
|
||||
email: str,
|
||||
) -> str:
|
||||
"""Request a password reset.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
email: User email
|
||||
|
||||
Returns:
|
||||
Reset token (to be sent via email)
|
||||
|
||||
Note:
|
||||
Always returns a token even if email doesn't exist (security)
|
||||
"""
|
||||
# Generate secure random token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
# Check if user exists
|
||||
result = await session.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user:
|
||||
# Store token (in production, use Redis with expiration)
|
||||
_password_reset_tokens[token] = email
|
||||
|
||||
return token
|
||||
|
||||
|
||||
async def reset_password(
|
||||
session: AsyncSession,
|
||||
token: str,
|
||||
new_password: str,
|
||||
) -> bool:
|
||||
"""Reset password using a token.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
token: Reset token
|
||||
new_password: New password
|
||||
|
||||
Returns:
|
||||
True if password was reset successfully
|
||||
|
||||
Raises:
|
||||
InvalidTokenError: If token is invalid or expired
|
||||
UserNotFoundError: If user not found
|
||||
"""
|
||||
# Verify token
|
||||
email = _password_reset_tokens.get(token)
|
||||
if not email:
|
||||
raise InvalidTokenError("Invalid or expired token")
|
||||
|
||||
# Find user
|
||||
result = await session.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError("User not found")
|
||||
|
||||
# Update password
|
||||
user.password_hash = hash_password(new_password)
|
||||
await session.commit()
|
||||
|
||||
# Remove used token
|
||||
del _password_reset_tokens[token]
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def get_user_by_id(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
) -> Optional[User]:
|
||||
"""Get user by ID.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
User object or None
|
||||
"""
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_user_by_email(
|
||||
session: AsyncSession,
|
||||
email: str,
|
||||
) -> Optional[User]:
|
||||
"""Get user by email.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
email: User email
|
||||
|
||||
Returns:
|
||||
User object or None
|
||||
"""
|
||||
result = await session.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def create_tokens_for_user(user: User) -> tuple[str, str]:
|
||||
"""Create access and refresh tokens for a user.
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
|
||||
Returns:
|
||||
Tuple of (access_token, refresh_token)
|
||||
"""
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
}
|
||||
|
||||
access_token = create_access_token(token_data)
|
||||
refresh_token = create_refresh_token(token_data)
|
||||
|
||||
return access_token, refresh_token
|
||||
Reference in New Issue
Block a user