From 3170f10e8687deb26ba72396214e9624cd168303 Mon Sep 17 00:00:00 2001 From: minhtrannhat Date: Sun, 7 Dec 2025 12:00:00 +0000 Subject: [PATCH] feat(api): Pydantic schemas + Data Repositories --- app/repositories/__init__.py | 17 + app/repositories/incident.py | 161 +++++ app/repositories/notification.py | 199 ++++++ app/repositories/org.py | 125 ++++ app/repositories/refresh_token.py | 396 +++++++++++ app/repositories/service.py | 80 +++ app/repositories/user.py | 63 ++ app/schemas/__init__.py | 50 ++ app/schemas/auth.py | 42 ++ app/schemas/common.py | 20 + app/schemas/incident.py | 57 ++ app/schemas/org.py | 69 ++ migrations/0004_refresh_token_rotation.sql | 18 + pyproject.toml | 5 +- tests/conftest.py | 95 +++ tests/repositories/__init__.py | 1 + tests/repositories/test_incident.py | 389 ++++++++++ tests/repositories/test_notification.py | 362 ++++++++++ tests/repositories/test_org.py | 250 +++++++ tests/repositories/test_refresh_token.py | 788 +++++++++++++++++++++ tests/repositories/test_service.py | 201 ++++++ tests/repositories/test_user.py | 133 ++++ uv.lock | 31 +- 23 files changed, 3549 insertions(+), 3 deletions(-) create mode 100644 app/repositories/__init__.py create mode 100644 app/repositories/incident.py create mode 100644 app/repositories/notification.py create mode 100644 app/repositories/org.py create mode 100644 app/repositories/refresh_token.py create mode 100644 app/repositories/service.py create mode 100644 app/repositories/user.py create mode 100644 app/schemas/__init__.py create mode 100644 app/schemas/auth.py create mode 100644 app/schemas/common.py create mode 100644 app/schemas/incident.py create mode 100644 app/schemas/org.py create mode 100644 migrations/0004_refresh_token_rotation.sql create mode 100644 tests/conftest.py create mode 100644 tests/repositories/__init__.py create mode 100644 tests/repositories/test_incident.py create mode 100644 tests/repositories/test_notification.py create mode 100644 tests/repositories/test_org.py create mode 100644 tests/repositories/test_refresh_token.py create mode 100644 tests/repositories/test_service.py create mode 100644 tests/repositories/test_user.py diff --git a/app/repositories/__init__.py b/app/repositories/__init__.py new file mode 100644 index 0000000..9933832 --- /dev/null +++ b/app/repositories/__init__.py @@ -0,0 +1,17 @@ +"""Repository layer for database operations.""" + +from app.repositories.incident import IncidentRepository +from app.repositories.notification import NotificationRepository +from app.repositories.org import OrgRepository +from app.repositories.refresh_token import RefreshTokenRepository +from app.repositories.service import ServiceRepository +from app.repositories.user import UserRepository + +__all__ = [ + "IncidentRepository", + "NotificationRepository", + "OrgRepository", + "RefreshTokenRepository", + "ServiceRepository", + "UserRepository", +] diff --git a/app/repositories/incident.py b/app/repositories/incident.py new file mode 100644 index 0000000..4b3ec69 --- /dev/null +++ b/app/repositories/incident.py @@ -0,0 +1,161 @@ +"""Incident repository for database operations.""" + +from datetime import datetime +from typing import Any +from uuid import UUID + +import asyncpg + + +class IncidentRepository: + """Database operations for incidents.""" + + def __init__(self, conn: asyncpg.Connection) -> None: + self.conn = conn + + async def create( + self, + incident_id: UUID, + org_id: UUID, + service_id: UUID, + title: str, + description: str | None, + severity: str, + ) -> dict: + """Create a new incident.""" + row = await self.conn.fetchrow( + """ + INSERT INTO incidents (id, org_id, service_id, title, description, status, severity) + VALUES ($1, $2, $3, $4, $5, 'triggered', $6) + RETURNING id, org_id, service_id, title, description, status, severity, + version, created_at, updated_at + """, + incident_id, + org_id, + service_id, + title, + description, + severity, + ) + return dict(row) + + async def get_by_id(self, incident_id: UUID) -> dict | None: + """Get incident by ID.""" + row = await self.conn.fetchrow( + """ + SELECT id, org_id, service_id, title, description, status, severity, + version, created_at, updated_at + FROM incidents + WHERE id = $1 + """, + incident_id, + ) + return dict(row) if row else None + + async def get_by_org( + self, + org_id: UUID, + status: str | None = None, + cursor: datetime | None = None, + limit: int = 20, + ) -> list[dict]: + """Get incidents for an organization with optional filtering and pagination.""" + query = """ + SELECT id, org_id, service_id, title, description, status, severity, + version, created_at, updated_at + FROM incidents + WHERE org_id = $1 + """ + params: list[Any] = [org_id] + param_idx = 2 + + if status: + query += f" AND status = ${param_idx}" + params.append(status) + param_idx += 1 + + if cursor: + query += f" AND created_at < ${param_idx}" + params.append(cursor) + param_idx += 1 + + query += f" ORDER BY created_at DESC LIMIT ${param_idx}" + params.append(limit + 1) # Fetch one extra to check if there are more + + rows = await self.conn.fetch(query, *params) + return [dict(row) for row in rows] + + async def update_status( + self, + incident_id: UUID, + new_status: str, + expected_version: int, + ) -> dict | None: + """Update incident status with optimistic locking. + + Returns updated incident if successful, None if version mismatch. + """ + row = await self.conn.fetchrow( + """ + UPDATE incidents + SET status = $2, version = version + 1, updated_at = now() + WHERE id = $1 AND version = $3 + RETURNING id, org_id, service_id, title, description, status, severity, + version, created_at, updated_at + """, + incident_id, + new_status, + expected_version, + ) + return dict(row) if row else None + + async def add_event( + self, + event_id: UUID, + incident_id: UUID, + event_type: str, + actor_user_id: UUID | None, + payload: dict[str, Any] | None, + ) -> dict: + """Add an event to the incident timeline.""" + import json + + row = await self.conn.fetchrow( + """ + INSERT INTO incident_events (id, incident_id, event_type, actor_user_id, payload) + VALUES ($1, $2, $3, $4, $5) + RETURNING id, incident_id, event_type, actor_user_id, payload, created_at + """, + event_id, + incident_id, + event_type, + actor_user_id, + json.dumps(payload) if payload else None, + ) + result = dict(row) + + # Parse JSON payload back to dict + if result["payload"]: + result["payload"] = json.loads(result["payload"]) + return result + + async def get_events(self, incident_id: UUID) -> list[dict]: + """Get all events for an incident.""" + import json + + rows = await self.conn.fetch( + """ + SELECT id, incident_id, event_type, actor_user_id, payload, created_at + FROM incident_events + WHERE incident_id = $1 + ORDER BY created_at + """, + incident_id, + ) + results = [] + for row in rows: + result = dict(row) + if result["payload"]: + result["payload"] = json.loads(result["payload"]) + results.append(result) + return results diff --git a/app/repositories/notification.py b/app/repositories/notification.py new file mode 100644 index 0000000..076fe4f --- /dev/null +++ b/app/repositories/notification.py @@ -0,0 +1,199 @@ +"""Notification repository for database operations.""" + +from datetime import datetime +from uuid import UUID + +import asyncpg + + +class NotificationRepository: + """Database operations for notification targets and attempts.""" + + def __init__(self, conn: asyncpg.Connection) -> None: + self.conn = conn + + async def create_target( + self, + target_id: UUID, + org_id: UUID, + name: str, + target_type: str, + webhook_url: str | None = None, + enabled: bool = True, + ) -> dict: + """Create a new notification target.""" + row = await self.conn.fetchrow( + """ + INSERT INTO notification_targets (id, org_id, name, target_type, webhook_url, enabled) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at + """, + target_id, + org_id, + name, + target_type, + webhook_url, + enabled, + ) + return dict(row) + + async def get_target_by_id(self, target_id: UUID) -> dict | None: + """Get notification target by ID.""" + row = await self.conn.fetchrow( + """ + SELECT id, org_id, name, target_type, webhook_url, enabled, created_at + FROM notification_targets + WHERE id = $1 + """, + target_id, + ) + return dict(row) if row else None + + async def get_targets_by_org( + self, + org_id: UUID, + enabled_only: bool = False, + ) -> list[dict]: + """Get all notification targets for an organization.""" + query = """ + SELECT id, org_id, name, target_type, webhook_url, enabled, created_at + FROM notification_targets + WHERE org_id = $1 + """ + if enabled_only: + query += " AND enabled = true" + query += " ORDER BY name" + + rows = await self.conn.fetch(query, org_id) + return [dict(row) for row in rows] + + async def update_target( + self, + target_id: UUID, + name: str | None = None, + webhook_url: str | None = None, + enabled: bool | None = None, + ) -> dict | None: + """Update a notification target.""" + updates = [] + params = [target_id] + param_idx = 2 + + if name is not None: + updates.append(f"name = ${param_idx}") + params.append(name) + param_idx += 1 + + if webhook_url is not None: + updates.append(f"webhook_url = ${param_idx}") + params.append(webhook_url) + param_idx += 1 + + if enabled is not None: + updates.append(f"enabled = ${param_idx}") + params.append(enabled) + param_idx += 1 + + if not updates: + return await self.get_target_by_id(target_id) + + query = f""" + UPDATE notification_targets + SET {", ".join(updates)} + WHERE id = $1 + RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at + """ + row = await self.conn.fetchrow(query, *params) + return dict(row) if row else None + + async def delete_target(self, target_id: UUID) -> bool: + """Delete a notification target. Returns True if deleted.""" + result = await self.conn.execute( + "DELETE FROM notification_targets WHERE id = $1", + target_id, + ) + return result == "DELETE 1" + + async def create_attempt( + self, + attempt_id: UUID, + incident_id: UUID, + target_id: UUID, + ) -> dict: + """Create a notification attempt (idempotent via unique constraint).""" + row = await self.conn.fetchrow( + """ + INSERT INTO notification_attempts (id, incident_id, target_id, status) + VALUES ($1, $2, $3, 'pending') + ON CONFLICT (incident_id, target_id) DO UPDATE SET id = notification_attempts.id + RETURNING id, incident_id, target_id, status, error, sent_at, created_at + """, + attempt_id, + incident_id, + target_id, + ) + return dict(row) + + async def get_attempt(self, incident_id: UUID, target_id: UUID) -> dict | None: + """Get notification attempt for incident and target.""" + row = await self.conn.fetchrow( + """ + SELECT id, incident_id, target_id, status, error, sent_at, created_at + FROM notification_attempts + WHERE incident_id = $1 AND target_id = $2 + """, + incident_id, + target_id, + ) + return dict(row) if row else None + + async def update_attempt_success( + self, + attempt_id: UUID, + sent_at: datetime, + ) -> dict | None: + """Mark notification attempt as successful.""" + row = await self.conn.fetchrow( + """ + UPDATE notification_attempts + SET status = 'sent', sent_at = $2, error = NULL + WHERE id = $1 + RETURNING id, incident_id, target_id, status, error, sent_at, created_at + """, + attempt_id, + sent_at, + ) + return dict(row) if row else None + + async def update_attempt_failure( + self, + attempt_id: UUID, + error: str, + ) -> dict | None: + """Mark notification attempt as failed.""" + row = await self.conn.fetchrow( + """ + UPDATE notification_attempts + SET status = 'failed', error = $2 + WHERE id = $1 + RETURNING id, incident_id, target_id, status, error, sent_at, created_at + """, + attempt_id, + error, + ) + return dict(row) if row else None + + async def get_pending_attempts(self, incident_id: UUID) -> list[dict]: + """Get all pending notification attempts for an incident.""" + rows = await self.conn.fetch( + """ + SELECT na.id, na.incident_id, na.target_id, na.status, na.error, + na.sent_at, na.created_at, + nt.target_type, nt.webhook_url, nt.name as target_name + FROM notification_attempts na + JOIN notification_targets nt ON nt.id = na.target_id + WHERE na.incident_id = $1 AND na.status = 'pending' + """, + incident_id, + ) + return [dict(row) for row in rows] diff --git a/app/repositories/org.py b/app/repositories/org.py new file mode 100644 index 0000000..fce692d --- /dev/null +++ b/app/repositories/org.py @@ -0,0 +1,125 @@ +"""Organization repository for database operations.""" + +from uuid import UUID + +import asyncpg + + +class OrgRepository: + """Database operations for organizations.""" + + def __init__(self, conn: asyncpg.Connection) -> None: + self.conn = conn + + async def create( + self, + org_id: UUID, + name: str, + slug: str, + ) -> dict: + """Create a new organization.""" + row = await self.conn.fetchrow( + """ + INSERT INTO orgs (id, name, slug) + VALUES ($1, $2, $3) + RETURNING id, name, slug, created_at + """, + org_id, + name, + slug, + ) + return dict(row) + + async def get_by_id(self, org_id: UUID) -> dict | None: + """Get organization by ID.""" + row = await self.conn.fetchrow( + """ + SELECT id, name, slug, created_at + FROM orgs + WHERE id = $1 + """, + org_id, + ) + return dict(row) if row else None + + async def get_by_slug(self, slug: str) -> dict | None: + """Get organization by slug.""" + row = await self.conn.fetchrow( + """ + SELECT id, name, slug, created_at + FROM orgs + WHERE slug = $1 + """, + slug, + ) + return dict(row) if row else None + + async def add_member( + self, + member_id: UUID, + user_id: UUID, + org_id: UUID, + role: str, + ) -> dict: + """Add a member to an organization.""" + row = await self.conn.fetchrow( + """ + INSERT INTO org_members (id, user_id, org_id, role) + VALUES ($1, $2, $3, $4) + RETURNING id, user_id, org_id, role, created_at + """, + member_id, + user_id, + org_id, + role, + ) + return dict(row) + + async def get_member(self, user_id: UUID, org_id: UUID) -> dict | None: + """Get membership for a user in an organization.""" + row = await self.conn.fetchrow( + """ + SELECT om.id, om.user_id, om.org_id, om.role, om.created_at + FROM org_members om + WHERE om.user_id = $1 AND om.org_id = $2 + """, + user_id, + org_id, + ) + return dict(row) if row else None + + async def get_members(self, org_id: UUID) -> list[dict]: + """Get all members of an organization.""" + rows = await self.conn.fetch( + """ + SELECT om.id, om.user_id, u.email, om.role, om.created_at + FROM org_members om + JOIN users u ON u.id = om.user_id + WHERE om.org_id = $1 + ORDER BY om.created_at + """, + org_id, + ) + return [dict(row) for row in rows] + + async def get_user_orgs(self, user_id: UUID) -> list[dict]: + """Get all organizations a user belongs to.""" + rows = await self.conn.fetch( + """ + SELECT o.id, o.name, o.slug, o.created_at, om.role + FROM orgs o + JOIN org_members om ON om.org_id = o.id + WHERE om.user_id = $1 + ORDER BY o.created_at + """, + user_id, + ) + return [dict(row) for row in rows] + + async def slug_exists(self, slug: str) -> bool: + """Check if organization slug exists.""" + result = await self.conn.fetchval( + "SELECT EXISTS(SELECT 1 FROM orgs WHERE slug = $1)", + slug, + ) + return result diff --git a/app/repositories/refresh_token.py b/app/repositories/refresh_token.py new file mode 100644 index 0000000..b5388dd --- /dev/null +++ b/app/repositories/refresh_token.py @@ -0,0 +1,396 @@ +"""Refresh token repository for database operations. + +Security considerations implemented: +- Atomic rotation using SELECT FOR UPDATE to prevent race conditions +- Token chain tracking via rotated_to for reuse/theft detection +- Defense-in-depth validation with user_id and active_org_id checks +- Uses RETURNING for robust row counting instead of string parsing +""" + +from datetime import datetime +from uuid import UUID + +import asyncpg + + +class RefreshTokenRepository: + """Database operations for refresh tokens.""" + + def __init__(self, conn: asyncpg.Connection) -> None: + self.conn = conn + + async def create( + self, + token_id: UUID, + user_id: UUID, + token_hash: str, + active_org_id: UUID, + expires_at: datetime, + ) -> dict: + """Create a new refresh token.""" + row = await self.conn.fetchrow( + """ + INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at) + VALUES ($1, $2, $3, $4, $5) + RETURNING id, user_id, token_hash, active_org_id, expires_at, + revoked_at, rotated_to, created_at + """, + token_id, + user_id, + token_hash, + active_org_id, + expires_at, + ) + return dict(row) + + async def get_by_hash(self, token_hash: str) -> dict | None: + """Get refresh token by hash (includes revoked/expired for auditing).""" + row = await self.conn.fetchrow( + """ + SELECT id, user_id, token_hash, active_org_id, expires_at, + revoked_at, rotated_to, created_at + FROM refresh_tokens + WHERE token_hash = $1 + """, + token_hash, + ) + return dict(row) if row else None + + async def get_valid_by_hash( + self, + token_hash: str, + user_id: UUID | None = None, + active_org_id: UUID | None = None, + ) -> dict | None: + """Get refresh token by hash, only if valid. + + Validates: + - Token exists and matches hash + - Token is not revoked + - Token is not expired + - Token has not been rotated (rotated_to is NULL) + - Optionally: user_id matches (defense-in-depth) + - Optionally: active_org_id matches (defense-in-depth) + + Args: + token_hash: The hashed token value + user_id: If provided, token must belong to this user + active_org_id: If provided, token must be bound to this org + + Returns: + Token dict if valid, None otherwise + """ + query = """ + SELECT id, user_id, token_hash, active_org_id, expires_at, + revoked_at, rotated_to, created_at + FROM refresh_tokens + WHERE token_hash = $1 + AND revoked_at IS NULL + AND rotated_to IS NULL + AND expires_at > clock_timestamp() + """ + params: list = [token_hash] + param_idx = 2 + + if user_id is not None: + query += f" AND user_id = ${param_idx}" + params.append(user_id) + param_idx += 1 + + if active_org_id is not None: + query += f" AND active_org_id = ${param_idx}" + params.append(active_org_id) + + row = await self.conn.fetchrow(query, *params) + return dict(row) if row else None + + async def get_valid_for_rotation( + self, + token_hash: str, + user_id: UUID | None = None, + ) -> dict | None: + """Get and lock a valid token for rotation using SELECT FOR UPDATE. + + This acquires a row-level lock to prevent concurrent rotation attempts. + Must be called within a transaction. + + Args: + token_hash: The hashed token value + user_id: If provided, token must belong to this user + + Returns: + Token dict if valid and locked, None otherwise + """ + query = """ + SELECT id, user_id, token_hash, active_org_id, expires_at, + revoked_at, rotated_to, created_at + FROM refresh_tokens + WHERE token_hash = $1 + AND revoked_at IS NULL + AND rotated_to IS NULL + AND expires_at > clock_timestamp() + """ + params: list = [token_hash] + + if user_id is not None: + query += " AND user_id = $2" + params.append(user_id) + + query += " FOR UPDATE" + + row = await self.conn.fetchrow(query, *params) + return dict(row) if row else None + + async def check_token_reuse(self, token_hash: str) -> dict | None: + """Check if a token has already been rotated (potential theft). + + If a token is presented that has rotated_to set, it means: + 1. The token was legitimately rotated earlier + 2. Someone is now trying to use the old token + 3. This indicates the token may have been stolen + + Returns: + Token dict if this is a reused/stolen token, None if not found or not rotated + """ + row = await self.conn.fetchrow( + """ + SELECT id, user_id, token_hash, active_org_id, expires_at, + revoked_at, rotated_to, created_at + FROM refresh_tokens + WHERE token_hash = $1 AND rotated_to IS NOT NULL + """, + token_hash, + ) + return dict(row) if row else None + + async def revoke_token_chain(self, token_id: UUID) -> int: + """Revoke a token and all tokens in its chain (for breach response). + + When token reuse is detected, this revokes: + 1. The original stolen token + 2. Any token it was rotated to (and their rotations, recursively) + + Args: + token_id: The ID of the compromised token + + Returns: + Count of tokens revoked + """ + # Use recursive CTE to find all tokens in the chain + rows = await self.conn.fetch( + """ + WITH RECURSIVE token_chain AS ( + -- Start with the given token + SELECT id, rotated_to + FROM refresh_tokens + WHERE id = $1 + + UNION ALL + + -- Follow the chain via rotated_to + SELECT rt.id, rt.rotated_to + FROM refresh_tokens rt + INNER JOIN token_chain tc ON rt.id = tc.rotated_to + ) + UPDATE refresh_tokens + SET revoked_at = clock_timestamp() + WHERE id IN (SELECT id FROM token_chain) + AND revoked_at IS NULL + RETURNING id + """, + token_id, + ) + return len(rows) + + async def rotate( + self, + old_token_hash: str, + new_token_id: UUID, + new_token_hash: str, + new_expires_at: datetime, + new_active_org_id: UUID | None = None, + expected_user_id: UUID | None = None, + ) -> dict | None: + """Atomically rotate a refresh token. + + This method: + 1. Validates the old token (not expired, not revoked, not already rotated) + 2. Locks the row to prevent concurrent rotation + 3. Marks old token as rotated (sets rotated_to) + 4. Creates new token with updated org if specified + 5. All in a single atomic operation + + Args: + old_token_hash: Hash of the token being rotated + new_token_id: UUID for the new token + new_token_hash: Hash for the new token + new_expires_at: Expiry time for the new token + new_active_org_id: New org ID (for org-switch), or None to keep current + expected_user_id: If provided, validates token belongs to this user + + Returns: + New token dict if rotation succeeded, None if old token invalid/expired + """ + # First, get and lock the old token + old_token = await self.get_valid_for_rotation(old_token_hash, expected_user_id) + if old_token is None: + return None + + # Determine the org for the new token + active_org_id = new_active_org_id or old_token["active_org_id"] + user_id = old_token["user_id"] + + # Create the new token + new_token = await self.conn.fetchrow( + """ + INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at) + VALUES ($1, $2, $3, $4, $5) + RETURNING id, user_id, token_hash, active_org_id, expires_at, + revoked_at, rotated_to, created_at + """, + new_token_id, + user_id, + new_token_hash, + active_org_id, + new_expires_at, + ) + + # Mark the old token as rotated (not revoked - for reuse detection) + await self.conn.execute( + """ + UPDATE refresh_tokens + SET rotated_to = $2 + WHERE id = $1 + """, + old_token["id"], + new_token_id, + ) + + return dict(new_token) + + async def revoke(self, token_id: UUID) -> bool: + """Revoke a refresh token by ID. + + Returns: + True if token was revoked, False if not found or already revoked + """ + row = await self.conn.fetchrow( + """ + UPDATE refresh_tokens + SET revoked_at = clock_timestamp() + WHERE id = $1 AND revoked_at IS NULL + RETURNING id + """, + token_id, + ) + return row is not None + + async def revoke_by_hash(self, token_hash: str) -> bool: + """Revoke a refresh token by hash. + + Returns: + True if token was revoked, False if not found or already revoked + """ + row = await self.conn.fetchrow( + """ + UPDATE refresh_tokens + SET revoked_at = clock_timestamp() + WHERE token_hash = $1 AND revoked_at IS NULL + RETURNING id + """, + token_hash, + ) + return row is not None + + async def revoke_all_for_user(self, user_id: UUID) -> int: + """Revoke all active refresh tokens for a user. + + Use this for: + - User-initiated logout from all devices + - Password change + - Account compromise response + + Returns: + Count of tokens revoked + """ + rows = await self.conn.fetch( + """ + UPDATE refresh_tokens + SET revoked_at = clock_timestamp() + WHERE user_id = $1 AND revoked_at IS NULL + RETURNING id + """, + user_id, + ) + return len(rows) + + async def revoke_all_for_user_except(self, user_id: UUID, keep_token_id: UUID) -> int: + """Revoke all tokens for a user except one (logout other sessions). + + Args: + user_id: The user whose tokens to revoke + keep_token_id: The token ID to keep active (current session) + + Returns: + Count of tokens revoked + """ + rows = await self.conn.fetch( + """ + UPDATE refresh_tokens + SET revoked_at = clock_timestamp() + WHERE user_id = $1 AND revoked_at IS NULL AND id != $2 + RETURNING id + """, + user_id, + keep_token_id, + ) + return len(rows) + + async def get_active_tokens_for_user(self, user_id: UUID) -> list[dict]: + """Get all active (non-revoked, non-expired, non-rotated) tokens for a user. + + Useful for: + - Showing active sessions + - Auditing + + Returns: + List of active token records + """ + rows = await self.conn.fetch( + """ + SELECT id, user_id, token_hash, active_org_id, expires_at, + revoked_at, rotated_to, created_at + FROM refresh_tokens + WHERE user_id = $1 + AND revoked_at IS NULL + AND rotated_to IS NULL + AND expires_at > clock_timestamp() + ORDER BY created_at DESC + """, + user_id, + ) + return [dict(row) for row in rows] + + async def cleanup_expired(self, older_than_days: int = 30) -> int: + """Delete expired tokens older than specified days. + + Note: This performs a hard delete. For audit trails, I think we should: + - Archiving to a separate table first + - Using partitioning with retention policies + - Only calling this for tokens well past their expiry + + Args: + older_than_days: Only delete tokens expired more than this many days ago + + Returns: + Count of tokens deleted + """ + rows = await self.conn.fetch( + """ + DELETE FROM refresh_tokens + WHERE expires_at < clock_timestamp() - interval '1 day' * $1 + RETURNING id + """, + older_than_days, + ) + return len(rows) diff --git a/app/repositories/service.py b/app/repositories/service.py new file mode 100644 index 0000000..01e846d --- /dev/null +++ b/app/repositories/service.py @@ -0,0 +1,80 @@ +"""Service repository for database operations.""" + +from uuid import UUID + +import asyncpg + + +class ServiceRepository: + """Database operations for services.""" + + def __init__(self, conn: asyncpg.Connection) -> None: + self.conn = conn + + async def create( + self, + service_id: UUID, + org_id: UUID, + name: str, + slug: str, + ) -> dict: + """Create a new service.""" + row = await self.conn.fetchrow( + """ + INSERT INTO services (id, org_id, name, slug) + VALUES ($1, $2, $3, $4) + RETURNING id, org_id, name, slug, created_at + """, + service_id, + org_id, + name, + slug, + ) + return dict(row) + + async def get_by_id(self, service_id: UUID) -> dict | None: + """Get service by ID.""" + row = await self.conn.fetchrow( + """ + SELECT id, org_id, name, slug, created_at + FROM services + WHERE id = $1 + """, + service_id, + ) + return dict(row) if row else None + + async def get_by_org(self, org_id: UUID) -> list[dict]: + """Get all services for an organization.""" + rows = await self.conn.fetch( + """ + SELECT id, org_id, name, slug, created_at + FROM services + WHERE org_id = $1 + ORDER BY name + """, + org_id, + ) + return [dict(row) for row in rows] + + async def get_by_slug(self, org_id: UUID, slug: str) -> dict | None: + """Get service by org and slug.""" + row = await self.conn.fetchrow( + """ + SELECT id, org_id, name, slug, created_at + FROM services + WHERE org_id = $1 AND slug = $2 + """, + org_id, + slug, + ) + return dict(row) if row else None + + async def slug_exists(self, org_id: UUID, slug: str) -> bool: + """Check if service slug exists in organization.""" + result = await self.conn.fetchval( + "SELECT EXISTS(SELECT 1 FROM services WHERE org_id = $1 AND slug = $2)", + org_id, + slug, + ) + return result diff --git a/app/repositories/user.py b/app/repositories/user.py new file mode 100644 index 0000000..7ddaf93 --- /dev/null +++ b/app/repositories/user.py @@ -0,0 +1,63 @@ +"""User repository for database operations.""" + +from uuid import UUID + +import asyncpg + + +class UserRepository: + """Database operations for users.""" + + def __init__(self, conn: asyncpg.Connection) -> None: + self.conn = conn + + async def create( + self, + user_id: UUID, + email: str, + password_hash: str, + ) -> dict: + """Create a new user.""" + row = await self.conn.fetchrow( + """ + INSERT INTO users (id, email, password_hash) + VALUES ($1, $2, $3) + RETURNING id, email, created_at + """, + user_id, + email, + password_hash, + ) + return dict(row) + + async def get_by_id(self, user_id: UUID) -> dict | None: + """Get user by ID.""" + row = await self.conn.fetchrow( + """ + SELECT id, email, password_hash, created_at + FROM users + WHERE id = $1 + """, + user_id, + ) + return dict(row) if row else None + + async def get_by_email(self, email: str) -> dict | None: + """Get user by email.""" + row = await self.conn.fetchrow( + """ + SELECT id, email, password_hash, created_at + FROM users + WHERE email = $1 + """, + email, + ) + return dict(row) if row else None + + async def exists_by_email(self, email: str) -> bool: + """Check if user exists by email.""" + result = await self.conn.fetchval( + "SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)", + email, + ) + return result diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..c8c6da1 --- /dev/null +++ b/app/schemas/__init__.py @@ -0,0 +1,50 @@ +"""Pydantic schemas for request/response models.""" + +from app.schemas.auth import ( + LoginRequest, + RefreshRequest, + RegisterRequest, + SwitchOrgRequest, + TokenResponse, +) +from app.schemas.common import CursorParams, PaginatedResponse +from app.schemas.incident import ( + CommentRequest, + IncidentCreate, + IncidentEventResponse, + IncidentResponse, + TransitionRequest, +) +from app.schemas.org import ( + MemberResponse, + NotificationTargetCreate, + NotificationTargetResponse, + OrgResponse, + ServiceCreate, + ServiceResponse, +) + +__all__ = [ + # Auth + "LoginRequest", + "RefreshRequest", + "RegisterRequest", + "SwitchOrgRequest", + "TokenResponse", + # Common + "CursorParams", + "PaginatedResponse", + # Incident + "CommentRequest", + "IncidentCreate", + "IncidentEventResponse", + "IncidentResponse", + "TransitionRequest", + # Org + "MemberResponse", + "NotificationTargetCreate", + "NotificationTargetResponse", + "OrgResponse", + "ServiceCreate", + "ServiceResponse", +] diff --git a/app/schemas/auth.py b/app/schemas/auth.py new file mode 100644 index 0000000..9e5b5a0 --- /dev/null +++ b/app/schemas/auth.py @@ -0,0 +1,42 @@ +"""Authentication schemas.""" + +from uuid import UUID + +from pydantic import BaseModel, EmailStr, Field + + +class RegisterRequest(BaseModel): + """Request body for user registration.""" + + email: EmailStr + password: str = Field(min_length=8, max_length=128) + org_name: str = Field(min_length=1, max_length=100, description="Name for the default org") + + +class LoginRequest(BaseModel): + """Request body for user login.""" + + email: EmailStr + password: str + + +class RefreshRequest(BaseModel): + """Request body for token refresh.""" + + refresh_token: str + + +class SwitchOrgRequest(BaseModel): + """Request body for switching active organization.""" + + org_id: UUID + refresh_token: str + + +class TokenResponse(BaseModel): + """Response containing access and refresh tokens.""" + + access_token: str + refresh_token: str + token_type: str = "bearer" + expires_in: int = Field(description="Access token expiry in seconds") diff --git a/app/schemas/common.py b/app/schemas/common.py new file mode 100644 index 0000000..c9f67c8 --- /dev/null +++ b/app/schemas/common.py @@ -0,0 +1,20 @@ +"""Common schemas used across the API.""" + +from pydantic import BaseModel, Field + + +class CursorParams(BaseModel): + """Pagination parameters using cursor-based pagination.""" + + cursor: str | None = Field(default=None, description="Cursor for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of items per page") + + +class PaginatedResponse[T](BaseModel): + """Generic paginated response wrapper.""" + + items: list[T] + next_cursor: str | None = Field( + default=None, description="Cursor for next page, null if no more items" + ) + has_more: bool = Field(description="Whether there are more items") diff --git a/app/schemas/incident.py b/app/schemas/incident.py new file mode 100644 index 0000000..2f48cf3 --- /dev/null +++ b/app/schemas/incident.py @@ -0,0 +1,57 @@ +"""Incident-related schemas.""" + +from datetime import datetime +from typing import Any, Literal +from uuid import UUID + +from pydantic import BaseModel, Field + +IncidentStatus = Literal["triggered", "acknowledged", "mitigated", "resolved"] +IncidentSeverity = Literal["critical", "high", "medium", "low"] + + +class IncidentCreate(BaseModel): + """Request body for creating an incident.""" + + title: str = Field(min_length=1, max_length=200) + description: str | None = Field(default=None, max_length=5000) + severity: IncidentSeverity = "medium" + + +class IncidentResponse(BaseModel): + """Incident response.""" + + id: UUID + service_id: UUID + title: str + description: str | None + status: IncidentStatus + severity: IncidentSeverity + version: int + created_at: datetime + updated_at: datetime + + +class IncidentEventResponse(BaseModel): + """Incident event response.""" + + id: UUID + incident_id: UUID + event_type: str + actor_user_id: UUID | None + payload: dict[str, Any] | None + created_at: datetime + + +class TransitionRequest(BaseModel): + """Request body for transitioning incident status.""" + + to_status: IncidentStatus + version: int = Field(description="Current version for optimistic locking") + note: str | None = Field(default=None, max_length=1000) + + +class CommentRequest(BaseModel): + """Request body for adding a comment to an incident.""" + + content: str = Field(min_length=1, max_length=5000) diff --git a/app/schemas/org.py b/app/schemas/org.py new file mode 100644 index 0000000..9e316d6 --- /dev/null +++ b/app/schemas/org.py @@ -0,0 +1,69 @@ +"""Organization-related schemas.""" + +from datetime import datetime +from typing import Literal +from uuid import UUID + +from pydantic import BaseModel, Field, HttpUrl + + +class OrgResponse(BaseModel): + """Organization summary response.""" + + id: UUID + name: str + slug: str + created_at: datetime + + +class MemberResponse(BaseModel): + """Organization member response.""" + + id: UUID + user_id: UUID + email: str + role: Literal["admin", "member", "viewer"] + created_at: datetime + + +class ServiceCreate(BaseModel): + """Request body for creating a service.""" + + name: str = Field(min_length=1, max_length=100) + slug: str = Field( + min_length=1, + max_length=50, + pattern=r"^[a-z0-9]+(?:-[a-z0-9]+)*$", + description="URL-friendly identifier (lowercase, hyphens allowed)", + ) + + +class ServiceResponse(BaseModel): + """Service response.""" + + id: UUID + name: str + slug: str + created_at: datetime + + +class NotificationTargetCreate(BaseModel): + """Request body for creating a notification target.""" + + name: str = Field(min_length=1, max_length=100) + target_type: Literal["webhook", "email", "slack"] + webhook_url: HttpUrl | None = Field( + default=None, description="Required for webhook type" + ) + enabled: bool = True + + +class NotificationTargetResponse(BaseModel): + """Notification target response.""" + + id: UUID + name: str + target_type: Literal["webhook", "email", "slack"] + webhook_url: str | None + enabled: bool + created_at: datetime diff --git a/migrations/0004_refresh_token_rotation.sql b/migrations/0004_refresh_token_rotation.sql new file mode 100644 index 0000000..1fcd725 --- /dev/null +++ b/migrations/0004_refresh_token_rotation.sql @@ -0,0 +1,18 @@ +-- Enhance refresh tokens for secure rotation and reuse detection +-- Adds rotated_to column to track token chains and detect stolen token reuse + +-- Add rotated_to column to track which token this was rotated into +-- When a token is rotated, we store the ID of the new token here +-- If a token with rotated_to set is used again, it indicates token theft +ALTER TABLE refresh_tokens ADD COLUMN rotated_to UUID REFERENCES refresh_tokens(id); + +-- Index for efficient cleanup queries on expires_at +CREATE INDEX idx_refresh_tokens_expires ON refresh_tokens(expires_at); + +-- Index for finding active tokens per user (for revoke_all and listing) +CREATE INDEX idx_refresh_tokens_user_active ON refresh_tokens(user_id, revoked_at) + WHERE revoked_at IS NULL; + +-- Index for reuse detection queries +CREATE INDEX idx_refresh_tokens_rotated ON refresh_tokens(rotated_to) + WHERE rotated_to IS NOT NULL; diff --git a/pyproject.toml b/pyproject.toml index 3f653e7..6f423f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dependencies = [ "fastapi>=0.115.0", "uvicorn[standard]>=0.32.0", "asyncpg>=0.30.0", - "pydantic>=2.0.0", + "pydantic[email]>=2.0.0", "pydantic-settings>=2.0.0", "python-jose[cryptography]>=3.3.0", "bcrypt>=4.0.0", @@ -38,6 +38,9 @@ target-version = "py314" [tool.ruff.lint] select = ["E", "F", "I", "N", "W", "UP"] +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = ["E501"] # Allow longer lines in tests for descriptive method names + [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..490f7a7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,95 @@ +"""Shared pytest fixtures for all tests.""" + +from __future__ import annotations + +import os +from uuid import uuid4 + +import asyncpg +import pytest + +# Set test environment variables before importing app modules +os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@localhost:5432/incidentops_test") +os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only") +os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1") + + +# Module-level setup: create database and run migrations once +_db_initialized = False + + +async def _init_test_db() -> None: + """Initialize test database and run migrations (once per session).""" + global _db_initialized + if _db_initialized: + return + + admin_dsn = os.environ["DATABASE_URL"].rsplit("/", 1)[0] + "/postgres" + test_db_name = "incidentops_test" + + admin_conn = await asyncpg.connect(admin_dsn) + try: + # Terminate existing connections to the test database + await admin_conn.execute(f""" + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{test_db_name}' + AND pid <> pg_backend_pid() + """) + # Drop and recreate test database + await admin_conn.execute(f"DROP DATABASE IF EXISTS {test_db_name}") + await admin_conn.execute(f"CREATE DATABASE {test_db_name}") + finally: + await admin_conn.close() + + # Connect to test database and run migrations + test_dsn = os.environ["DATABASE_URL"] + conn = await asyncpg.connect(test_dsn) + + try: + # Run migrations + migrations_dir = os.path.join(os.path.dirname(__file__), "..", "migrations") + migration_files = sorted( + f for f in os.listdir(migrations_dir) if f.endswith(".sql") + ) + + for migration_file in migration_files: + migration_path = os.path.join(migrations_dir, migration_file) + with open(migration_path) as f: + sql = f.read() + await conn.execute(sql) + finally: + await conn.close() + + _db_initialized = True + + +@pytest.fixture +async def db_conn() -> asyncpg.Connection: + """Get a database connection with transaction rollback for test isolation.""" + await _init_test_db() + + test_dsn = os.environ["DATABASE_URL"] + conn = await asyncpg.connect(test_dsn) + + # Start a transaction that will be rolled back + tr = conn.transaction() + await tr.start() + + try: + yield conn + finally: + await tr.rollback() + await conn.close() + + +@pytest.fixture +def make_user_id() -> uuid4: + """Factory for generating user IDs.""" + return lambda: uuid4() + + +@pytest.fixture +def make_org_id() -> uuid4: + """Factory for generating org IDs.""" + return lambda: uuid4() diff --git a/tests/repositories/__init__.py b/tests/repositories/__init__.py new file mode 100644 index 0000000..0ef5327 --- /dev/null +++ b/tests/repositories/__init__.py @@ -0,0 +1 @@ +"""Repository tests.""" diff --git a/tests/repositories/test_incident.py b/tests/repositories/test_incident.py new file mode 100644 index 0000000..97688a0 --- /dev/null +++ b/tests/repositories/test_incident.py @@ -0,0 +1,389 @@ +"""Tests for IncidentRepository.""" + +from uuid import uuid4 + +import asyncpg +import pytest + +from app.repositories.incident import IncidentRepository +from app.repositories.org import OrgRepository +from app.repositories.service import ServiceRepository +from app.repositories.user import UserRepository + + +class TestIncidentRepository: + """Tests for IncidentRepository conforming to SPECS.md.""" + + async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4: + """Helper to create an org.""" + org_repo = OrgRepository(conn) + org_id = uuid4() + await org_repo.create(org_id, f"Org {slug}", slug) + return org_id + + async def _create_service(self, conn: asyncpg.Connection, org_id: uuid4, slug: str) -> uuid4: + """Helper to create a service.""" + service_repo = ServiceRepository(conn) + service_id = uuid4() + await service_repo.create(service_id, org_id, f"Service {slug}", slug) + return service_id + + async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4: + """Helper to create a user.""" + user_repo = UserRepository(conn) + user_id = uuid4() + await user_repo.create(user_id, email, "hash") + return user_id + + async def test_create_incident_returns_incident_data(self, db_conn: asyncpg.Connection) -> None: + """Creating an incident returns the incident data with triggered status.""" + org_id = await self._create_org(db_conn, "incident-org") + service_id = await self._create_service(db_conn, org_id, "incident-service") + repo = IncidentRepository(db_conn) + incident_id = uuid4() + + result = await repo.create( + incident_id, org_id, service_id, + title="Server Down", + description="Main API server is not responding", + severity="critical" + ) + + assert result["id"] == incident_id + assert result["org_id"] == org_id + assert result["service_id"] == service_id + assert result["title"] == "Server Down" + assert result["description"] == "Main API server is not responding" + assert result["status"] == "triggered" # Initial status per SPECS.md + assert result["severity"] == "critical" + assert result["version"] == 1 + assert result["created_at"] is not None + assert result["updated_at"] is not None + + async def test_create_incident_initial_status_is_triggered(self, db_conn: asyncpg.Connection) -> None: + """New incidents always start with 'triggered' status per SPECS.md state machine.""" + org_id = await self._create_org(db_conn, "triggered-org") + service_id = await self._create_service(db_conn, org_id, "triggered-service") + repo = IncidentRepository(db_conn) + + result = await repo.create(uuid4(), org_id, service_id, "Test", None, "low") + + assert result["status"] == "triggered" + + async def test_create_incident_initial_version_is_one(self, db_conn: asyncpg.Connection) -> None: + """New incidents start with version 1 for optimistic locking.""" + org_id = await self._create_org(db_conn, "version-org") + service_id = await self._create_service(db_conn, org_id, "version-service") + repo = IncidentRepository(db_conn) + + result = await repo.create(uuid4(), org_id, service_id, "Test", None, "medium") + + assert result["version"] == 1 + + async def test_create_incident_severity_must_be_valid(self, db_conn: asyncpg.Connection) -> None: + """Severity must be critical, high, medium, or low per SPECS.md.""" + org_id = await self._create_org(db_conn, "severity-org") + service_id = await self._create_service(db_conn, org_id, "severity-service") + repo = IncidentRepository(db_conn) + + # Valid severities + for severity in ["critical", "high", "medium", "low"]: + result = await repo.create(uuid4(), org_id, service_id, f"Test {severity}", None, severity) + assert result["severity"] == severity + + # Invalid severity + with pytest.raises(asyncpg.CheckViolationError): + await repo.create(uuid4(), org_id, service_id, "Invalid", None, "extreme") + + async def test_get_by_id_returns_incident(self, db_conn: asyncpg.Connection) -> None: + """get_by_id returns the correct incident.""" + org_id = await self._create_org(db_conn, "getbyid-org") + service_id = await self._create_service(db_conn, org_id, "getbyid-service") + repo = IncidentRepository(db_conn) + incident_id = uuid4() + + await repo.create(incident_id, org_id, service_id, "My Incident", "Details", "high") + result = await repo.get_by_id(incident_id) + + assert result is not None + assert result["id"] == incident_id + assert result["title"] == "My Incident" + + async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_by_id returns None for non-existent incident.""" + repo = IncidentRepository(db_conn) + + result = await repo.get_by_id(uuid4()) + + assert result is None + + async def test_get_by_org_returns_org_incidents(self, db_conn: asyncpg.Connection) -> None: + """get_by_org returns incidents for the organization.""" + org_id = await self._create_org(db_conn, "list-org") + service_id = await self._create_service(db_conn, org_id, "list-service") + repo = IncidentRepository(db_conn) + + await repo.create(uuid4(), org_id, service_id, "Incident 1", None, "low") + await repo.create(uuid4(), org_id, service_id, "Incident 2", None, "medium") + await repo.create(uuid4(), org_id, service_id, "Incident 3", None, "high") + + result = await repo.get_by_org(org_id) + + assert len(result) == 3 + + async def test_get_by_org_filters_by_status(self, db_conn: asyncpg.Connection) -> None: + """get_by_org can filter by status.""" + org_id = await self._create_org(db_conn, "filter-org") + service_id = await self._create_service(db_conn, org_id, "filter-service") + repo = IncidentRepository(db_conn) + + # Create incidents and transition some + inc1 = uuid4() + inc2 = uuid4() + await repo.create(inc1, org_id, service_id, "Triggered", None, "low") + await repo.create(inc2, org_id, service_id, "Will be Acked", None, "low") + await repo.update_status(inc2, "acknowledged", 1) + + result = await repo.get_by_org(org_id, status="triggered") + + assert len(result) == 1 + assert result[0]["title"] == "Triggered" + + async def test_get_by_org_pagination_with_cursor(self, db_conn: asyncpg.Connection) -> None: + """get_by_org supports cursor-based pagination.""" + org_id = await self._create_org(db_conn, "pagination-org") + service_id = await self._create_service(db_conn, org_id, "pagination-service") + repo = IncidentRepository(db_conn) + + # Create 5 incidents + for i in range(5): + await repo.create(uuid4(), org_id, service_id, f"Incident {i}", None, "low") + + # Get first page - should return limit+1 to check for more + page1 = await repo.get_by_org(org_id, limit=2) + assert len(page1) == 3 + + # Verify total is 5 when we get all + all_incidents = await repo.get_by_org(org_id, limit=10) + assert len(all_incidents) == 5 + + async def test_get_by_org_orders_by_created_at_desc(self, db_conn: asyncpg.Connection) -> None: + """get_by_org returns incidents ordered by created_at descending.""" + org_id = await self._create_org(db_conn, "order-org") + service_id = await self._create_service(db_conn, org_id, "order-service") + repo = IncidentRepository(db_conn) + + await repo.create(uuid4(), org_id, service_id, "First", None, "low") + await repo.create(uuid4(), org_id, service_id, "Second", None, "low") + await repo.create(uuid4(), org_id, service_id, "Third", None, "low") + + result = await repo.get_by_org(org_id) + + # Verify ordering - newer items should come first (or same time due to fast execution) + assert len(result) == 3 + for i in range(len(result) - 1): + assert result[i]["created_at"] >= result[i + 1]["created_at"] + + async def test_get_by_org_tenant_isolation(self, db_conn: asyncpg.Connection) -> None: + """get_by_org only returns incidents for the specified org.""" + org1 = await self._create_org(db_conn, "tenant-org-1") + org2 = await self._create_org(db_conn, "tenant-org-2") + service1 = await self._create_service(db_conn, org1, "tenant-service-1") + service2 = await self._create_service(db_conn, org2, "tenant-service-2") + repo = IncidentRepository(db_conn) + + await repo.create(uuid4(), org1, service1, "Org1 Incident", None, "low") + await repo.create(uuid4(), org2, service2, "Org2 Incident", None, "low") + + result = await repo.get_by_org(org1) + + assert len(result) == 1 + assert result[0]["title"] == "Org1 Incident" + + +class TestIncidentStatusTransitions: + """Tests for incident status transitions per SPECS.md state machine.""" + + async def _setup_incident(self, conn: asyncpg.Connection) -> tuple[uuid4, IncidentRepository]: + """Helper to create org, service, and incident.""" + org_repo = OrgRepository(conn) + service_repo = ServiceRepository(conn) + incident_repo = IncidentRepository(conn) + + org_id = uuid4() + service_id = uuid4() + incident_id = uuid4() + + await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}") + await service_repo.create(service_id, org_id, "Test Service", f"test-service-{uuid4().hex[:8]}") + await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium") + + return incident_id, incident_repo + + async def test_update_status_increments_version(self, db_conn: asyncpg.Connection) -> None: + """update_status increments version for optimistic locking.""" + incident_id, repo = await self._setup_incident(db_conn) + + result = await repo.update_status(incident_id, "acknowledged", 1) + + assert result is not None + assert result["version"] == 2 + + async def test_update_status_fails_on_version_mismatch(self, db_conn: asyncpg.Connection) -> None: + """update_status returns None on version mismatch (optimistic locking).""" + incident_id, repo = await self._setup_incident(db_conn) + + # Try with wrong version + result = await repo.update_status(incident_id, "acknowledged", 999) + + assert result is None + + async def test_update_status_updates_updated_at(self, db_conn: asyncpg.Connection) -> None: + """update_status updates the updated_at timestamp.""" + incident_id, repo = await self._setup_incident(db_conn) + + before = await repo.get_by_id(incident_id) + result = await repo.update_status(incident_id, "acknowledged", 1) + + # updated_at should be at least as recent as before (may be same in fast execution) + assert result["updated_at"] >= before["updated_at"] + # Also verify status was actually updated + assert result["status"] == "acknowledged" + + async def test_status_must_be_valid_value(self, db_conn: asyncpg.Connection) -> None: + """Status must be triggered, acknowledged, mitigated, or resolved per SPECS.md.""" + incident_id, repo = await self._setup_incident(db_conn) + + with pytest.raises(asyncpg.CheckViolationError): + await repo.update_status(incident_id, "invalid_status", 1) + + async def test_valid_status_transitions(self, db_conn: asyncpg.Connection) -> None: + """Test the valid status values per SPECS.md.""" + incident_id, repo = await self._setup_incident(db_conn) + + # Triggered -> Acknowledged + result = await repo.update_status(incident_id, "acknowledged", 1) + assert result["status"] == "acknowledged" + + # Acknowledged -> Mitigated + result = await repo.update_status(incident_id, "mitigated", 2) + assert result["status"] == "mitigated" + + # Mitigated -> Resolved + result = await repo.update_status(incident_id, "resolved", 3) + assert result["status"] == "resolved" + + +class TestIncidentEvents: + """Tests for incident events (timeline) per SPECS.md incident_events table.""" + + async def _setup_incident(self, conn: asyncpg.Connection) -> tuple[uuid4, uuid4, IncidentRepository]: + """Helper to create org, service, user, and incident.""" + org_repo = OrgRepository(conn) + service_repo = ServiceRepository(conn) + user_repo = UserRepository(conn) + incident_repo = IncidentRepository(conn) + + org_id = uuid4() + service_id = uuid4() + user_id = uuid4() + incident_id = uuid4() + + await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}") + await service_repo.create(service_id, org_id, "Test Service", f"test-svc-{uuid4().hex[:8]}") + await user_repo.create(user_id, f"user-{uuid4().hex[:8]}@example.com", "hash") + await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium") + + return incident_id, user_id, incident_repo + + async def test_add_event_creates_event(self, db_conn: asyncpg.Connection) -> None: + """add_event creates an event in the timeline.""" + incident_id, user_id, repo = await self._setup_incident(db_conn) + event_id = uuid4() + + result = await repo.add_event( + event_id, incident_id, "status_changed", + actor_user_id=user_id, + payload={"from": "triggered", "to": "acknowledged"} + ) + + assert result["id"] == event_id + assert result["incident_id"] == incident_id + assert result["event_type"] == "status_changed" + assert result["actor_user_id"] == user_id + assert result["payload"] == {"from": "triggered", "to": "acknowledged"} + assert result["created_at"] is not None + + async def test_add_event_allows_null_actor(self, db_conn: asyncpg.Connection) -> None: + """add_event allows null actor_user_id (system events).""" + incident_id, _, repo = await self._setup_incident(db_conn) + + result = await repo.add_event( + uuid4(), incident_id, "auto_escalated", + actor_user_id=None, + payload={"reason": "Unacknowledged after 30 minutes"} + ) + + assert result["actor_user_id"] is None + + async def test_add_event_allows_null_payload(self, db_conn: asyncpg.Connection) -> None: + """add_event allows null payload.""" + incident_id, user_id, repo = await self._setup_incident(db_conn) + + result = await repo.add_event( + uuid4(), incident_id, "viewed", + actor_user_id=user_id, + payload=None + ) + + assert result["payload"] is None + + async def test_get_events_returns_all_incident_events(self, db_conn: asyncpg.Connection) -> None: + """get_events returns all events for an incident.""" + incident_id, user_id, repo = await self._setup_incident(db_conn) + + await repo.add_event(uuid4(), incident_id, "created", user_id, {"title": "Test"}) + await repo.add_event(uuid4(), incident_id, "status_changed", user_id, {"to": "acked"}) + await repo.add_event(uuid4(), incident_id, "comment_added", user_id, {"text": "Working on it"}) + + result = await repo.get_events(incident_id) + + assert len(result) == 3 + event_types = [e["event_type"] for e in result] + assert event_types == ["created", "status_changed", "comment_added"] + + async def test_get_events_orders_by_created_at(self, db_conn: asyncpg.Connection) -> None: + """get_events returns events in chronological order.""" + incident_id, user_id, repo = await self._setup_incident(db_conn) + + await repo.add_event(uuid4(), incident_id, "first", user_id, None) + await repo.add_event(uuid4(), incident_id, "second", user_id, None) + await repo.add_event(uuid4(), incident_id, "third", user_id, None) + + result = await repo.get_events(incident_id) + + assert result[0]["event_type"] == "first" + assert result[1]["event_type"] == "second" + assert result[2]["event_type"] == "third" + + async def test_get_events_returns_empty_for_no_events(self, db_conn: asyncpg.Connection) -> None: + """get_events returns empty list for incident with no events.""" + incident_id, _, repo = await self._setup_incident(db_conn) + + result = await repo.get_events(incident_id) + + assert result == [] + + async def test_event_requires_valid_incident_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """incident_events.incident_id must reference existing incident.""" + incident_id, user_id, repo = await self._setup_incident(db_conn) + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await repo.add_event(uuid4(), uuid4(), "test", user_id, None) + + async def test_event_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """incident_events.actor_user_id must reference existing user if not null.""" + incident_id, _, repo = await self._setup_incident(db_conn) + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await repo.add_event(uuid4(), incident_id, "test", uuid4(), None) diff --git a/tests/repositories/test_notification.py b/tests/repositories/test_notification.py new file mode 100644 index 0000000..46ecb24 --- /dev/null +++ b/tests/repositories/test_notification.py @@ -0,0 +1,362 @@ +"""Tests for NotificationRepository.""" + +from datetime import UTC, datetime +from uuid import uuid4 + +import asyncpg +import pytest + +from app.repositories.incident import IncidentRepository +from app.repositories.notification import NotificationRepository +from app.repositories.org import OrgRepository +from app.repositories.service import ServiceRepository + + +class TestNotificationTargetRepository: + """Tests for notification targets per SPECS.md notification_targets table.""" + + async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4: + """Helper to create an org.""" + org_repo = OrgRepository(conn) + org_id = uuid4() + await org_repo.create(org_id, f"Org {slug}", slug) + return org_id + + async def test_create_target_returns_target_data(self, db_conn: asyncpg.Connection) -> None: + """Creating a notification target returns the target data.""" + org_id = await self._create_org(db_conn, "target-org") + repo = NotificationRepository(db_conn) + target_id = uuid4() + + result = await repo.create_target( + target_id, org_id, "Slack Alerts", + target_type="webhook", + webhook_url="https://hooks.slack.com/services/xxx", + enabled=True + ) + + assert result["id"] == target_id + assert result["org_id"] == org_id + assert result["name"] == "Slack Alerts" + assert result["target_type"] == "webhook" + assert result["webhook_url"] == "https://hooks.slack.com/services/xxx" + assert result["enabled"] is True + assert result["created_at"] is not None + + async def test_create_target_type_must_be_valid(self, db_conn: asyncpg.Connection) -> None: + """Target type must be webhook, email, or slack per SPECS.md.""" + org_id = await self._create_org(db_conn, "type-org") + repo = NotificationRepository(db_conn) + + # Valid types + for target_type in ["webhook", "email", "slack"]: + result = await repo.create_target( + uuid4(), org_id, f"{target_type} target", + target_type=target_type + ) + assert result["target_type"] == target_type + + # Invalid type + with pytest.raises(asyncpg.CheckViolationError): + await repo.create_target( + uuid4(), org_id, "Invalid", + target_type="sms" + ) + + async def test_create_target_webhook_url_optional(self, db_conn: asyncpg.Connection) -> None: + """webhook_url is optional (for email/slack types).""" + org_id = await self._create_org(db_conn, "optional-url-org") + repo = NotificationRepository(db_conn) + + result = await repo.create_target( + uuid4(), org_id, "Email Alerts", + target_type="email", + webhook_url=None + ) + + assert result["webhook_url"] is None + + async def test_create_target_enabled_defaults_to_true(self, db_conn: asyncpg.Connection) -> None: + """enabled defaults to True.""" + org_id = await self._create_org(db_conn, "default-enabled-org") + repo = NotificationRepository(db_conn) + + result = await repo.create_target( + uuid4(), org_id, "Default Enabled", + target_type="webhook" + ) + + assert result["enabled"] is True + + async def test_get_target_by_id_returns_target(self, db_conn: asyncpg.Connection) -> None: + """get_target_by_id returns the correct target.""" + org_id = await self._create_org(db_conn, "getbyid-target-org") + repo = NotificationRepository(db_conn) + target_id = uuid4() + + await repo.create_target(target_id, org_id, "My Target", "webhook") + result = await repo.get_target_by_id(target_id) + + assert result is not None + assert result["id"] == target_id + assert result["name"] == "My Target" + + async def test_get_target_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_target_by_id returns None for non-existent target.""" + repo = NotificationRepository(db_conn) + + result = await repo.get_target_by_id(uuid4()) + + assert result is None + + async def test_get_targets_by_org_returns_all_targets(self, db_conn: asyncpg.Connection) -> None: + """get_targets_by_org returns all targets for an organization.""" + org_id = await self._create_org(db_conn, "multi-target-org") + repo = NotificationRepository(db_conn) + + await repo.create_target(uuid4(), org_id, "Target A", "webhook") + await repo.create_target(uuid4(), org_id, "Target B", "email") + await repo.create_target(uuid4(), org_id, "Target C", "slack") + + result = await repo.get_targets_by_org(org_id) + + assert len(result) == 3 + names = {t["name"] for t in result} + assert names == {"Target A", "Target B", "Target C"} + + async def test_get_targets_by_org_filters_enabled(self, db_conn: asyncpg.Connection) -> None: + """get_targets_by_org can filter to only enabled targets.""" + org_id = await self._create_org(db_conn, "enabled-filter-org") + repo = NotificationRepository(db_conn) + + await repo.create_target(uuid4(), org_id, "Enabled", "webhook", enabled=True) + await repo.create_target(uuid4(), org_id, "Disabled", "webhook", enabled=False) + + result = await repo.get_targets_by_org(org_id, enabled_only=True) + + assert len(result) == 1 + assert result[0]["name"] == "Enabled" + + async def test_get_targets_by_org_tenant_isolation(self, db_conn: asyncpg.Connection) -> None: + """get_targets_by_org only returns targets for the specified org.""" + org1 = await self._create_org(db_conn, "isolated-target-org-1") + org2 = await self._create_org(db_conn, "isolated-target-org-2") + repo = NotificationRepository(db_conn) + + await repo.create_target(uuid4(), org1, "Org1 Target", "webhook") + await repo.create_target(uuid4(), org2, "Org2 Target", "webhook") + + result = await repo.get_targets_by_org(org1) + + assert len(result) == 1 + assert result[0]["name"] == "Org1 Target" + + async def test_update_target_updates_fields(self, db_conn: asyncpg.Connection) -> None: + """update_target updates the specified fields.""" + org_id = await self._create_org(db_conn, "update-target-org") + repo = NotificationRepository(db_conn) + target_id = uuid4() + + await repo.create_target(target_id, org_id, "Original", "webhook", enabled=True) + result = await repo.update_target(target_id, name="Updated", enabled=False) + + assert result is not None + assert result["name"] == "Updated" + assert result["enabled"] is False + + async def test_update_target_partial_update(self, db_conn: asyncpg.Connection) -> None: + """update_target only updates provided fields.""" + org_id = await self._create_org(db_conn, "partial-update-org") + repo = NotificationRepository(db_conn) + target_id = uuid4() + + await repo.create_target( + target_id, org_id, "Original Name", "webhook", + webhook_url="https://original.com", enabled=True + ) + result = await repo.update_target(target_id, name="New Name") + + assert result["name"] == "New Name" + assert result["webhook_url"] == "https://original.com" + assert result["enabled"] is True + + async def test_delete_target_removes_target(self, db_conn: asyncpg.Connection) -> None: + """delete_target removes the target.""" + org_id = await self._create_org(db_conn, "delete-target-org") + repo = NotificationRepository(db_conn) + target_id = uuid4() + + await repo.create_target(target_id, org_id, "To Delete", "webhook") + result = await repo.delete_target(target_id) + + assert result is True + assert await repo.get_target_by_id(target_id) is None + + async def test_delete_target_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """delete_target returns False for non-existent target.""" + repo = NotificationRepository(db_conn) + + result = await repo.delete_target(uuid4()) + + assert result is False + + async def test_target_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """notification_targets.org_id must reference existing org.""" + repo = NotificationRepository(db_conn) + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await repo.create_target(uuid4(), uuid4(), "Orphan Target", "webhook") + + +class TestNotificationAttemptRepository: + """Tests for notification attempts per SPECS.md notification_attempts table.""" + + async def _setup_incident_and_target(self, conn: asyncpg.Connection) -> tuple[uuid4, uuid4, NotificationRepository]: + """Helper to create org, service, incident, and notification target.""" + org_repo = OrgRepository(conn) + service_repo = ServiceRepository(conn) + incident_repo = IncidentRepository(conn) + notification_repo = NotificationRepository(conn) + + org_id = uuid4() + service_id = uuid4() + incident_id = uuid4() + target_id = uuid4() + + await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}") + await service_repo.create(service_id, org_id, "Test Service", f"test-svc-{uuid4().hex[:8]}") + await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium") + await notification_repo.create_target(target_id, org_id, "Test Target", "webhook") + + return incident_id, target_id, notification_repo + + async def test_create_attempt_returns_attempt_data(self, db_conn: asyncpg.Connection) -> None: + """Creating a notification attempt returns the attempt data.""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + attempt_id = uuid4() + + result = await repo.create_attempt(attempt_id, incident_id, target_id) + + assert result["id"] == attempt_id + assert result["incident_id"] == incident_id + assert result["target_id"] == target_id + assert result["status"] == "pending" + assert result["error"] is None + assert result["sent_at"] is None + assert result["created_at"] is not None + + async def test_create_attempt_idempotent(self, db_conn: asyncpg.Connection) -> None: + """create_attempt is idempotent per SPECS.md (unique constraint on incident+target).""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + + # First attempt + result1 = await repo.create_attempt(uuid4(), incident_id, target_id) + # Second attempt with same incident+target + result2 = await repo.create_attempt(uuid4(), incident_id, target_id) + + # Should return the same attempt + assert result1["id"] == result2["id"] + + async def test_get_attempt_returns_attempt(self, db_conn: asyncpg.Connection) -> None: + """get_attempt returns the attempt for incident and target.""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + + await repo.create_attempt(uuid4(), incident_id, target_id) + result = await repo.get_attempt(incident_id, target_id) + + assert result is not None + assert result["incident_id"] == incident_id + assert result["target_id"] == target_id + + async def test_get_attempt_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_attempt returns None for non-existent attempt.""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + + result = await repo.get_attempt(incident_id, target_id) + + assert result is None + + async def test_update_attempt_success_sets_sent_status(self, db_conn: asyncpg.Connection) -> None: + """update_attempt_success marks attempt as sent.""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + attempt = await repo.create_attempt(uuid4(), incident_id, target_id) + sent_at = datetime.now(UTC) + + result = await repo.update_attempt_success(attempt["id"], sent_at) + + assert result is not None + assert result["status"] == "sent" + assert result["sent_at"] is not None + assert result["error"] is None + + async def test_update_attempt_failure_sets_failed_status(self, db_conn: asyncpg.Connection) -> None: + """update_attempt_failure marks attempt as failed with error.""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + attempt = await repo.create_attempt(uuid4(), incident_id, target_id) + + result = await repo.update_attempt_failure(attempt["id"], "Connection timeout") + + assert result is not None + assert result["status"] == "failed" + assert result["error"] == "Connection timeout" + + async def test_attempt_status_must_be_valid(self, db_conn: asyncpg.Connection) -> None: + """Attempt status must be pending, sent, or failed per SPECS.md.""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + + # Create with default 'pending' status - valid + result = await repo.create_attempt(uuid4(), incident_id, target_id) + assert result["status"] == "pending" + + # Transition to 'sent' - valid + result = await repo.update_attempt_success(result["id"], datetime.now(UTC)) + assert result["status"] == "sent" + + async def test_get_pending_attempts_returns_pending_with_target_info(self, db_conn: asyncpg.Connection) -> None: + """get_pending_attempts returns pending attempts with target details.""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + + await repo.create_attempt(uuid4(), incident_id, target_id) + result = await repo.get_pending_attempts(incident_id) + + assert len(result) == 1 + assert result[0]["status"] == "pending" + assert result[0]["target_id"] == target_id + assert "target_type" in result[0] + assert "target_name" in result[0] + + async def test_get_pending_attempts_excludes_sent(self, db_conn: asyncpg.Connection) -> None: + """get_pending_attempts excludes sent attempts.""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + + attempt = await repo.create_attempt(uuid4(), incident_id, target_id) + await repo.update_attempt_success(attempt["id"], datetime.now(UTC)) + + result = await repo.get_pending_attempts(incident_id) + + assert len(result) == 0 + + async def test_get_pending_attempts_excludes_failed(self, db_conn: asyncpg.Connection) -> None: + """get_pending_attempts excludes failed attempts.""" + incident_id, target_id, repo = await self._setup_incident_and_target(db_conn) + + attempt = await repo.create_attempt(uuid4(), incident_id, target_id) + await repo.update_attempt_failure(attempt["id"], "Error") + + result = await repo.get_pending_attempts(incident_id) + + assert len(result) == 0 + + async def test_attempt_requires_valid_incident_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """notification_attempts.incident_id must reference existing incident.""" + _, target_id, repo = await self._setup_incident_and_target(db_conn) + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await repo.create_attempt(uuid4(), uuid4(), target_id) + + async def test_attempt_requires_valid_target_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """notification_attempts.target_id must reference existing target.""" + incident_id, _, repo = await self._setup_incident_and_target(db_conn) + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await repo.create_attempt(uuid4(), incident_id, uuid4()) diff --git a/tests/repositories/test_org.py b/tests/repositories/test_org.py new file mode 100644 index 0000000..3650c66 --- /dev/null +++ b/tests/repositories/test_org.py @@ -0,0 +1,250 @@ +"""Tests for OrgRepository.""" + +from uuid import uuid4 + +import asyncpg +import pytest + +from app.repositories.org import OrgRepository +from app.repositories.user import UserRepository + + +class TestOrgRepository: + """Tests for OrgRepository conforming to SPECS.md.""" + + async def test_create_org_returns_org_data(self, db_conn: asyncpg.Connection) -> None: + """Creating an org returns the org data.""" + repo = OrgRepository(db_conn) + org_id = uuid4() + name = "Test Organization" + slug = "test-org" + + result = await repo.create(org_id, name, slug) + + assert result["id"] == org_id + assert result["name"] == name + assert result["slug"] == slug + assert result["created_at"] is not None + + async def test_create_org_slug_must_be_unique(self, db_conn: asyncpg.Connection) -> None: + """Org slug uniqueness constraint per SPECS.md orgs table.""" + repo = OrgRepository(db_conn) + slug = "unique-slug" + + await repo.create(uuid4(), "Org One", slug) + + with pytest.raises(asyncpg.UniqueViolationError): + await repo.create(uuid4(), "Org Two", slug) + + async def test_get_by_id_returns_org(self, db_conn: asyncpg.Connection) -> None: + """get_by_id returns the correct organization.""" + repo = OrgRepository(db_conn) + org_id = uuid4() + + await repo.create(org_id, "My Org", "my-org") + result = await repo.get_by_id(org_id) + + assert result is not None + assert result["id"] == org_id + assert result["name"] == "My Org" + + async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_by_id returns None for non-existent org.""" + repo = OrgRepository(db_conn) + + result = await repo.get_by_id(uuid4()) + + assert result is None + + async def test_get_by_slug_returns_org(self, db_conn: asyncpg.Connection) -> None: + """get_by_slug returns the correct organization.""" + repo = OrgRepository(db_conn) + org_id = uuid4() + slug = "slug-lookup" + + await repo.create(org_id, "Slug Test", slug) + result = await repo.get_by_slug(slug) + + assert result is not None + assert result["id"] == org_id + + async def test_get_by_slug_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_by_slug returns None for non-existent slug.""" + repo = OrgRepository(db_conn) + + result = await repo.get_by_slug("nonexistent-slug") + + assert result is None + + async def test_slug_exists_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None: + """slug_exists returns True when slug exists.""" + repo = OrgRepository(db_conn) + slug = "existing-slug" + + await repo.create(uuid4(), "Existing Org", slug) + result = await repo.slug_exists(slug) + + assert result is True + + async def test_slug_exists_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None: + """slug_exists returns False when slug doesn't exist.""" + repo = OrgRepository(db_conn) + + result = await repo.slug_exists("no-such-slug") + + assert result is False + + +class TestOrgMembership: + """Tests for org membership operations per SPECS.md org_members table.""" + + async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4: + """Helper to create a user.""" + user_repo = UserRepository(conn) + user_id = uuid4() + await user_repo.create(user_id, email, "hash") + return user_id + + async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4: + """Helper to create an org.""" + org_repo = OrgRepository(conn) + org_id = uuid4() + await org_repo.create(org_id, f"Org {slug}", slug) + return org_id + + async def test_add_member_creates_membership(self, db_conn: asyncpg.Connection) -> None: + """add_member creates a membership record.""" + user_id = await self._create_user(db_conn, "member@example.com") + org_id = await self._create_org(db_conn, "member-org") + repo = OrgRepository(db_conn) + + result = await repo.add_member(uuid4(), user_id, org_id, "member") + + assert result["user_id"] == user_id + assert result["org_id"] == org_id + assert result["role"] == "member" + assert result["created_at"] is not None + + async def test_add_member_role_must_be_valid(self, db_conn: asyncpg.Connection) -> None: + """Role must be admin, member, or viewer per SPECS.md.""" + org_id = await self._create_org(db_conn, "role-test-org") + repo = OrgRepository(db_conn) + + # Valid roles should work + for role in ["admin", "member", "viewer"]: + member_id = uuid4() + # Need a new user for each since user+org must be unique + new_user_id = await self._create_user(db_conn, f"{role}@example.com") + result = await repo.add_member(member_id, new_user_id, org_id, role) + assert result["role"] == role + + # Invalid role should fail + another_user = await self._create_user(db_conn, "invalid_role@example.com") + with pytest.raises(asyncpg.CheckViolationError): + await repo.add_member(uuid4(), another_user, org_id, "superuser") + + async def test_add_member_user_org_must_be_unique(self, db_conn: asyncpg.Connection) -> None: + """User can only be member of an org once (unique constraint).""" + user_id = await self._create_user(db_conn, "unique_member@example.com") + org_id = await self._create_org(db_conn, "unique-member-org") + repo = OrgRepository(db_conn) + + await repo.add_member(uuid4(), user_id, org_id, "member") + + with pytest.raises(asyncpg.UniqueViolationError): + await repo.add_member(uuid4(), user_id, org_id, "admin") + + async def test_get_member_returns_membership(self, db_conn: asyncpg.Connection) -> None: + """get_member returns the membership for user and org.""" + user_id = await self._create_user(db_conn, "get_member@example.com") + org_id = await self._create_org(db_conn, "get-member-org") + repo = OrgRepository(db_conn) + + await repo.add_member(uuid4(), user_id, org_id, "admin") + result = await repo.get_member(user_id, org_id) + + assert result is not None + assert result["user_id"] == user_id + assert result["org_id"] == org_id + assert result["role"] == "admin" + + async def test_get_member_returns_none_for_nonmember(self, db_conn: asyncpg.Connection) -> None: + """get_member returns None if user is not a member.""" + user_id = await self._create_user(db_conn, "nonmember@example.com") + org_id = await self._create_org(db_conn, "nonmember-org") + repo = OrgRepository(db_conn) + + result = await repo.get_member(user_id, org_id) + + assert result is None + + async def test_get_members_returns_all_org_members(self, db_conn: asyncpg.Connection) -> None: + """get_members returns all members with their emails.""" + org_id = await self._create_org(db_conn, "all-members-org") + user1 = await self._create_user(db_conn, "user1@example.com") + user2 = await self._create_user(db_conn, "user2@example.com") + user3 = await self._create_user(db_conn, "user3@example.com") + repo = OrgRepository(db_conn) + + await repo.add_member(uuid4(), user1, org_id, "admin") + await repo.add_member(uuid4(), user2, org_id, "member") + await repo.add_member(uuid4(), user3, org_id, "viewer") + + result = await repo.get_members(org_id) + + assert len(result) == 3 + emails = {m["email"] for m in result} + assert emails == {"user1@example.com", "user2@example.com", "user3@example.com"} + + async def test_get_members_returns_empty_list_for_no_members(self, db_conn: asyncpg.Connection) -> None: + """get_members returns empty list for org with no members.""" + org_id = await self._create_org(db_conn, "empty-org") + repo = OrgRepository(db_conn) + + result = await repo.get_members(org_id) + + assert result == [] + + async def test_get_user_orgs_returns_all_user_memberships(self, db_conn: asyncpg.Connection) -> None: + """get_user_orgs returns all orgs a user belongs to with their role.""" + user_id = await self._create_user(db_conn, "multi_org@example.com") + org1 = await self._create_org(db_conn, "user-org-1") + org2 = await self._create_org(db_conn, "user-org-2") + repo = OrgRepository(db_conn) + + await repo.add_member(uuid4(), user_id, org1, "admin") + await repo.add_member(uuid4(), user_id, org2, "member") + + result = await repo.get_user_orgs(user_id) + + assert len(result) == 2 + slugs = {o["slug"] for o in result} + assert slugs == {"user-org-1", "user-org-2"} + # Check role is included + roles = {o["role"] for o in result} + assert roles == {"admin", "member"} + + async def test_get_user_orgs_returns_empty_for_no_memberships(self, db_conn: asyncpg.Connection) -> None: + """get_user_orgs returns empty list for user with no memberships.""" + user_id = await self._create_user(db_conn, "no_orgs@example.com") + repo = OrgRepository(db_conn) + + result = await repo.get_user_orgs(user_id) + + assert result == [] + + async def test_member_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """org_members.user_id must reference existing user.""" + org_id = await self._create_org(db_conn, "fk-test-org") + repo = OrgRepository(db_conn) + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await repo.add_member(uuid4(), uuid4(), org_id, "member") + + async def test_member_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """org_members.org_id must reference existing org.""" + user_id = await self._create_user(db_conn, "fk_user@example.com") + repo = OrgRepository(db_conn) + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await repo.add_member(uuid4(), user_id, uuid4(), "member") diff --git a/tests/repositories/test_refresh_token.py b/tests/repositories/test_refresh_token.py new file mode 100644 index 0000000..e3e06b1 --- /dev/null +++ b/tests/repositories/test_refresh_token.py @@ -0,0 +1,788 @@ +"""Tests for RefreshTokenRepository with security features.""" + +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + +import asyncpg +import pytest + +from app.repositories.org import OrgRepository +from app.repositories.refresh_token import RefreshTokenRepository +from app.repositories.user import UserRepository + + +class TestRefreshTokenRepository: + """Tests for basic RefreshTokenRepository operations.""" + + async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4: + """Helper to create a user.""" + user_repo = UserRepository(conn) + user_id = uuid4() + await user_repo.create(user_id, email, "hash") + return user_id + + async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4: + """Helper to create an org.""" + org_repo = OrgRepository(conn) + org_id = uuid4() + await org_repo.create(org_id, f"Org {slug}", slug) + return org_id + + async def test_create_token_returns_token_data(self, db_conn: asyncpg.Connection) -> None: + """Creating a refresh token returns the token data including rotated_to.""" + user_id = await self._create_user(db_conn, "token_create@example.com") + org_id = await self._create_org(db_conn, "token-create-org") + repo = RefreshTokenRepository(db_conn) + token_id = uuid4() + token_hash = "sha256_hashed_token_value" + expires_at = datetime.now(UTC) + timedelta(days=30) + + result = await repo.create(token_id, user_id, token_hash, org_id, expires_at) + + assert result["id"] == token_id + assert result["user_id"] == user_id + assert result["token_hash"] == token_hash + assert result["active_org_id"] == org_id + assert result["expires_at"] is not None + assert result["revoked_at"] is None + assert result["rotated_to"] is None # New field + assert result["created_at"] is not None + + async def test_token_hash_must_be_unique(self, db_conn: asyncpg.Connection) -> None: + """Token hash uniqueness constraint per SPECS.md refresh_tokens table.""" + user_id = await self._create_user(db_conn, "unique_hash@example.com") + org_id = await self._create_org(db_conn, "unique-hash-org") + repo = RefreshTokenRepository(db_conn) + token_hash = "duplicate_hash_value" + expires_at = datetime.now(UTC) + timedelta(days=30) + + await repo.create(uuid4(), user_id, token_hash, org_id, expires_at) + + with pytest.raises(asyncpg.UniqueViolationError): + await repo.create(uuid4(), user_id, token_hash, org_id, expires_at) + + async def test_get_by_hash_returns_token(self, db_conn: asyncpg.Connection) -> None: + """get_by_hash returns the correct token (even if revoked/expired).""" + user_id = await self._create_user(db_conn, "get_hash@example.com") + org_id = await self._create_org(db_conn, "get-hash-org") + repo = RefreshTokenRepository(db_conn) + token_id = uuid4() + token_hash = "lookup_by_hash_value" + expires_at = datetime.now(UTC) + timedelta(days=30) + + await repo.create(token_id, user_id, token_hash, org_id, expires_at) + result = await repo.get_by_hash(token_hash) + + assert result is not None + assert result["id"] == token_id + assert result["token_hash"] == token_hash + + async def test_get_by_hash_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_by_hash returns None for non-existent hash.""" + repo = RefreshTokenRepository(db_conn) + + result = await repo.get_by_hash("nonexistent_hash") + + assert result is None + + +class TestGetValidByHash: + """Tests for get_valid_by_hash with defense-in-depth validation.""" + + async def _setup_token( + self, conn: asyncpg.Connection, suffix: str = "" + ) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]: + """Helper to create user, org, and token.""" + user_repo = UserRepository(conn) + org_repo = OrgRepository(conn) + token_repo = RefreshTokenRepository(conn) + + user_id = uuid4() + org_id = uuid4() + token_id = uuid4() + token_hash = f"token_hash_{uuid4().hex[:8]}{suffix}" + + await user_repo.create(user_id, f"user_{uuid4().hex[:8]}@example.com", "hash") + await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}") + + expires_at = datetime.now(UTC) + timedelta(days=30) + await token_repo.create(token_id, user_id, token_hash, org_id, expires_at) + + return token_id, user_id, org_id, token_hash, token_repo + + async def test_get_valid_returns_valid_token(self, db_conn: asyncpg.Connection) -> None: + """get_valid_by_hash returns token if not expired, not revoked, not rotated.""" + _, _, _, token_hash, repo = await self._setup_token(db_conn) + + result = await repo.get_valid_by_hash(token_hash) + + assert result is not None + assert result["token_hash"] == token_hash + + async def test_get_valid_returns_none_for_expired(self, db_conn: asyncpg.Connection) -> None: + """get_valid_by_hash returns None for expired token.""" + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, "expired@example.com", "hash") + await org_repo.create(org_id, "Org", "expired-org") + + token_hash = "expired_token_hash" + expires_at = datetime.now(UTC) - timedelta(days=1) # Already expired + await repo.create(uuid4(), user_id, token_hash, org_id, expires_at) + + result = await repo.get_valid_by_hash(token_hash) + + assert result is None + + async def test_get_valid_returns_none_for_revoked(self, db_conn: asyncpg.Connection) -> None: + """get_valid_by_hash returns None for revoked token.""" + token_id, _, _, token_hash, repo = await self._setup_token(db_conn, "_revoked") + + await repo.revoke(token_id) + result = await repo.get_valid_by_hash(token_hash) + + assert result is None + + async def test_get_valid_returns_none_for_rotated(self, db_conn: asyncpg.Connection) -> None: + """get_valid_by_hash returns None for already-rotated token.""" + _, user_id, org_id, old_hash, repo = await self._setup_token(db_conn, "_rotated") + + # Rotate the token + new_hash = f"new_token_{uuid4().hex[:8]}" + new_expires = datetime.now(UTC) + timedelta(days=30) + await repo.rotate(old_hash, uuid4(), new_hash, new_expires) + + # Old token should no longer be valid + result = await repo.get_valid_by_hash(old_hash) + assert result is None + + async def test_get_valid_with_user_id_validation(self, db_conn: asyncpg.Connection) -> None: + """get_valid_by_hash validates user_id when provided (defense-in-depth).""" + _, user_id, _, token_hash, repo = await self._setup_token(db_conn, "_user_check") + + # Correct user_id should work + result = await repo.get_valid_by_hash(token_hash, user_id=user_id) + assert result is not None + + # Wrong user_id should return None + wrong_user_id = uuid4() + result = await repo.get_valid_by_hash(token_hash, user_id=wrong_user_id) + assert result is None + + async def test_get_valid_with_org_id_validation(self, db_conn: asyncpg.Connection) -> None: + """get_valid_by_hash validates active_org_id when provided (defense-in-depth).""" + _, _, org_id, token_hash, repo = await self._setup_token(db_conn, "_org_check") + + # Correct org_id should work + result = await repo.get_valid_by_hash(token_hash, active_org_id=org_id) + assert result is not None + + # Wrong org_id should return None + wrong_org_id = uuid4() + result = await repo.get_valid_by_hash(token_hash, active_org_id=wrong_org_id) + assert result is None + + async def test_get_valid_with_both_user_and_org_validation(self, db_conn: asyncpg.Connection) -> None: + """get_valid_by_hash validates both user_id and active_org_id together.""" + _, user_id, org_id, token_hash, repo = await self._setup_token(db_conn, "_both") + + # Both correct should work + result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=org_id) + assert result is not None + + # Either wrong should fail + result = await repo.get_valid_by_hash(token_hash, user_id=uuid4(), active_org_id=org_id) + assert result is None + + result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=uuid4()) + assert result is None + + +class TestAtomicRotation: + """Tests for atomic token rotation per SPECS.md.""" + + async def _setup_token( + self, conn: asyncpg.Connection + ) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]: + """Helper to create user, org, and token.""" + user_repo = UserRepository(conn) + org_repo = OrgRepository(conn) + token_repo = RefreshTokenRepository(conn) + + user_id = uuid4() + org_id = uuid4() + token_id = uuid4() + token_hash = f"rotate_token_{uuid4().hex[:8]}" + + await user_repo.create(user_id, f"rotate_{uuid4().hex[:8]}@example.com", "hash") + await org_repo.create(org_id, "Rotate Org", f"rotate-org-{uuid4().hex[:8]}") + + expires_at = datetime.now(UTC) + timedelta(days=30) + await token_repo.create(token_id, user_id, token_hash, org_id, expires_at) + + return token_id, user_id, org_id, token_hash, token_repo + + async def test_rotate_creates_new_token(self, db_conn: asyncpg.Connection) -> None: + """rotate() creates a new token and returns it.""" + old_id, user_id, org_id, old_hash, repo = await self._setup_token(db_conn) + + new_id = uuid4() + new_hash = f"new_rotated_{uuid4().hex[:8]}" + new_expires = datetime.now(UTC) + timedelta(days=30) + + result = await repo.rotate(old_hash, new_id, new_hash, new_expires) + + assert result is not None + assert result["id"] == new_id + assert result["token_hash"] == new_hash + assert result["user_id"] == user_id + assert result["active_org_id"] == org_id + + async def test_rotate_marks_old_token_as_rotated(self, db_conn: asyncpg.Connection) -> None: + """rotate() sets rotated_to on the old token (not revoked_at).""" + old_id, _, _, old_hash, repo = await self._setup_token(db_conn) + + new_id = uuid4() + new_hash = f"new_{uuid4().hex[:8]}" + new_expires = datetime.now(UTC) + timedelta(days=30) + + await repo.rotate(old_hash, new_id, new_hash, new_expires) + + # Check old token state + old_token = await repo.get_by_hash(old_hash) + assert old_token["rotated_to"] == new_id + assert old_token["revoked_at"] is None # Not revoked, just rotated + + async def test_rotate_fails_for_invalid_token(self, db_conn: asyncpg.Connection) -> None: + """rotate() returns None if old token is invalid.""" + _, _, _, _, repo = await self._setup_token(db_conn) + + result = await repo.rotate( + "nonexistent_hash", + uuid4(), + f"new_{uuid4().hex[:8]}", + datetime.now(UTC) + timedelta(days=30), + ) + + assert result is None + + async def test_rotate_fails_for_expired_token(self, db_conn: asyncpg.Connection) -> None: + """rotate() returns None if old token is expired.""" + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, "exp_rotate@example.com", "hash") + await org_repo.create(org_id, "Org", "exp-rotate-org") + + old_hash = "expired_for_rotation" + await repo.create( + uuid4(), user_id, old_hash, org_id, + datetime.now(UTC) - timedelta(days=1) # Already expired + ) + + result = await repo.rotate( + old_hash, uuid4(), f"new_{uuid4().hex[:8]}", + datetime.now(UTC) + timedelta(days=30) + ) + + assert result is None + + async def test_rotate_fails_for_revoked_token(self, db_conn: asyncpg.Connection) -> None: + """rotate() returns None if old token is revoked.""" + old_id, _, _, old_hash, repo = await self._setup_token(db_conn) + + await repo.revoke(old_id) + + result = await repo.rotate( + old_hash, uuid4(), f"new_{uuid4().hex[:8]}", + datetime.now(UTC) + timedelta(days=30) + ) + + assert result is None + + async def test_rotate_fails_for_already_rotated_token(self, db_conn: asyncpg.Connection) -> None: + """rotate() returns None if old token was already rotated.""" + _, _, _, old_hash, repo = await self._setup_token(db_conn) + + # First rotation should succeed + result1 = await repo.rotate( + old_hash, uuid4(), f"new1_{uuid4().hex[:8]}", + datetime.now(UTC) + timedelta(days=30) + ) + assert result1 is not None + + # Second rotation of same token should fail + result2 = await repo.rotate( + old_hash, uuid4(), f"new2_{uuid4().hex[:8]}", + datetime.now(UTC) + timedelta(days=30) + ) + assert result2 is None + + async def test_rotate_with_org_switch(self, db_conn: asyncpg.Connection) -> None: + """rotate() can change active_org_id (for org-switch flow).""" + _, user_id, old_org_id, old_hash, repo = await self._setup_token(db_conn) + + # Create a new org for the user to switch to + org_repo = OrgRepository(db_conn) + new_org_id = uuid4() + await org_repo.create(new_org_id, "New Org", f"new-org-{uuid4().hex[:8]}") + + new_hash = f"switched_{uuid4().hex[:8]}" + result = await repo.rotate( + old_hash, uuid4(), new_hash, + datetime.now(UTC) + timedelta(days=30), + new_active_org_id=new_org_id # Switch org + ) + + assert result is not None + assert result["active_org_id"] == new_org_id + assert result["active_org_id"] != old_org_id + + async def test_rotate_validates_expected_user_id(self, db_conn: asyncpg.Connection) -> None: + """rotate() fails if expected_user_id doesn't match token's user.""" + _, user_id, _, old_hash, repo = await self._setup_token(db_conn) + + # Wrong user should fail + result = await repo.rotate( + old_hash, uuid4(), f"new_{uuid4().hex[:8]}", + datetime.now(UTC) + timedelta(days=30), + expected_user_id=uuid4() # Wrong user + ) + assert result is None + + # Correct user should work + result = await repo.rotate( + old_hash, uuid4(), f"new_{uuid4().hex[:8]}", + datetime.now(UTC) + timedelta(days=30), + expected_user_id=user_id # Correct user + ) + assert result is not None + + +class TestTokenReuseDetection: + """Tests for detecting token reuse (stolen token attacks).""" + + async def _setup_rotated_token( + self, conn: asyncpg.Connection + ) -> tuple[uuid4, str, str, RefreshTokenRepository]: + """Create a token and rotate it, returning old and new hashes.""" + user_repo = UserRepository(conn) + org_repo = OrgRepository(conn) + token_repo = RefreshTokenRepository(conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, f"reuse_{uuid4().hex[:8]}@example.com", "hash") + await org_repo.create(org_id, "Reuse Org", f"reuse-org-{uuid4().hex[:8]}") + + old_hash = f"old_token_{uuid4().hex[:8]}" + expires_at = datetime.now(UTC) + timedelta(days=30) + old_token = await token_repo.create(uuid4(), user_id, old_hash, org_id, expires_at) + + new_hash = f"new_token_{uuid4().hex[:8]}" + await token_repo.rotate(old_hash, uuid4(), new_hash, expires_at) + + return old_token["id"], old_hash, new_hash, token_repo + + async def test_check_token_reuse_detects_rotated_token(self, db_conn: asyncpg.Connection) -> None: + """check_token_reuse returns token if it has been rotated.""" + old_id, old_hash, _, repo = await self._setup_rotated_token(db_conn) + + result = await repo.check_token_reuse(old_hash) + + assert result is not None + assert result["id"] == old_id + assert result["rotated_to"] is not None + + async def test_check_token_reuse_returns_none_for_active_token(self, db_conn: asyncpg.Connection) -> None: + """check_token_reuse returns None for token that hasn't been rotated.""" + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, "active@example.com", "hash") + await org_repo.create(org_id, "Org", "active-org") + + token_hash = "active_token_hash" + await repo.create( + uuid4(), user_id, token_hash, org_id, + datetime.now(UTC) + timedelta(days=30) + ) + + result = await repo.check_token_reuse(token_hash) + + assert result is None + + async def test_check_token_reuse_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """check_token_reuse returns None for non-existent token.""" + repo = RefreshTokenRepository(db_conn) + + result = await repo.check_token_reuse("nonexistent_hash") + + assert result is None + + +class TestTokenChainRevocation: + """Tests for revoking entire token chains (breach response).""" + + async def _setup_token_chain( + self, conn: asyncpg.Connection, chain_length: int = 3 + ) -> tuple[list[uuid4], list[str], uuid4, RefreshTokenRepository]: + """Create a chain of rotated tokens.""" + user_repo = UserRepository(conn) + org_repo = OrgRepository(conn) + token_repo = RefreshTokenRepository(conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, f"chain_{uuid4().hex[:8]}@example.com", "hash") + await org_repo.create(org_id, "Chain Org", f"chain-org-{uuid4().hex[:8]}") + + token_ids = [] + token_hashes = [] + expires_at = datetime.now(UTC) + timedelta(days=30) + + # Create first token + first_hash = f"chain_token_0_{uuid4().hex[:8]}" + first_token = await token_repo.create(uuid4(), user_id, first_hash, org_id, expires_at) + token_ids.append(first_token["id"]) + token_hashes.append(first_hash) + + # Rotate to create chain + current_hash = first_hash + for i in range(1, chain_length): + new_hash = f"chain_token_{i}_{uuid4().hex[:8]}" + new_id = uuid4() + await token_repo.rotate(current_hash, new_id, new_hash, expires_at) + token_ids.append(new_id) + token_hashes.append(new_hash) + current_hash = new_hash + + return token_ids, token_hashes, user_id, token_repo + + async def test_revoke_token_chain_revokes_all_in_chain(self, db_conn: asyncpg.Connection) -> None: + """revoke_token_chain revokes the token and all its rotations.""" + token_ids, token_hashes, _, repo = await self._setup_token_chain(db_conn, chain_length=3) + + # Revoke starting from the first token + count = await repo.revoke_token_chain(token_ids[0]) + + # Should revoke all 3 tokens in the chain + # But note: only the last one wasn't already "consumed" by rotation + # Let's check that revoke was called on all that were eligible + assert count >= 1 # At least the leaf token + + # Verify the leaf token is revoked + leaf_token = await repo.get_by_hash(token_hashes[-1]) + assert leaf_token["revoked_at"] is not None + + async def test_revoke_token_chain_returns_count(self, db_conn: asyncpg.Connection) -> None: + """revoke_token_chain returns count of actually revoked tokens.""" + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, "single@example.com", "hash") + await org_repo.create(org_id, "Single Org", "single-org") + + token_hash = "single_chain_token" + token = await repo.create( + uuid4(), user_id, token_hash, org_id, + datetime.now(UTC) + timedelta(days=30) + ) + + count = await repo.revoke_token_chain(token["id"]) + + assert count == 1 + + +class TestTokenRevocation: + """Tests for token revocation methods.""" + + async def _setup_token(self, conn: asyncpg.Connection) -> tuple[uuid4, str, RefreshTokenRepository]: + """Helper to create user, org, and token.""" + user_repo = UserRepository(conn) + org_repo = OrgRepository(conn) + token_repo = RefreshTokenRepository(conn) + + user_id = uuid4() + org_id = uuid4() + token_id = uuid4() + token_hash = f"revoke_token_{uuid4().hex[:8]}" + + await user_repo.create(user_id, f"revoke_{uuid4().hex[:8]}@example.com", "hash") + await org_repo.create(org_id, "Revoke Org", f"revoke-org-{uuid4().hex[:8]}") + + expires_at = datetime.now(UTC) + timedelta(days=30) + await token_repo.create(token_id, user_id, token_hash, org_id, expires_at) + + return token_id, token_hash, token_repo + + async def test_revoke_sets_revoked_at(self, db_conn: asyncpg.Connection) -> None: + """revoke() sets the revoked_at timestamp.""" + token_id, token_hash, repo = await self._setup_token(db_conn) + + result = await repo.revoke(token_id) + + assert result is True + token = await repo.get_by_hash(token_hash) + assert token["revoked_at"] is not None + + async def test_revoke_returns_true_on_success(self, db_conn: asyncpg.Connection) -> None: + """revoke() returns True when token is revoked.""" + token_id, _, repo = await self._setup_token(db_conn) + + result = await repo.revoke(token_id) + + assert result is True + + async def test_revoke_returns_false_for_already_revoked(self, db_conn: asyncpg.Connection) -> None: + """revoke() returns False if token already revoked.""" + token_id, _, repo = await self._setup_token(db_conn) + + await repo.revoke(token_id) + result = await repo.revoke(token_id) + + assert result is False + + async def test_revoke_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """revoke() returns False for non-existent token.""" + _, _, repo = await self._setup_token(db_conn) + + result = await repo.revoke(uuid4()) + + assert result is False + + async def test_revoke_by_hash_works(self, db_conn: asyncpg.Connection) -> None: + """revoke_by_hash() revokes token by hash value.""" + _, token_hash, repo = await self._setup_token(db_conn) + + result = await repo.revoke_by_hash(token_hash) + + assert result is True + token = await repo.get_by_hash(token_hash) + assert token["revoked_at"] is not None + + async def test_revoke_by_hash_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """revoke_by_hash() returns False for non-existent hash.""" + _, _, repo = await self._setup_token(db_conn) + + result = await repo.revoke_by_hash("nonexistent_hash") + + assert result is False + + +class TestRevokeAllForUser: + """Tests for revoking all tokens for a user.""" + + async def test_revoke_all_for_user_revokes_all_tokens(self, db_conn: asyncpg.Connection) -> None: + """revoke_all_for_user() revokes all tokens for the user.""" + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + token_repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, "multi_token@example.com", "hash") + await org_repo.create(org_id, "Multi Token Org", "multi-token-org") + + # Create multiple tokens + hashes = [] + for i in range(3): + token_hash = f"token_{i}_{uuid4().hex[:8]}" + hashes.append(token_hash) + expires_at = datetime.now(UTC) + timedelta(days=30) + await token_repo.create(uuid4(), user_id, token_hash, org_id, expires_at) + + result = await token_repo.revoke_all_for_user(user_id) + + assert result == 3 + for token_hash in hashes: + token = await token_repo.get_valid_by_hash(token_hash) + assert token is None + + async def test_revoke_all_for_user_returns_zero_for_no_tokens(self, db_conn: asyncpg.Connection) -> None: + """revoke_all_for_user() returns 0 if user has no tokens.""" + user_repo = UserRepository(db_conn) + token_repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + await user_repo.create(user_id, "no_tokens@example.com", "hash") + + result = await token_repo.revoke_all_for_user(user_id) + + assert result == 0 + + async def test_revoke_all_for_user_only_affects_user_tokens(self, db_conn: asyncpg.Connection) -> None: + """revoke_all_for_user() doesn't affect other users' tokens.""" + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + token_repo = RefreshTokenRepository(db_conn) + + user1 = uuid4() + user2 = uuid4() + org_id = uuid4() + await user_repo.create(user1, "user1@example.com", "hash") + await user_repo.create(user2, "user2@example.com", "hash") + await org_repo.create(org_id, "Shared Org", "shared-org") + + user1_hash = f"user1_token_{uuid4().hex[:8]}" + user2_hash = f"user2_token_{uuid4().hex[:8]}" + expires_at = datetime.now(UTC) + timedelta(days=30) + await token_repo.create(uuid4(), user1, user1_hash, org_id, expires_at) + await token_repo.create(uuid4(), user2, user2_hash, org_id, expires_at) + + await token_repo.revoke_all_for_user(user1) + + # User1's token is revoked + assert await token_repo.get_valid_by_hash(user1_hash) is None + # User2's token is still valid + assert await token_repo.get_valid_by_hash(user2_hash) is not None + + +class TestRevokeAllExcept: + """Tests for revoking all tokens except current session.""" + + async def test_revoke_all_except_keeps_specified_token(self, db_conn: asyncpg.Connection) -> None: + """revoke_all_for_user_except() keeps the specified token active.""" + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + token_repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, "except@example.com", "hash") + await org_repo.create(org_id, "Except Org", "except-org") + + # Create multiple tokens + expires_at = datetime.now(UTC) + timedelta(days=30) + keep_token_id = uuid4() + keep_hash = f"keep_token_{uuid4().hex[:8]}" + await token_repo.create(keep_token_id, user_id, keep_hash, org_id, expires_at) + + other_hashes = [] + for i in range(2): + other_hash = f"other_token_{i}_{uuid4().hex[:8]}" + other_hashes.append(other_hash) + await token_repo.create(uuid4(), user_id, other_hash, org_id, expires_at) + + result = await token_repo.revoke_all_for_user_except(user_id, keep_token_id) + + assert result == 2 # Revoked 2 other tokens + + # Keep token is still valid + assert await token_repo.get_valid_by_hash(keep_hash) is not None + + # Other tokens are revoked + for other_hash in other_hashes: + assert await token_repo.get_valid_by_hash(other_hash) is None + + +class TestActiveTokensForUser: + """Tests for listing active tokens for a user.""" + + async def test_get_active_tokens_returns_only_active(self, db_conn: asyncpg.Connection) -> None: + """get_active_tokens_for_user() returns only non-revoked, non-expired, non-rotated.""" + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + token_repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, "active_list@example.com", "hash") + await org_repo.create(org_id, "Active List Org", "active-list-org") + + expires_at = datetime.now(UTC) + timedelta(days=30) + expired_at = datetime.now(UTC) - timedelta(days=1) + + # Create active token + active_hash = f"active_{uuid4().hex[:8]}" + await token_repo.create(uuid4(), user_id, active_hash, org_id, expires_at) + + # Create revoked token + revoked_id = uuid4() + revoked_hash = f"revoked_{uuid4().hex[:8]}" + await token_repo.create(revoked_id, user_id, revoked_hash, org_id, expires_at) + await token_repo.revoke(revoked_id) + + # Create expired token + expired_hash = f"expired_{uuid4().hex[:8]}" + await token_repo.create(uuid4(), user_id, expired_hash, org_id, expired_at) + + # Create rotated token + rotated_hash = f"rotated_{uuid4().hex[:8]}" + await token_repo.create(uuid4(), user_id, rotated_hash, org_id, expires_at) + await token_repo.rotate(rotated_hash, uuid4(), f"new_{uuid4().hex[:8]}", expires_at) + + result = await token_repo.get_active_tokens_for_user(user_id) + + # Should only return the active token and the new rotated token + assert len(result) == 2 + hashes = {t["token_hash"] for t in result} + assert active_hash in hashes + assert revoked_hash not in hashes + assert expired_hash not in hashes + assert rotated_hash not in hashes + + +class TestTokenForeignKeys: + """Tests for refresh token foreign key constraints.""" + + async def test_token_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """refresh_tokens.user_id must reference existing user.""" + org_repo = OrgRepository(db_conn) + token_repo = RefreshTokenRepository(db_conn) + + org_id = uuid4() + await org_repo.create(org_id, "FK Test Org", "fk-test-org") + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await token_repo.create( + uuid4(), uuid4(), "orphan_token", org_id, + datetime.now(UTC) + timedelta(days=30) + ) + + async def test_token_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """refresh_tokens.active_org_id must reference existing org.""" + user_repo = UserRepository(db_conn) + token_repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + await user_repo.create(user_id, "fk_org_test@example.com", "hash") + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await token_repo.create( + uuid4(), user_id, "orphan_org_token", uuid4(), + datetime.now(UTC) + timedelta(days=30) + ) + + async def test_token_stores_active_org_id(self, db_conn: asyncpg.Connection) -> None: + """Token stores active_org_id for org context per SPECS.md.""" + user_repo = UserRepository(db_conn) + org_repo = OrgRepository(db_conn) + token_repo = RefreshTokenRepository(db_conn) + + user_id = uuid4() + org_id = uuid4() + await user_repo.create(user_id, "active_org@example.com", "hash") + await org_repo.create(org_id, "Active Org", "active-org") + + token_hash = f"active_org_token_{uuid4().hex[:8]}" + await token_repo.create( + uuid4(), user_id, token_hash, org_id, + datetime.now(UTC) + timedelta(days=30) + ) + + token = await token_repo.get_by_hash(token_hash) + assert token["active_org_id"] == org_id diff --git a/tests/repositories/test_service.py b/tests/repositories/test_service.py new file mode 100644 index 0000000..1855a9f --- /dev/null +++ b/tests/repositories/test_service.py @@ -0,0 +1,201 @@ +"""Tests for ServiceRepository.""" + +from uuid import uuid4 + +import asyncpg +import pytest + +from app.repositories.org import OrgRepository +from app.repositories.service import ServiceRepository + + +class TestServiceRepository: + """Tests for ServiceRepository conforming to SPECS.md.""" + + async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4: + """Helper to create an org.""" + org_repo = OrgRepository(conn) + org_id = uuid4() + await org_repo.create(org_id, f"Org {slug}", slug) + return org_id + + async def test_create_service_returns_service_data(self, db_conn: asyncpg.Connection) -> None: + """Creating a service returns the service data.""" + org_id = await self._create_org(db_conn, "service-org") + repo = ServiceRepository(db_conn) + service_id = uuid4() + + result = await repo.create(service_id, org_id, "API Gateway", "api-gateway") + + assert result["id"] == service_id + assert result["org_id"] == org_id + assert result["name"] == "API Gateway" + assert result["slug"] == "api-gateway" + assert result["created_at"] is not None + + async def test_create_service_slug_unique_per_org(self, db_conn: asyncpg.Connection) -> None: + """Service slug must be unique within an org per SPECS.md.""" + org_id = await self._create_org(db_conn, "unique-slug-org") + repo = ServiceRepository(db_conn) + + await repo.create(uuid4(), org_id, "Service One", "my-service") + + with pytest.raises(asyncpg.UniqueViolationError): + await repo.create(uuid4(), org_id, "Service Two", "my-service") + + async def test_same_slug_allowed_in_different_orgs(self, db_conn: asyncpg.Connection) -> None: + """Same slug can exist in different orgs.""" + org1 = await self._create_org(db_conn, "org-one") + org2 = await self._create_org(db_conn, "org-two") + repo = ServiceRepository(db_conn) + slug = "shared-slug" + + # Both should succeed + result1 = await repo.create(uuid4(), org1, "Service Org1", slug) + result2 = await repo.create(uuid4(), org2, "Service Org2", slug) + + assert result1["slug"] == slug + assert result2["slug"] == slug + assert result1["org_id"] != result2["org_id"] + + async def test_get_by_id_returns_service(self, db_conn: asyncpg.Connection) -> None: + """get_by_id returns the correct service.""" + org_id = await self._create_org(db_conn, "get-service-org") + repo = ServiceRepository(db_conn) + service_id = uuid4() + + await repo.create(service_id, org_id, "My Service", "my-service") + result = await repo.get_by_id(service_id) + + assert result is not None + assert result["id"] == service_id + assert result["name"] == "My Service" + + async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_by_id returns None for non-existent service.""" + repo = ServiceRepository(db_conn) + + result = await repo.get_by_id(uuid4()) + + assert result is None + + async def test_get_by_org_returns_all_org_services(self, db_conn: asyncpg.Connection) -> None: + """get_by_org returns all services for an organization.""" + org_id = await self._create_org(db_conn, "multi-service-org") + repo = ServiceRepository(db_conn) + + await repo.create(uuid4(), org_id, "Service A", "service-a") + await repo.create(uuid4(), org_id, "Service B", "service-b") + await repo.create(uuid4(), org_id, "Service C", "service-c") + + result = await repo.get_by_org(org_id) + + assert len(result) == 3 + names = {s["name"] for s in result} + assert names == {"Service A", "Service B", "Service C"} + + async def test_get_by_org_returns_empty_for_no_services(self, db_conn: asyncpg.Connection) -> None: + """get_by_org returns empty list for org with no services.""" + org_id = await self._create_org(db_conn, "empty-service-org") + repo = ServiceRepository(db_conn) + + result = await repo.get_by_org(org_id) + + assert result == [] + + async def test_get_by_org_only_returns_own_services(self, db_conn: asyncpg.Connection) -> None: + """get_by_org doesn't return services from other orgs (tenant isolation).""" + org1 = await self._create_org(db_conn, "isolated-org-1") + org2 = await self._create_org(db_conn, "isolated-org-2") + repo = ServiceRepository(db_conn) + + await repo.create(uuid4(), org1, "Org1 Service", "org1-service") + await repo.create(uuid4(), org2, "Org2 Service", "org2-service") + + result = await repo.get_by_org(org1) + + assert len(result) == 1 + assert result[0]["name"] == "Org1 Service" + + async def test_get_by_slug_returns_service(self, db_conn: asyncpg.Connection) -> None: + """get_by_slug returns service by org and slug.""" + org_id = await self._create_org(db_conn, "slug-lookup-org") + repo = ServiceRepository(db_conn) + service_id = uuid4() + + await repo.create(service_id, org_id, "Slug Service", "slug-service") + result = await repo.get_by_slug(org_id, "slug-service") + + assert result is not None + assert result["id"] == service_id + + async def test_get_by_slug_returns_none_for_wrong_org(self, db_conn: asyncpg.Connection) -> None: + """get_by_slug returns None if slug exists but in different org.""" + org1 = await self._create_org(db_conn, "slug-org-1") + org2 = await self._create_org(db_conn, "slug-org-2") + repo = ServiceRepository(db_conn) + + await repo.create(uuid4(), org1, "Service", "the-slug") + result = await repo.get_by_slug(org2, "the-slug") + + assert result is None + + async def test_get_by_slug_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_by_slug returns None for non-existent slug.""" + org_id = await self._create_org(db_conn, "no-slug-org") + repo = ServiceRepository(db_conn) + + result = await repo.get_by_slug(org_id, "nonexistent") + + assert result is None + + async def test_slug_exists_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None: + """slug_exists returns True when slug exists in org.""" + org_id = await self._create_org(db_conn, "exists-org") + repo = ServiceRepository(db_conn) + + await repo.create(uuid4(), org_id, "Exists Service", "exists-slug") + result = await repo.slug_exists(org_id, "exists-slug") + + assert result is True + + async def test_slug_exists_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None: + """slug_exists returns False when slug doesn't exist in org.""" + org_id = await self._create_org(db_conn, "not-exists-org") + repo = ServiceRepository(db_conn) + + result = await repo.slug_exists(org_id, "no-such-slug") + + assert result is False + + async def test_slug_exists_returns_false_for_other_org(self, db_conn: asyncpg.Connection) -> None: + """slug_exists returns False for slug in different org.""" + org1 = await self._create_org(db_conn, "other-org-1") + org2 = await self._create_org(db_conn, "other-org-2") + repo = ServiceRepository(db_conn) + + await repo.create(uuid4(), org1, "Service", "cross-org-slug") + result = await repo.slug_exists(org2, "cross-org-slug") + + assert result is False + + async def test_service_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None: + """services.org_id must reference existing org.""" + repo = ServiceRepository(db_conn) + + with pytest.raises(asyncpg.ForeignKeyViolationError): + await repo.create(uuid4(), uuid4(), "Orphan Service", "orphan") + + async def test_get_by_org_orders_by_name(self, db_conn: asyncpg.Connection) -> None: + """get_by_org returns services ordered by name.""" + org_id = await self._create_org(db_conn, "ordered-org") + repo = ServiceRepository(db_conn) + + await repo.create(uuid4(), org_id, "Zebra", "zebra") + await repo.create(uuid4(), org_id, "Alpha", "alpha") + await repo.create(uuid4(), org_id, "Middle", "middle") + + result = await repo.get_by_org(org_id) + + names = [s["name"] for s in result] + assert names == ["Alpha", "Middle", "Zebra"] diff --git a/tests/repositories/test_user.py b/tests/repositories/test_user.py new file mode 100644 index 0000000..944b422 --- /dev/null +++ b/tests/repositories/test_user.py @@ -0,0 +1,133 @@ +"""Tests for UserRepository.""" + +from uuid import uuid4 + +import asyncpg +import pytest + +from app.repositories.user import UserRepository + + +class TestUserRepository: + """Tests for UserRepository conforming to SPECS.md.""" + + async def test_create_user_returns_user_data(self, db_conn: asyncpg.Connection) -> None: + """Creating a user returns the user data with id, email, created_at.""" + repo = UserRepository(db_conn) + user_id = uuid4() + email = "test@example.com" + password_hash = "hashed_password_123" + + result = await repo.create(user_id, email, password_hash) + + assert result["id"] == user_id + assert result["email"] == email + assert result["created_at"] is not None + + async def test_create_user_stores_password_hash(self, db_conn: asyncpg.Connection) -> None: + """Password hash is stored correctly in the database.""" + repo = UserRepository(db_conn) + user_id = uuid4() + email = "hash_test@example.com" + password_hash = "bcrypt_hashed_value" + + await repo.create(user_id, email, password_hash) + user = await repo.get_by_id(user_id) + + assert user["password_hash"] == password_hash + + async def test_create_user_email_must_be_unique(self, db_conn: asyncpg.Connection) -> None: + """Email uniqueness constraint per SPECS.md users table.""" + repo = UserRepository(db_conn) + email = "duplicate@example.com" + + await repo.create(uuid4(), email, "hash1") + + with pytest.raises(asyncpg.UniqueViolationError): + await repo.create(uuid4(), email, "hash2") + + async def test_get_by_id_returns_user(self, db_conn: asyncpg.Connection) -> None: + """get_by_id returns the correct user.""" + repo = UserRepository(db_conn) + user_id = uuid4() + email = "getbyid@example.com" + + await repo.create(user_id, email, "hash") + result = await repo.get_by_id(user_id) + + assert result is not None + assert result["id"] == user_id + assert result["email"] == email + + async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_by_id returns None for non-existent user.""" + repo = UserRepository(db_conn) + + result = await repo.get_by_id(uuid4()) + + assert result is None + + async def test_get_by_email_returns_user(self, db_conn: asyncpg.Connection) -> None: + """get_by_email returns the correct user.""" + repo = UserRepository(db_conn) + user_id = uuid4() + email = "getbyemail@example.com" + + await repo.create(user_id, email, "hash") + result = await repo.get_by_email(email) + + assert result is not None + assert result["id"] == user_id + assert result["email"] == email + + async def test_get_by_email_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None: + """get_by_email returns None for non-existent email.""" + repo = UserRepository(db_conn) + + result = await repo.get_by_email("nonexistent@example.com") + + assert result is None + + async def test_get_by_email_is_case_sensitive(self, db_conn: asyncpg.Connection) -> None: + """Email lookup is case-sensitive (stored as provided).""" + repo = UserRepository(db_conn) + email = "CaseSensitive@Example.com" + + await repo.create(uuid4(), email, "hash") + + # Exact match works + result = await repo.get_by_email(email) + assert result is not None + + # Different case returns None + result = await repo.get_by_email(email.lower()) + assert result is None + + async def test_exists_by_email_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None: + """exists_by_email returns True when email exists.""" + repo = UserRepository(db_conn) + email = "exists@example.com" + + await repo.create(uuid4(), email, "hash") + result = await repo.exists_by_email(email) + + assert result is True + + async def test_exists_by_email_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None: + """exists_by_email returns False when email doesn't exist.""" + repo = UserRepository(db_conn) + + result = await repo.exists_by_email("notexists@example.com") + + assert result is False + + async def test_user_id_is_uuid_primary_key(self, db_conn: asyncpg.Connection) -> None: + """User ID must be a valid UUID (primary key).""" + repo = UserRepository(db_conn) + user_id = uuid4() + + await repo.create(user_id, "pk_test@example.com", "hash") + + # Duplicate ID should fail + with pytest.raises(asyncpg.UniqueViolationError): + await repo.create(user_id, "other@example.com", "hash") diff --git a/uv.lock b/uv.lock index 9eda67f..b1f808c 100644 --- a/uv.lock +++ b/uv.lock @@ -309,6 +309,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e8/cb/2da4cc83f5edb9c3257d09e1e7ab7b23f049c7962cae8d842bbef0a9cec9/cryptography-46.0.3-cp38-abi3-win_arm64.whl", hash = "sha256:d89c3468de4cdc4f08a57e214384d0471911a3830fcdaf7a8cc587e42a866372", size = 2918740, upload-time = "2025-10-15T23:18:12.277Z" }, ] +[[package]] +name = "dnspython" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, +] + [[package]] name = "ecdsa" version = "0.19.1" @@ -321,6 +330,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, ] +[[package]] +name = "email-validator" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" }, +] + [[package]] name = "fastapi" version = "0.128.0" @@ -407,7 +429,7 @@ dependencies = [ { name = "celery", extra = ["redis"] }, { name = "fastapi" }, { name = "httpx" }, - { name = "pydantic" }, + { name = "pydantic", extra = ["email"] }, { name = "pydantic-settings" }, { name = "python-jose", extra = ["cryptography"] }, { name = "redis" }, @@ -428,7 +450,7 @@ requires-dist = [ { name = "celery", extras = ["redis"], specifier = ">=5.4.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "httpx", specifier = ">=0.28.0" }, - { name = "pydantic", specifier = ">=2.0.0" }, + { name = "pydantic", extras = ["email"], specifier = ">=2.0.0" }, { name = "pydantic-settings", specifier = ">=2.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" }, @@ -531,6 +553,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" }, ] +[package.optional-dependencies] +email = [ + { name = "email-validator" }, +] + [[package]] name = "pydantic-core" version = "2.41.5"