feat(api): Pydantic schemas + Data Repositories

This commit is contained in:
2025-12-07 03:58:02 -05:00
parent fbe9fbba6e
commit a8fbce09c4
23 changed files with 3549 additions and 3 deletions

View File

@@ -0,0 +1,17 @@
"""Repository layer for database operations."""
from app.repositories.incident import IncidentRepository
from app.repositories.notification import NotificationRepository
from app.repositories.org import OrgRepository
from app.repositories.refresh_token import RefreshTokenRepository
from app.repositories.service import ServiceRepository
from app.repositories.user import UserRepository
__all__ = [
"IncidentRepository",
"NotificationRepository",
"OrgRepository",
"RefreshTokenRepository",
"ServiceRepository",
"UserRepository",
]

View File

@@ -0,0 +1,161 @@
"""Incident repository for database operations."""
from datetime import datetime
from typing import Any
from uuid import UUID
import asyncpg
class IncidentRepository:
"""Database operations for incidents."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
incident_id: UUID,
org_id: UUID,
service_id: UUID,
title: str,
description: str | None,
severity: str,
) -> dict:
"""Create a new incident."""
row = await self.conn.fetchrow(
"""
INSERT INTO incidents (id, org_id, service_id, title, description, status, severity)
VALUES ($1, $2, $3, $4, $5, 'triggered', $6)
RETURNING id, org_id, service_id, title, description, status, severity,
version, created_at, updated_at
""",
incident_id,
org_id,
service_id,
title,
description,
severity,
)
return dict(row)
async def get_by_id(self, incident_id: UUID) -> dict | None:
"""Get incident by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, org_id, service_id, title, description, status, severity,
version, created_at, updated_at
FROM incidents
WHERE id = $1
""",
incident_id,
)
return dict(row) if row else None
async def get_by_org(
self,
org_id: UUID,
status: str | None = None,
cursor: datetime | None = None,
limit: int = 20,
) -> list[dict]:
"""Get incidents for an organization with optional filtering and pagination."""
query = """
SELECT id, org_id, service_id, title, description, status, severity,
version, created_at, updated_at
FROM incidents
WHERE org_id = $1
"""
params: list[Any] = [org_id]
param_idx = 2
if status:
query += f" AND status = ${param_idx}"
params.append(status)
param_idx += 1
if cursor:
query += f" AND created_at < ${param_idx}"
params.append(cursor)
param_idx += 1
query += f" ORDER BY created_at DESC LIMIT ${param_idx}"
params.append(limit + 1) # Fetch one extra to check if there are more
rows = await self.conn.fetch(query, *params)
return [dict(row) for row in rows]
async def update_status(
self,
incident_id: UUID,
new_status: str,
expected_version: int,
) -> dict | None:
"""Update incident status with optimistic locking.
Returns updated incident if successful, None if version mismatch.
"""
row = await self.conn.fetchrow(
"""
UPDATE incidents
SET status = $2, version = version + 1, updated_at = now()
WHERE id = $1 AND version = $3
RETURNING id, org_id, service_id, title, description, status, severity,
version, created_at, updated_at
""",
incident_id,
new_status,
expected_version,
)
return dict(row) if row else None
async def add_event(
self,
event_id: UUID,
incident_id: UUID,
event_type: str,
actor_user_id: UUID | None,
payload: dict[str, Any] | None,
) -> dict:
"""Add an event to the incident timeline."""
import json
row = await self.conn.fetchrow(
"""
INSERT INTO incident_events (id, incident_id, event_type, actor_user_id, payload)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, incident_id, event_type, actor_user_id, payload, created_at
""",
event_id,
incident_id,
event_type,
actor_user_id,
json.dumps(payload) if payload else None,
)
result = dict(row)
# Parse JSON payload back to dict
if result["payload"]:
result["payload"] = json.loads(result["payload"])
return result
async def get_events(self, incident_id: UUID) -> list[dict]:
"""Get all events for an incident."""
import json
rows = await self.conn.fetch(
"""
SELECT id, incident_id, event_type, actor_user_id, payload, created_at
FROM incident_events
WHERE incident_id = $1
ORDER BY created_at
""",
incident_id,
)
results = []
for row in rows:
result = dict(row)
if result["payload"]:
result["payload"] = json.loads(result["payload"])
results.append(result)
return results

View File

@@ -0,0 +1,199 @@
"""Notification repository for database operations."""
from datetime import datetime
from uuid import UUID
import asyncpg
class NotificationRepository:
"""Database operations for notification targets and attempts."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create_target(
self,
target_id: UUID,
org_id: UUID,
name: str,
target_type: str,
webhook_url: str | None = None,
enabled: bool = True,
) -> dict:
"""Create a new notification target."""
row = await self.conn.fetchrow(
"""
INSERT INTO notification_targets (id, org_id, name, target_type, webhook_url, enabled)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at
""",
target_id,
org_id,
name,
target_type,
webhook_url,
enabled,
)
return dict(row)
async def get_target_by_id(self, target_id: UUID) -> dict | None:
"""Get notification target by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, org_id, name, target_type, webhook_url, enabled, created_at
FROM notification_targets
WHERE id = $1
""",
target_id,
)
return dict(row) if row else None
async def get_targets_by_org(
self,
org_id: UUID,
enabled_only: bool = False,
) -> list[dict]:
"""Get all notification targets for an organization."""
query = """
SELECT id, org_id, name, target_type, webhook_url, enabled, created_at
FROM notification_targets
WHERE org_id = $1
"""
if enabled_only:
query += " AND enabled = true"
query += " ORDER BY name"
rows = await self.conn.fetch(query, org_id)
return [dict(row) for row in rows]
async def update_target(
self,
target_id: UUID,
name: str | None = None,
webhook_url: str | None = None,
enabled: bool | None = None,
) -> dict | None:
"""Update a notification target."""
updates = []
params = [target_id]
param_idx = 2
if name is not None:
updates.append(f"name = ${param_idx}")
params.append(name)
param_idx += 1
if webhook_url is not None:
updates.append(f"webhook_url = ${param_idx}")
params.append(webhook_url)
param_idx += 1
if enabled is not None:
updates.append(f"enabled = ${param_idx}")
params.append(enabled)
param_idx += 1
if not updates:
return await self.get_target_by_id(target_id)
query = f"""
UPDATE notification_targets
SET {", ".join(updates)}
WHERE id = $1
RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at
"""
row = await self.conn.fetchrow(query, *params)
return dict(row) if row else None
async def delete_target(self, target_id: UUID) -> bool:
"""Delete a notification target. Returns True if deleted."""
result = await self.conn.execute(
"DELETE FROM notification_targets WHERE id = $1",
target_id,
)
return result == "DELETE 1"
async def create_attempt(
self,
attempt_id: UUID,
incident_id: UUID,
target_id: UUID,
) -> dict:
"""Create a notification attempt (idempotent via unique constraint)."""
row = await self.conn.fetchrow(
"""
INSERT INTO notification_attempts (id, incident_id, target_id, status)
VALUES ($1, $2, $3, 'pending')
ON CONFLICT (incident_id, target_id) DO UPDATE SET id = notification_attempts.id
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
""",
attempt_id,
incident_id,
target_id,
)
return dict(row)
async def get_attempt(self, incident_id: UUID, target_id: UUID) -> dict | None:
"""Get notification attempt for incident and target."""
row = await self.conn.fetchrow(
"""
SELECT id, incident_id, target_id, status, error, sent_at, created_at
FROM notification_attempts
WHERE incident_id = $1 AND target_id = $2
""",
incident_id,
target_id,
)
return dict(row) if row else None
async def update_attempt_success(
self,
attempt_id: UUID,
sent_at: datetime,
) -> dict | None:
"""Mark notification attempt as successful."""
row = await self.conn.fetchrow(
"""
UPDATE notification_attempts
SET status = 'sent', sent_at = $2, error = NULL
WHERE id = $1
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
""",
attempt_id,
sent_at,
)
return dict(row) if row else None
async def update_attempt_failure(
self,
attempt_id: UUID,
error: str,
) -> dict | None:
"""Mark notification attempt as failed."""
row = await self.conn.fetchrow(
"""
UPDATE notification_attempts
SET status = 'failed', error = $2
WHERE id = $1
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
""",
attempt_id,
error,
)
return dict(row) if row else None
async def get_pending_attempts(self, incident_id: UUID) -> list[dict]:
"""Get all pending notification attempts for an incident."""
rows = await self.conn.fetch(
"""
SELECT na.id, na.incident_id, na.target_id, na.status, na.error,
na.sent_at, na.created_at,
nt.target_type, nt.webhook_url, nt.name as target_name
FROM notification_attempts na
JOIN notification_targets nt ON nt.id = na.target_id
WHERE na.incident_id = $1 AND na.status = 'pending'
""",
incident_id,
)
return [dict(row) for row in rows]

125
app/repositories/org.py Normal file
View File

@@ -0,0 +1,125 @@
"""Organization repository for database operations."""
from uuid import UUID
import asyncpg
class OrgRepository:
"""Database operations for organizations."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
org_id: UUID,
name: str,
slug: str,
) -> dict:
"""Create a new organization."""
row = await self.conn.fetchrow(
"""
INSERT INTO orgs (id, name, slug)
VALUES ($1, $2, $3)
RETURNING id, name, slug, created_at
""",
org_id,
name,
slug,
)
return dict(row)
async def get_by_id(self, org_id: UUID) -> dict | None:
"""Get organization by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, name, slug, created_at
FROM orgs
WHERE id = $1
""",
org_id,
)
return dict(row) if row else None
async def get_by_slug(self, slug: str) -> dict | None:
"""Get organization by slug."""
row = await self.conn.fetchrow(
"""
SELECT id, name, slug, created_at
FROM orgs
WHERE slug = $1
""",
slug,
)
return dict(row) if row else None
async def add_member(
self,
member_id: UUID,
user_id: UUID,
org_id: UUID,
role: str,
) -> dict:
"""Add a member to an organization."""
row = await self.conn.fetchrow(
"""
INSERT INTO org_members (id, user_id, org_id, role)
VALUES ($1, $2, $3, $4)
RETURNING id, user_id, org_id, role, created_at
""",
member_id,
user_id,
org_id,
role,
)
return dict(row)
async def get_member(self, user_id: UUID, org_id: UUID) -> dict | None:
"""Get membership for a user in an organization."""
row = await self.conn.fetchrow(
"""
SELECT om.id, om.user_id, om.org_id, om.role, om.created_at
FROM org_members om
WHERE om.user_id = $1 AND om.org_id = $2
""",
user_id,
org_id,
)
return dict(row) if row else None
async def get_members(self, org_id: UUID) -> list[dict]:
"""Get all members of an organization."""
rows = await self.conn.fetch(
"""
SELECT om.id, om.user_id, u.email, om.role, om.created_at
FROM org_members om
JOIN users u ON u.id = om.user_id
WHERE om.org_id = $1
ORDER BY om.created_at
""",
org_id,
)
return [dict(row) for row in rows]
async def get_user_orgs(self, user_id: UUID) -> list[dict]:
"""Get all organizations a user belongs to."""
rows = await self.conn.fetch(
"""
SELECT o.id, o.name, o.slug, o.created_at, om.role
FROM orgs o
JOIN org_members om ON om.org_id = o.id
WHERE om.user_id = $1
ORDER BY o.created_at
""",
user_id,
)
return [dict(row) for row in rows]
async def slug_exists(self, slug: str) -> bool:
"""Check if organization slug exists."""
result = await self.conn.fetchval(
"SELECT EXISTS(SELECT 1 FROM orgs WHERE slug = $1)",
slug,
)
return result

View File

@@ -0,0 +1,396 @@
"""Refresh token repository for database operations.
Security considerations implemented:
- Atomic rotation using SELECT FOR UPDATE to prevent race conditions
- Token chain tracking via rotated_to for reuse/theft detection
- Defense-in-depth validation with user_id and active_org_id checks
- Uses RETURNING for robust row counting instead of string parsing
"""
from datetime import datetime
from uuid import UUID
import asyncpg
class RefreshTokenRepository:
"""Database operations for refresh tokens."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
token_id: UUID,
user_id: UUID,
token_hash: str,
active_org_id: UUID,
expires_at: datetime,
) -> dict:
"""Create a new refresh token."""
row = await self.conn.fetchrow(
"""
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
""",
token_id,
user_id,
token_hash,
active_org_id,
expires_at,
)
return dict(row)
async def get_by_hash(self, token_hash: str) -> dict | None:
"""Get refresh token by hash (includes revoked/expired for auditing)."""
row = await self.conn.fetchrow(
"""
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1
""",
token_hash,
)
return dict(row) if row else None
async def get_valid_by_hash(
self,
token_hash: str,
user_id: UUID | None = None,
active_org_id: UUID | None = None,
) -> dict | None:
"""Get refresh token by hash, only if valid.
Validates:
- Token exists and matches hash
- Token is not revoked
- Token is not expired
- Token has not been rotated (rotated_to is NULL)
- Optionally: user_id matches (defense-in-depth)
- Optionally: active_org_id matches (defense-in-depth)
Args:
token_hash: The hashed token value
user_id: If provided, token must belong to this user
active_org_id: If provided, token must be bound to this org
Returns:
Token dict if valid, None otherwise
"""
query = """
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1
AND revoked_at IS NULL
AND rotated_to IS NULL
AND expires_at > clock_timestamp()
"""
params: list = [token_hash]
param_idx = 2
if user_id is not None:
query += f" AND user_id = ${param_idx}"
params.append(user_id)
param_idx += 1
if active_org_id is not None:
query += f" AND active_org_id = ${param_idx}"
params.append(active_org_id)
row = await self.conn.fetchrow(query, *params)
return dict(row) if row else None
async def get_valid_for_rotation(
self,
token_hash: str,
user_id: UUID | None = None,
) -> dict | None:
"""Get and lock a valid token for rotation using SELECT FOR UPDATE.
This acquires a row-level lock to prevent concurrent rotation attempts.
Must be called within a transaction.
Args:
token_hash: The hashed token value
user_id: If provided, token must belong to this user
Returns:
Token dict if valid and locked, None otherwise
"""
query = """
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1
AND revoked_at IS NULL
AND rotated_to IS NULL
AND expires_at > clock_timestamp()
"""
params: list = [token_hash]
if user_id is not None:
query += " AND user_id = $2"
params.append(user_id)
query += " FOR UPDATE"
row = await self.conn.fetchrow(query, *params)
return dict(row) if row else None
async def check_token_reuse(self, token_hash: str) -> dict | None:
"""Check if a token has already been rotated (potential theft).
If a token is presented that has rotated_to set, it means:
1. The token was legitimately rotated earlier
2. Someone is now trying to use the old token
3. This indicates the token may have been stolen
Returns:
Token dict if this is a reused/stolen token, None if not found or not rotated
"""
row = await self.conn.fetchrow(
"""
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1 AND rotated_to IS NOT NULL
""",
token_hash,
)
return dict(row) if row else None
async def revoke_token_chain(self, token_id: UUID) -> int:
"""Revoke a token and all tokens in its chain (for breach response).
When token reuse is detected, this revokes:
1. The original stolen token
2. Any token it was rotated to (and their rotations, recursively)
Args:
token_id: The ID of the compromised token
Returns:
Count of tokens revoked
"""
# Use recursive CTE to find all tokens in the chain
rows = await self.conn.fetch(
"""
WITH RECURSIVE token_chain AS (
-- Start with the given token
SELECT id, rotated_to
FROM refresh_tokens
WHERE id = $1
UNION ALL
-- Follow the chain via rotated_to
SELECT rt.id, rt.rotated_to
FROM refresh_tokens rt
INNER JOIN token_chain tc ON rt.id = tc.rotated_to
)
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE id IN (SELECT id FROM token_chain)
AND revoked_at IS NULL
RETURNING id
""",
token_id,
)
return len(rows)
async def rotate(
self,
old_token_hash: str,
new_token_id: UUID,
new_token_hash: str,
new_expires_at: datetime,
new_active_org_id: UUID | None = None,
expected_user_id: UUID | None = None,
) -> dict | None:
"""Atomically rotate a refresh token.
This method:
1. Validates the old token (not expired, not revoked, not already rotated)
2. Locks the row to prevent concurrent rotation
3. Marks old token as rotated (sets rotated_to)
4. Creates new token with updated org if specified
5. All in a single atomic operation
Args:
old_token_hash: Hash of the token being rotated
new_token_id: UUID for the new token
new_token_hash: Hash for the new token
new_expires_at: Expiry time for the new token
new_active_org_id: New org ID (for org-switch), or None to keep current
expected_user_id: If provided, validates token belongs to this user
Returns:
New token dict if rotation succeeded, None if old token invalid/expired
"""
# First, get and lock the old token
old_token = await self.get_valid_for_rotation(old_token_hash, expected_user_id)
if old_token is None:
return None
# Determine the org for the new token
active_org_id = new_active_org_id or old_token["active_org_id"]
user_id = old_token["user_id"]
# Create the new token
new_token = await self.conn.fetchrow(
"""
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
""",
new_token_id,
user_id,
new_token_hash,
active_org_id,
new_expires_at,
)
# Mark the old token as rotated (not revoked - for reuse detection)
await self.conn.execute(
"""
UPDATE refresh_tokens
SET rotated_to = $2
WHERE id = $1
""",
old_token["id"],
new_token_id,
)
return dict(new_token)
async def revoke(self, token_id: UUID) -> bool:
"""Revoke a refresh token by ID.
Returns:
True if token was revoked, False if not found or already revoked
"""
row = await self.conn.fetchrow(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE id = $1 AND revoked_at IS NULL
RETURNING id
""",
token_id,
)
return row is not None
async def revoke_by_hash(self, token_hash: str) -> bool:
"""Revoke a refresh token by hash.
Returns:
True if token was revoked, False if not found or already revoked
"""
row = await self.conn.fetchrow(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE token_hash = $1 AND revoked_at IS NULL
RETURNING id
""",
token_hash,
)
return row is not None
async def revoke_all_for_user(self, user_id: UUID) -> int:
"""Revoke all active refresh tokens for a user.
Use this for:
- User-initiated logout from all devices
- Password change
- Account compromise response
Returns:
Count of tokens revoked
"""
rows = await self.conn.fetch(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE user_id = $1 AND revoked_at IS NULL
RETURNING id
""",
user_id,
)
return len(rows)
async def revoke_all_for_user_except(self, user_id: UUID, keep_token_id: UUID) -> int:
"""Revoke all tokens for a user except one (logout other sessions).
Args:
user_id: The user whose tokens to revoke
keep_token_id: The token ID to keep active (current session)
Returns:
Count of tokens revoked
"""
rows = await self.conn.fetch(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE user_id = $1 AND revoked_at IS NULL AND id != $2
RETURNING id
""",
user_id,
keep_token_id,
)
return len(rows)
async def get_active_tokens_for_user(self, user_id: UUID) -> list[dict]:
"""Get all active (non-revoked, non-expired, non-rotated) tokens for a user.
Useful for:
- Showing active sessions
- Auditing
Returns:
List of active token records
"""
rows = await self.conn.fetch(
"""
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE user_id = $1
AND revoked_at IS NULL
AND rotated_to IS NULL
AND expires_at > clock_timestamp()
ORDER BY created_at DESC
""",
user_id,
)
return [dict(row) for row in rows]
async def cleanup_expired(self, older_than_days: int = 30) -> int:
"""Delete expired tokens older than specified days.
Note: This performs a hard delete. For audit trails, I think we should:
- Archiving to a separate table first
- Using partitioning with retention policies
- Only calling this for tokens well past their expiry
Args:
older_than_days: Only delete tokens expired more than this many days ago
Returns:
Count of tokens deleted
"""
rows = await self.conn.fetch(
"""
DELETE FROM refresh_tokens
WHERE expires_at < clock_timestamp() - interval '1 day' * $1
RETURNING id
""",
older_than_days,
)
return len(rows)

View File

@@ -0,0 +1,80 @@
"""Service repository for database operations."""
from uuid import UUID
import asyncpg
class ServiceRepository:
"""Database operations for services."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
service_id: UUID,
org_id: UUID,
name: str,
slug: str,
) -> dict:
"""Create a new service."""
row = await self.conn.fetchrow(
"""
INSERT INTO services (id, org_id, name, slug)
VALUES ($1, $2, $3, $4)
RETURNING id, org_id, name, slug, created_at
""",
service_id,
org_id,
name,
slug,
)
return dict(row)
async def get_by_id(self, service_id: UUID) -> dict | None:
"""Get service by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, org_id, name, slug, created_at
FROM services
WHERE id = $1
""",
service_id,
)
return dict(row) if row else None
async def get_by_org(self, org_id: UUID) -> list[dict]:
"""Get all services for an organization."""
rows = await self.conn.fetch(
"""
SELECT id, org_id, name, slug, created_at
FROM services
WHERE org_id = $1
ORDER BY name
""",
org_id,
)
return [dict(row) for row in rows]
async def get_by_slug(self, org_id: UUID, slug: str) -> dict | None:
"""Get service by org and slug."""
row = await self.conn.fetchrow(
"""
SELECT id, org_id, name, slug, created_at
FROM services
WHERE org_id = $1 AND slug = $2
""",
org_id,
slug,
)
return dict(row) if row else None
async def slug_exists(self, org_id: UUID, slug: str) -> bool:
"""Check if service slug exists in organization."""
result = await self.conn.fetchval(
"SELECT EXISTS(SELECT 1 FROM services WHERE org_id = $1 AND slug = $2)",
org_id,
slug,
)
return result

63
app/repositories/user.py Normal file
View File

@@ -0,0 +1,63 @@
"""User repository for database operations."""
from uuid import UUID
import asyncpg
class UserRepository:
"""Database operations for users."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
user_id: UUID,
email: str,
password_hash: str,
) -> dict:
"""Create a new user."""
row = await self.conn.fetchrow(
"""
INSERT INTO users (id, email, password_hash)
VALUES ($1, $2, $3)
RETURNING id, email, created_at
""",
user_id,
email,
password_hash,
)
return dict(row)
async def get_by_id(self, user_id: UUID) -> dict | None:
"""Get user by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, email, password_hash, created_at
FROM users
WHERE id = $1
""",
user_id,
)
return dict(row) if row else None
async def get_by_email(self, email: str) -> dict | None:
"""Get user by email."""
row = await self.conn.fetchrow(
"""
SELECT id, email, password_hash, created_at
FROM users
WHERE email = $1
""",
email,
)
return dict(row) if row else None
async def exists_by_email(self, email: str) -> bool:
"""Check if user exists by email."""
result = await self.conn.fetchval(
"SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)",
email,
)
return result

50
app/schemas/__init__.py Normal file
View File

@@ -0,0 +1,50 @@
"""Pydantic schemas for request/response models."""
from app.schemas.auth import (
LoginRequest,
RefreshRequest,
RegisterRequest,
SwitchOrgRequest,
TokenResponse,
)
from app.schemas.common import CursorParams, PaginatedResponse
from app.schemas.incident import (
CommentRequest,
IncidentCreate,
IncidentEventResponse,
IncidentResponse,
TransitionRequest,
)
from app.schemas.org import (
MemberResponse,
NotificationTargetCreate,
NotificationTargetResponse,
OrgResponse,
ServiceCreate,
ServiceResponse,
)
__all__ = [
# Auth
"LoginRequest",
"RefreshRequest",
"RegisterRequest",
"SwitchOrgRequest",
"TokenResponse",
# Common
"CursorParams",
"PaginatedResponse",
# Incident
"CommentRequest",
"IncidentCreate",
"IncidentEventResponse",
"IncidentResponse",
"TransitionRequest",
# Org
"MemberResponse",
"NotificationTargetCreate",
"NotificationTargetResponse",
"OrgResponse",
"ServiceCreate",
"ServiceResponse",
]

42
app/schemas/auth.py Normal file
View File

@@ -0,0 +1,42 @@
"""Authentication schemas."""
from uuid import UUID
from pydantic import BaseModel, EmailStr, Field
class RegisterRequest(BaseModel):
"""Request body for user registration."""
email: EmailStr
password: str = Field(min_length=8, max_length=128)
org_name: str = Field(min_length=1, max_length=100, description="Name for the default org")
class LoginRequest(BaseModel):
"""Request body for user login."""
email: EmailStr
password: str
class RefreshRequest(BaseModel):
"""Request body for token refresh."""
refresh_token: str
class SwitchOrgRequest(BaseModel):
"""Request body for switching active organization."""
org_id: UUID
refresh_token: str
class TokenResponse(BaseModel):
"""Response containing access and refresh tokens."""
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int = Field(description="Access token expiry in seconds")

20
app/schemas/common.py Normal file
View File

@@ -0,0 +1,20 @@
"""Common schemas used across the API."""
from pydantic import BaseModel, Field
class CursorParams(BaseModel):
"""Pagination parameters using cursor-based pagination."""
cursor: str | None = Field(default=None, description="Cursor for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page")
class PaginatedResponse[T](BaseModel):
"""Generic paginated response wrapper."""
items: list[T]
next_cursor: str | None = Field(
default=None, description="Cursor for next page, null if no more items"
)
has_more: bool = Field(description="Whether there are more items")

57
app/schemas/incident.py Normal file
View File

@@ -0,0 +1,57 @@
"""Incident-related schemas."""
from datetime import datetime
from typing import Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field
IncidentStatus = Literal["triggered", "acknowledged", "mitigated", "resolved"]
IncidentSeverity = Literal["critical", "high", "medium", "low"]
class IncidentCreate(BaseModel):
"""Request body for creating an incident."""
title: str = Field(min_length=1, max_length=200)
description: str | None = Field(default=None, max_length=5000)
severity: IncidentSeverity = "medium"
class IncidentResponse(BaseModel):
"""Incident response."""
id: UUID
service_id: UUID
title: str
description: str | None
status: IncidentStatus
severity: IncidentSeverity
version: int
created_at: datetime
updated_at: datetime
class IncidentEventResponse(BaseModel):
"""Incident event response."""
id: UUID
incident_id: UUID
event_type: str
actor_user_id: UUID | None
payload: dict[str, Any] | None
created_at: datetime
class TransitionRequest(BaseModel):
"""Request body for transitioning incident status."""
to_status: IncidentStatus
version: int = Field(description="Current version for optimistic locking")
note: str | None = Field(default=None, max_length=1000)
class CommentRequest(BaseModel):
"""Request body for adding a comment to an incident."""
content: str = Field(min_length=1, max_length=5000)

69
app/schemas/org.py Normal file
View File

@@ -0,0 +1,69 @@
"""Organization-related schemas."""
from datetime import datetime
from typing import Literal
from uuid import UUID
from pydantic import BaseModel, Field, HttpUrl
class OrgResponse(BaseModel):
"""Organization summary response."""
id: UUID
name: str
slug: str
created_at: datetime
class MemberResponse(BaseModel):
"""Organization member response."""
id: UUID
user_id: UUID
email: str
role: Literal["admin", "member", "viewer"]
created_at: datetime
class ServiceCreate(BaseModel):
"""Request body for creating a service."""
name: str = Field(min_length=1, max_length=100)
slug: str = Field(
min_length=1,
max_length=50,
pattern=r"^[a-z0-9]+(?:-[a-z0-9]+)*$",
description="URL-friendly identifier (lowercase, hyphens allowed)",
)
class ServiceResponse(BaseModel):
"""Service response."""
id: UUID
name: str
slug: str
created_at: datetime
class NotificationTargetCreate(BaseModel):
"""Request body for creating a notification target."""
name: str = Field(min_length=1, max_length=100)
target_type: Literal["webhook", "email", "slack"]
webhook_url: HttpUrl | None = Field(
default=None, description="Required for webhook type"
)
enabled: bool = True
class NotificationTargetResponse(BaseModel):
"""Notification target response."""
id: UUID
name: str
target_type: Literal["webhook", "email", "slack"]
webhook_url: str | None
enabled: bool
created_at: datetime

View File

@@ -0,0 +1,18 @@
-- Enhance refresh tokens for secure rotation and reuse detection
-- Adds rotated_to column to track token chains and detect stolen token reuse
-- Add rotated_to column to track which token this was rotated into
-- When a token is rotated, we store the ID of the new token here
-- If a token with rotated_to set is used again, it indicates token theft
ALTER TABLE refresh_tokens ADD COLUMN rotated_to UUID REFERENCES refresh_tokens(id);
-- Index for efficient cleanup queries on expires_at
CREATE INDEX idx_refresh_tokens_expires ON refresh_tokens(expires_at);
-- Index for finding active tokens per user (for revoke_all and listing)
CREATE INDEX idx_refresh_tokens_user_active ON refresh_tokens(user_id, revoked_at)
WHERE revoked_at IS NULL;
-- Index for reuse detection queries
CREATE INDEX idx_refresh_tokens_rotated ON refresh_tokens(rotated_to)
WHERE rotated_to IS NOT NULL;

View File

@@ -8,7 +8,7 @@ dependencies = [
"fastapi>=0.115.0", "fastapi>=0.115.0",
"uvicorn[standard]>=0.32.0", "uvicorn[standard]>=0.32.0",
"asyncpg>=0.30.0", "asyncpg>=0.30.0",
"pydantic>=2.0.0", "pydantic[email]>=2.0.0",
"pydantic-settings>=2.0.0", "pydantic-settings>=2.0.0",
"python-jose[cryptography]>=3.3.0", "python-jose[cryptography]>=3.3.0",
"bcrypt>=4.0.0", "bcrypt>=4.0.0",
@@ -38,6 +38,9 @@ target-version = "py314"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP"] select = ["E", "F", "I", "N", "W", "UP"]
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["E501"] # Allow longer lines in tests for descriptive method names
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = "auto" asyncio_mode = "auto"
testpaths = ["tests"] testpaths = ["tests"]

95
tests/conftest.py Normal file
View File

@@ -0,0 +1,95 @@
"""Shared pytest fixtures for all tests."""
from __future__ import annotations
import os
from uuid import uuid4
import asyncpg
import pytest
# Set test environment variables before importing app modules
os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@localhost:5432/incidentops_test")
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1")
# Module-level setup: create database and run migrations once
_db_initialized = False
async def _init_test_db() -> None:
"""Initialize test database and run migrations (once per session)."""
global _db_initialized
if _db_initialized:
return
admin_dsn = os.environ["DATABASE_URL"].rsplit("/", 1)[0] + "/postgres"
test_db_name = "incidentops_test"
admin_conn = await asyncpg.connect(admin_dsn)
try:
# Terminate existing connections to the test database
await admin_conn.execute(f"""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{test_db_name}'
AND pid <> pg_backend_pid()
""")
# Drop and recreate test database
await admin_conn.execute(f"DROP DATABASE IF EXISTS {test_db_name}")
await admin_conn.execute(f"CREATE DATABASE {test_db_name}")
finally:
await admin_conn.close()
# Connect to test database and run migrations
test_dsn = os.environ["DATABASE_URL"]
conn = await asyncpg.connect(test_dsn)
try:
# Run migrations
migrations_dir = os.path.join(os.path.dirname(__file__), "..", "migrations")
migration_files = sorted(
f for f in os.listdir(migrations_dir) if f.endswith(".sql")
)
for migration_file in migration_files:
migration_path = os.path.join(migrations_dir, migration_file)
with open(migration_path) as f:
sql = f.read()
await conn.execute(sql)
finally:
await conn.close()
_db_initialized = True
@pytest.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a database connection with transaction rollback for test isolation."""
await _init_test_db()
test_dsn = os.environ["DATABASE_URL"]
conn = await asyncpg.connect(test_dsn)
# Start a transaction that will be rolled back
tr = conn.transaction()
await tr.start()
try:
yield conn
finally:
await tr.rollback()
await conn.close()
@pytest.fixture
def make_user_id() -> uuid4:
"""Factory for generating user IDs."""
return lambda: uuid4()
@pytest.fixture
def make_org_id() -> uuid4:
"""Factory for generating org IDs."""
return lambda: uuid4()

View File

@@ -0,0 +1 @@
"""Repository tests."""

View File

@@ -0,0 +1,389 @@
"""Tests for IncidentRepository."""
from uuid import uuid4
import asyncpg
import pytest
from app.repositories.incident import IncidentRepository
from app.repositories.org import OrgRepository
from app.repositories.service import ServiceRepository
from app.repositories.user import UserRepository
class TestIncidentRepository:
"""Tests for IncidentRepository conforming to SPECS.md."""
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
"""Helper to create an org."""
org_repo = OrgRepository(conn)
org_id = uuid4()
await org_repo.create(org_id, f"Org {slug}", slug)
return org_id
async def _create_service(self, conn: asyncpg.Connection, org_id: uuid4, slug: str) -> uuid4:
"""Helper to create a service."""
service_repo = ServiceRepository(conn)
service_id = uuid4()
await service_repo.create(service_id, org_id, f"Service {slug}", slug)
return service_id
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
"""Helper to create a user."""
user_repo = UserRepository(conn)
user_id = uuid4()
await user_repo.create(user_id, email, "hash")
return user_id
async def test_create_incident_returns_incident_data(self, db_conn: asyncpg.Connection) -> None:
"""Creating an incident returns the incident data with triggered status."""
org_id = await self._create_org(db_conn, "incident-org")
service_id = await self._create_service(db_conn, org_id, "incident-service")
repo = IncidentRepository(db_conn)
incident_id = uuid4()
result = await repo.create(
incident_id, org_id, service_id,
title="Server Down",
description="Main API server is not responding",
severity="critical"
)
assert result["id"] == incident_id
assert result["org_id"] == org_id
assert result["service_id"] == service_id
assert result["title"] == "Server Down"
assert result["description"] == "Main API server is not responding"
assert result["status"] == "triggered" # Initial status per SPECS.md
assert result["severity"] == "critical"
assert result["version"] == 1
assert result["created_at"] is not None
assert result["updated_at"] is not None
async def test_create_incident_initial_status_is_triggered(self, db_conn: asyncpg.Connection) -> None:
"""New incidents always start with 'triggered' status per SPECS.md state machine."""
org_id = await self._create_org(db_conn, "triggered-org")
service_id = await self._create_service(db_conn, org_id, "triggered-service")
repo = IncidentRepository(db_conn)
result = await repo.create(uuid4(), org_id, service_id, "Test", None, "low")
assert result["status"] == "triggered"
async def test_create_incident_initial_version_is_one(self, db_conn: asyncpg.Connection) -> None:
"""New incidents start with version 1 for optimistic locking."""
org_id = await self._create_org(db_conn, "version-org")
service_id = await self._create_service(db_conn, org_id, "version-service")
repo = IncidentRepository(db_conn)
result = await repo.create(uuid4(), org_id, service_id, "Test", None, "medium")
assert result["version"] == 1
async def test_create_incident_severity_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
"""Severity must be critical, high, medium, or low per SPECS.md."""
org_id = await self._create_org(db_conn, "severity-org")
service_id = await self._create_service(db_conn, org_id, "severity-service")
repo = IncidentRepository(db_conn)
# Valid severities
for severity in ["critical", "high", "medium", "low"]:
result = await repo.create(uuid4(), org_id, service_id, f"Test {severity}", None, severity)
assert result["severity"] == severity
# Invalid severity
with pytest.raises(asyncpg.CheckViolationError):
await repo.create(uuid4(), org_id, service_id, "Invalid", None, "extreme")
async def test_get_by_id_returns_incident(self, db_conn: asyncpg.Connection) -> None:
"""get_by_id returns the correct incident."""
org_id = await self._create_org(db_conn, "getbyid-org")
service_id = await self._create_service(db_conn, org_id, "getbyid-service")
repo = IncidentRepository(db_conn)
incident_id = uuid4()
await repo.create(incident_id, org_id, service_id, "My Incident", "Details", "high")
result = await repo.get_by_id(incident_id)
assert result is not None
assert result["id"] == incident_id
assert result["title"] == "My Incident"
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_id returns None for non-existent incident."""
repo = IncidentRepository(db_conn)
result = await repo.get_by_id(uuid4())
assert result is None
async def test_get_by_org_returns_org_incidents(self, db_conn: asyncpg.Connection) -> None:
"""get_by_org returns incidents for the organization."""
org_id = await self._create_org(db_conn, "list-org")
service_id = await self._create_service(db_conn, org_id, "list-service")
repo = IncidentRepository(db_conn)
await repo.create(uuid4(), org_id, service_id, "Incident 1", None, "low")
await repo.create(uuid4(), org_id, service_id, "Incident 2", None, "medium")
await repo.create(uuid4(), org_id, service_id, "Incident 3", None, "high")
result = await repo.get_by_org(org_id)
assert len(result) == 3
async def test_get_by_org_filters_by_status(self, db_conn: asyncpg.Connection) -> None:
"""get_by_org can filter by status."""
org_id = await self._create_org(db_conn, "filter-org")
service_id = await self._create_service(db_conn, org_id, "filter-service")
repo = IncidentRepository(db_conn)
# Create incidents and transition some
inc1 = uuid4()
inc2 = uuid4()
await repo.create(inc1, org_id, service_id, "Triggered", None, "low")
await repo.create(inc2, org_id, service_id, "Will be Acked", None, "low")
await repo.update_status(inc2, "acknowledged", 1)
result = await repo.get_by_org(org_id, status="triggered")
assert len(result) == 1
assert result[0]["title"] == "Triggered"
async def test_get_by_org_pagination_with_cursor(self, db_conn: asyncpg.Connection) -> None:
"""get_by_org supports cursor-based pagination."""
org_id = await self._create_org(db_conn, "pagination-org")
service_id = await self._create_service(db_conn, org_id, "pagination-service")
repo = IncidentRepository(db_conn)
# Create 5 incidents
for i in range(5):
await repo.create(uuid4(), org_id, service_id, f"Incident {i}", None, "low")
# Get first page - should return limit+1 to check for more
page1 = await repo.get_by_org(org_id, limit=2)
assert len(page1) == 3
# Verify total is 5 when we get all
all_incidents = await repo.get_by_org(org_id, limit=10)
assert len(all_incidents) == 5
async def test_get_by_org_orders_by_created_at_desc(self, db_conn: asyncpg.Connection) -> None:
"""get_by_org returns incidents ordered by created_at descending."""
org_id = await self._create_org(db_conn, "order-org")
service_id = await self._create_service(db_conn, org_id, "order-service")
repo = IncidentRepository(db_conn)
await repo.create(uuid4(), org_id, service_id, "First", None, "low")
await repo.create(uuid4(), org_id, service_id, "Second", None, "low")
await repo.create(uuid4(), org_id, service_id, "Third", None, "low")
result = await repo.get_by_org(org_id)
# Verify ordering - newer items should come first (or same time due to fast execution)
assert len(result) == 3
for i in range(len(result) - 1):
assert result[i]["created_at"] >= result[i + 1]["created_at"]
async def test_get_by_org_tenant_isolation(self, db_conn: asyncpg.Connection) -> None:
"""get_by_org only returns incidents for the specified org."""
org1 = await self._create_org(db_conn, "tenant-org-1")
org2 = await self._create_org(db_conn, "tenant-org-2")
service1 = await self._create_service(db_conn, org1, "tenant-service-1")
service2 = await self._create_service(db_conn, org2, "tenant-service-2")
repo = IncidentRepository(db_conn)
await repo.create(uuid4(), org1, service1, "Org1 Incident", None, "low")
await repo.create(uuid4(), org2, service2, "Org2 Incident", None, "low")
result = await repo.get_by_org(org1)
assert len(result) == 1
assert result[0]["title"] == "Org1 Incident"
class TestIncidentStatusTransitions:
"""Tests for incident status transitions per SPECS.md state machine."""
async def _setup_incident(self, conn: asyncpg.Connection) -> tuple[uuid4, IncidentRepository]:
"""Helper to create org, service, and incident."""
org_repo = OrgRepository(conn)
service_repo = ServiceRepository(conn)
incident_repo = IncidentRepository(conn)
org_id = uuid4()
service_id = uuid4()
incident_id = uuid4()
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
await service_repo.create(service_id, org_id, "Test Service", f"test-service-{uuid4().hex[:8]}")
await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium")
return incident_id, incident_repo
async def test_update_status_increments_version(self, db_conn: asyncpg.Connection) -> None:
"""update_status increments version for optimistic locking."""
incident_id, repo = await self._setup_incident(db_conn)
result = await repo.update_status(incident_id, "acknowledged", 1)
assert result is not None
assert result["version"] == 2
async def test_update_status_fails_on_version_mismatch(self, db_conn: asyncpg.Connection) -> None:
"""update_status returns None on version mismatch (optimistic locking)."""
incident_id, repo = await self._setup_incident(db_conn)
# Try with wrong version
result = await repo.update_status(incident_id, "acknowledged", 999)
assert result is None
async def test_update_status_updates_updated_at(self, db_conn: asyncpg.Connection) -> None:
"""update_status updates the updated_at timestamp."""
incident_id, repo = await self._setup_incident(db_conn)
before = await repo.get_by_id(incident_id)
result = await repo.update_status(incident_id, "acknowledged", 1)
# updated_at should be at least as recent as before (may be same in fast execution)
assert result["updated_at"] >= before["updated_at"]
# Also verify status was actually updated
assert result["status"] == "acknowledged"
async def test_status_must_be_valid_value(self, db_conn: asyncpg.Connection) -> None:
"""Status must be triggered, acknowledged, mitigated, or resolved per SPECS.md."""
incident_id, repo = await self._setup_incident(db_conn)
with pytest.raises(asyncpg.CheckViolationError):
await repo.update_status(incident_id, "invalid_status", 1)
async def test_valid_status_transitions(self, db_conn: asyncpg.Connection) -> None:
"""Test the valid status values per SPECS.md."""
incident_id, repo = await self._setup_incident(db_conn)
# Triggered -> Acknowledged
result = await repo.update_status(incident_id, "acknowledged", 1)
assert result["status"] == "acknowledged"
# Acknowledged -> Mitigated
result = await repo.update_status(incident_id, "mitigated", 2)
assert result["status"] == "mitigated"
# Mitigated -> Resolved
result = await repo.update_status(incident_id, "resolved", 3)
assert result["status"] == "resolved"
class TestIncidentEvents:
"""Tests for incident events (timeline) per SPECS.md incident_events table."""
async def _setup_incident(self, conn: asyncpg.Connection) -> tuple[uuid4, uuid4, IncidentRepository]:
"""Helper to create org, service, user, and incident."""
org_repo = OrgRepository(conn)
service_repo = ServiceRepository(conn)
user_repo = UserRepository(conn)
incident_repo = IncidentRepository(conn)
org_id = uuid4()
service_id = uuid4()
user_id = uuid4()
incident_id = uuid4()
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
await service_repo.create(service_id, org_id, "Test Service", f"test-svc-{uuid4().hex[:8]}")
await user_repo.create(user_id, f"user-{uuid4().hex[:8]}@example.com", "hash")
await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium")
return incident_id, user_id, incident_repo
async def test_add_event_creates_event(self, db_conn: asyncpg.Connection) -> None:
"""add_event creates an event in the timeline."""
incident_id, user_id, repo = await self._setup_incident(db_conn)
event_id = uuid4()
result = await repo.add_event(
event_id, incident_id, "status_changed",
actor_user_id=user_id,
payload={"from": "triggered", "to": "acknowledged"}
)
assert result["id"] == event_id
assert result["incident_id"] == incident_id
assert result["event_type"] == "status_changed"
assert result["actor_user_id"] == user_id
assert result["payload"] == {"from": "triggered", "to": "acknowledged"}
assert result["created_at"] is not None
async def test_add_event_allows_null_actor(self, db_conn: asyncpg.Connection) -> None:
"""add_event allows null actor_user_id (system events)."""
incident_id, _, repo = await self._setup_incident(db_conn)
result = await repo.add_event(
uuid4(), incident_id, "auto_escalated",
actor_user_id=None,
payload={"reason": "Unacknowledged after 30 minutes"}
)
assert result["actor_user_id"] is None
async def test_add_event_allows_null_payload(self, db_conn: asyncpg.Connection) -> None:
"""add_event allows null payload."""
incident_id, user_id, repo = await self._setup_incident(db_conn)
result = await repo.add_event(
uuid4(), incident_id, "viewed",
actor_user_id=user_id,
payload=None
)
assert result["payload"] is None
async def test_get_events_returns_all_incident_events(self, db_conn: asyncpg.Connection) -> None:
"""get_events returns all events for an incident."""
incident_id, user_id, repo = await self._setup_incident(db_conn)
await repo.add_event(uuid4(), incident_id, "created", user_id, {"title": "Test"})
await repo.add_event(uuid4(), incident_id, "status_changed", user_id, {"to": "acked"})
await repo.add_event(uuid4(), incident_id, "comment_added", user_id, {"text": "Working on it"})
result = await repo.get_events(incident_id)
assert len(result) == 3
event_types = [e["event_type"] for e in result]
assert event_types == ["created", "status_changed", "comment_added"]
async def test_get_events_orders_by_created_at(self, db_conn: asyncpg.Connection) -> None:
"""get_events returns events in chronological order."""
incident_id, user_id, repo = await self._setup_incident(db_conn)
await repo.add_event(uuid4(), incident_id, "first", user_id, None)
await repo.add_event(uuid4(), incident_id, "second", user_id, None)
await repo.add_event(uuid4(), incident_id, "third", user_id, None)
result = await repo.get_events(incident_id)
assert result[0]["event_type"] == "first"
assert result[1]["event_type"] == "second"
assert result[2]["event_type"] == "third"
async def test_get_events_returns_empty_for_no_events(self, db_conn: asyncpg.Connection) -> None:
"""get_events returns empty list for incident with no events."""
incident_id, _, repo = await self._setup_incident(db_conn)
result = await repo.get_events(incident_id)
assert result == []
async def test_event_requires_valid_incident_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""incident_events.incident_id must reference existing incident."""
incident_id, user_id, repo = await self._setup_incident(db_conn)
with pytest.raises(asyncpg.ForeignKeyViolationError):
await repo.add_event(uuid4(), uuid4(), "test", user_id, None)
async def test_event_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""incident_events.actor_user_id must reference existing user if not null."""
incident_id, _, repo = await self._setup_incident(db_conn)
with pytest.raises(asyncpg.ForeignKeyViolationError):
await repo.add_event(uuid4(), incident_id, "test", uuid4(), None)

View File

@@ -0,0 +1,362 @@
"""Tests for NotificationRepository."""
from datetime import UTC, datetime
from uuid import uuid4
import asyncpg
import pytest
from app.repositories.incident import IncidentRepository
from app.repositories.notification import NotificationRepository
from app.repositories.org import OrgRepository
from app.repositories.service import ServiceRepository
class TestNotificationTargetRepository:
"""Tests for notification targets per SPECS.md notification_targets table."""
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
"""Helper to create an org."""
org_repo = OrgRepository(conn)
org_id = uuid4()
await org_repo.create(org_id, f"Org {slug}", slug)
return org_id
async def test_create_target_returns_target_data(self, db_conn: asyncpg.Connection) -> None:
"""Creating a notification target returns the target data."""
org_id = await self._create_org(db_conn, "target-org")
repo = NotificationRepository(db_conn)
target_id = uuid4()
result = await repo.create_target(
target_id, org_id, "Slack Alerts",
target_type="webhook",
webhook_url="https://hooks.slack.com/services/xxx",
enabled=True
)
assert result["id"] == target_id
assert result["org_id"] == org_id
assert result["name"] == "Slack Alerts"
assert result["target_type"] == "webhook"
assert result["webhook_url"] == "https://hooks.slack.com/services/xxx"
assert result["enabled"] is True
assert result["created_at"] is not None
async def test_create_target_type_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
"""Target type must be webhook, email, or slack per SPECS.md."""
org_id = await self._create_org(db_conn, "type-org")
repo = NotificationRepository(db_conn)
# Valid types
for target_type in ["webhook", "email", "slack"]:
result = await repo.create_target(
uuid4(), org_id, f"{target_type} target",
target_type=target_type
)
assert result["target_type"] == target_type
# Invalid type
with pytest.raises(asyncpg.CheckViolationError):
await repo.create_target(
uuid4(), org_id, "Invalid",
target_type="sms"
)
async def test_create_target_webhook_url_optional(self, db_conn: asyncpg.Connection) -> None:
"""webhook_url is optional (for email/slack types)."""
org_id = await self._create_org(db_conn, "optional-url-org")
repo = NotificationRepository(db_conn)
result = await repo.create_target(
uuid4(), org_id, "Email Alerts",
target_type="email",
webhook_url=None
)
assert result["webhook_url"] is None
async def test_create_target_enabled_defaults_to_true(self, db_conn: asyncpg.Connection) -> None:
"""enabled defaults to True."""
org_id = await self._create_org(db_conn, "default-enabled-org")
repo = NotificationRepository(db_conn)
result = await repo.create_target(
uuid4(), org_id, "Default Enabled",
target_type="webhook"
)
assert result["enabled"] is True
async def test_get_target_by_id_returns_target(self, db_conn: asyncpg.Connection) -> None:
"""get_target_by_id returns the correct target."""
org_id = await self._create_org(db_conn, "getbyid-target-org")
repo = NotificationRepository(db_conn)
target_id = uuid4()
await repo.create_target(target_id, org_id, "My Target", "webhook")
result = await repo.get_target_by_id(target_id)
assert result is not None
assert result["id"] == target_id
assert result["name"] == "My Target"
async def test_get_target_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_target_by_id returns None for non-existent target."""
repo = NotificationRepository(db_conn)
result = await repo.get_target_by_id(uuid4())
assert result is None
async def test_get_targets_by_org_returns_all_targets(self, db_conn: asyncpg.Connection) -> None:
"""get_targets_by_org returns all targets for an organization."""
org_id = await self._create_org(db_conn, "multi-target-org")
repo = NotificationRepository(db_conn)
await repo.create_target(uuid4(), org_id, "Target A", "webhook")
await repo.create_target(uuid4(), org_id, "Target B", "email")
await repo.create_target(uuid4(), org_id, "Target C", "slack")
result = await repo.get_targets_by_org(org_id)
assert len(result) == 3
names = {t["name"] for t in result}
assert names == {"Target A", "Target B", "Target C"}
async def test_get_targets_by_org_filters_enabled(self, db_conn: asyncpg.Connection) -> None:
"""get_targets_by_org can filter to only enabled targets."""
org_id = await self._create_org(db_conn, "enabled-filter-org")
repo = NotificationRepository(db_conn)
await repo.create_target(uuid4(), org_id, "Enabled", "webhook", enabled=True)
await repo.create_target(uuid4(), org_id, "Disabled", "webhook", enabled=False)
result = await repo.get_targets_by_org(org_id, enabled_only=True)
assert len(result) == 1
assert result[0]["name"] == "Enabled"
async def test_get_targets_by_org_tenant_isolation(self, db_conn: asyncpg.Connection) -> None:
"""get_targets_by_org only returns targets for the specified org."""
org1 = await self._create_org(db_conn, "isolated-target-org-1")
org2 = await self._create_org(db_conn, "isolated-target-org-2")
repo = NotificationRepository(db_conn)
await repo.create_target(uuid4(), org1, "Org1 Target", "webhook")
await repo.create_target(uuid4(), org2, "Org2 Target", "webhook")
result = await repo.get_targets_by_org(org1)
assert len(result) == 1
assert result[0]["name"] == "Org1 Target"
async def test_update_target_updates_fields(self, db_conn: asyncpg.Connection) -> None:
"""update_target updates the specified fields."""
org_id = await self._create_org(db_conn, "update-target-org")
repo = NotificationRepository(db_conn)
target_id = uuid4()
await repo.create_target(target_id, org_id, "Original", "webhook", enabled=True)
result = await repo.update_target(target_id, name="Updated", enabled=False)
assert result is not None
assert result["name"] == "Updated"
assert result["enabled"] is False
async def test_update_target_partial_update(self, db_conn: asyncpg.Connection) -> None:
"""update_target only updates provided fields."""
org_id = await self._create_org(db_conn, "partial-update-org")
repo = NotificationRepository(db_conn)
target_id = uuid4()
await repo.create_target(
target_id, org_id, "Original Name", "webhook",
webhook_url="https://original.com", enabled=True
)
result = await repo.update_target(target_id, name="New Name")
assert result["name"] == "New Name"
assert result["webhook_url"] == "https://original.com"
assert result["enabled"] is True
async def test_delete_target_removes_target(self, db_conn: asyncpg.Connection) -> None:
"""delete_target removes the target."""
org_id = await self._create_org(db_conn, "delete-target-org")
repo = NotificationRepository(db_conn)
target_id = uuid4()
await repo.create_target(target_id, org_id, "To Delete", "webhook")
result = await repo.delete_target(target_id)
assert result is True
assert await repo.get_target_by_id(target_id) is None
async def test_delete_target_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""delete_target returns False for non-existent target."""
repo = NotificationRepository(db_conn)
result = await repo.delete_target(uuid4())
assert result is False
async def test_target_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""notification_targets.org_id must reference existing org."""
repo = NotificationRepository(db_conn)
with pytest.raises(asyncpg.ForeignKeyViolationError):
await repo.create_target(uuid4(), uuid4(), "Orphan Target", "webhook")
class TestNotificationAttemptRepository:
"""Tests for notification attempts per SPECS.md notification_attempts table."""
async def _setup_incident_and_target(self, conn: asyncpg.Connection) -> tuple[uuid4, uuid4, NotificationRepository]:
"""Helper to create org, service, incident, and notification target."""
org_repo = OrgRepository(conn)
service_repo = ServiceRepository(conn)
incident_repo = IncidentRepository(conn)
notification_repo = NotificationRepository(conn)
org_id = uuid4()
service_id = uuid4()
incident_id = uuid4()
target_id = uuid4()
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
await service_repo.create(service_id, org_id, "Test Service", f"test-svc-{uuid4().hex[:8]}")
await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium")
await notification_repo.create_target(target_id, org_id, "Test Target", "webhook")
return incident_id, target_id, notification_repo
async def test_create_attempt_returns_attempt_data(self, db_conn: asyncpg.Connection) -> None:
"""Creating a notification attempt returns the attempt data."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
attempt_id = uuid4()
result = await repo.create_attempt(attempt_id, incident_id, target_id)
assert result["id"] == attempt_id
assert result["incident_id"] == incident_id
assert result["target_id"] == target_id
assert result["status"] == "pending"
assert result["error"] is None
assert result["sent_at"] is None
assert result["created_at"] is not None
async def test_create_attempt_idempotent(self, db_conn: asyncpg.Connection) -> None:
"""create_attempt is idempotent per SPECS.md (unique constraint on incident+target)."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
# First attempt
result1 = await repo.create_attempt(uuid4(), incident_id, target_id)
# Second attempt with same incident+target
result2 = await repo.create_attempt(uuid4(), incident_id, target_id)
# Should return the same attempt
assert result1["id"] == result2["id"]
async def test_get_attempt_returns_attempt(self, db_conn: asyncpg.Connection) -> None:
"""get_attempt returns the attempt for incident and target."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
await repo.create_attempt(uuid4(), incident_id, target_id)
result = await repo.get_attempt(incident_id, target_id)
assert result is not None
assert result["incident_id"] == incident_id
assert result["target_id"] == target_id
async def test_get_attempt_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_attempt returns None for non-existent attempt."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
result = await repo.get_attempt(incident_id, target_id)
assert result is None
async def test_update_attempt_success_sets_sent_status(self, db_conn: asyncpg.Connection) -> None:
"""update_attempt_success marks attempt as sent."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
sent_at = datetime.now(UTC)
result = await repo.update_attempt_success(attempt["id"], sent_at)
assert result is not None
assert result["status"] == "sent"
assert result["sent_at"] is not None
assert result["error"] is None
async def test_update_attempt_failure_sets_failed_status(self, db_conn: asyncpg.Connection) -> None:
"""update_attempt_failure marks attempt as failed with error."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
result = await repo.update_attempt_failure(attempt["id"], "Connection timeout")
assert result is not None
assert result["status"] == "failed"
assert result["error"] == "Connection timeout"
async def test_attempt_status_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
"""Attempt status must be pending, sent, or failed per SPECS.md."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
# Create with default 'pending' status - valid
result = await repo.create_attempt(uuid4(), incident_id, target_id)
assert result["status"] == "pending"
# Transition to 'sent' - valid
result = await repo.update_attempt_success(result["id"], datetime.now(UTC))
assert result["status"] == "sent"
async def test_get_pending_attempts_returns_pending_with_target_info(self, db_conn: asyncpg.Connection) -> None:
"""get_pending_attempts returns pending attempts with target details."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
await repo.create_attempt(uuid4(), incident_id, target_id)
result = await repo.get_pending_attempts(incident_id)
assert len(result) == 1
assert result[0]["status"] == "pending"
assert result[0]["target_id"] == target_id
assert "target_type" in result[0]
assert "target_name" in result[0]
async def test_get_pending_attempts_excludes_sent(self, db_conn: asyncpg.Connection) -> None:
"""get_pending_attempts excludes sent attempts."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
await repo.update_attempt_success(attempt["id"], datetime.now(UTC))
result = await repo.get_pending_attempts(incident_id)
assert len(result) == 0
async def test_get_pending_attempts_excludes_failed(self, db_conn: asyncpg.Connection) -> None:
"""get_pending_attempts excludes failed attempts."""
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
await repo.update_attempt_failure(attempt["id"], "Error")
result = await repo.get_pending_attempts(incident_id)
assert len(result) == 0
async def test_attempt_requires_valid_incident_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""notification_attempts.incident_id must reference existing incident."""
_, target_id, repo = await self._setup_incident_and_target(db_conn)
with pytest.raises(asyncpg.ForeignKeyViolationError):
await repo.create_attempt(uuid4(), uuid4(), target_id)
async def test_attempt_requires_valid_target_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""notification_attempts.target_id must reference existing target."""
incident_id, _, repo = await self._setup_incident_and_target(db_conn)
with pytest.raises(asyncpg.ForeignKeyViolationError):
await repo.create_attempt(uuid4(), incident_id, uuid4())

View File

@@ -0,0 +1,250 @@
"""Tests for OrgRepository."""
from uuid import uuid4
import asyncpg
import pytest
from app.repositories.org import OrgRepository
from app.repositories.user import UserRepository
class TestOrgRepository:
"""Tests for OrgRepository conforming to SPECS.md."""
async def test_create_org_returns_org_data(self, db_conn: asyncpg.Connection) -> None:
"""Creating an org returns the org data."""
repo = OrgRepository(db_conn)
org_id = uuid4()
name = "Test Organization"
slug = "test-org"
result = await repo.create(org_id, name, slug)
assert result["id"] == org_id
assert result["name"] == name
assert result["slug"] == slug
assert result["created_at"] is not None
async def test_create_org_slug_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
"""Org slug uniqueness constraint per SPECS.md orgs table."""
repo = OrgRepository(db_conn)
slug = "unique-slug"
await repo.create(uuid4(), "Org One", slug)
with pytest.raises(asyncpg.UniqueViolationError):
await repo.create(uuid4(), "Org Two", slug)
async def test_get_by_id_returns_org(self, db_conn: asyncpg.Connection) -> None:
"""get_by_id returns the correct organization."""
repo = OrgRepository(db_conn)
org_id = uuid4()
await repo.create(org_id, "My Org", "my-org")
result = await repo.get_by_id(org_id)
assert result is not None
assert result["id"] == org_id
assert result["name"] == "My Org"
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_id returns None for non-existent org."""
repo = OrgRepository(db_conn)
result = await repo.get_by_id(uuid4())
assert result is None
async def test_get_by_slug_returns_org(self, db_conn: asyncpg.Connection) -> None:
"""get_by_slug returns the correct organization."""
repo = OrgRepository(db_conn)
org_id = uuid4()
slug = "slug-lookup"
await repo.create(org_id, "Slug Test", slug)
result = await repo.get_by_slug(slug)
assert result is not None
assert result["id"] == org_id
async def test_get_by_slug_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_slug returns None for non-existent slug."""
repo = OrgRepository(db_conn)
result = await repo.get_by_slug("nonexistent-slug")
assert result is None
async def test_slug_exists_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None:
"""slug_exists returns True when slug exists."""
repo = OrgRepository(db_conn)
slug = "existing-slug"
await repo.create(uuid4(), "Existing Org", slug)
result = await repo.slug_exists(slug)
assert result is True
async def test_slug_exists_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None:
"""slug_exists returns False when slug doesn't exist."""
repo = OrgRepository(db_conn)
result = await repo.slug_exists("no-such-slug")
assert result is False
class TestOrgMembership:
"""Tests for org membership operations per SPECS.md org_members table."""
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
"""Helper to create a user."""
user_repo = UserRepository(conn)
user_id = uuid4()
await user_repo.create(user_id, email, "hash")
return user_id
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
"""Helper to create an org."""
org_repo = OrgRepository(conn)
org_id = uuid4()
await org_repo.create(org_id, f"Org {slug}", slug)
return org_id
async def test_add_member_creates_membership(self, db_conn: asyncpg.Connection) -> None:
"""add_member creates a membership record."""
user_id = await self._create_user(db_conn, "member@example.com")
org_id = await self._create_org(db_conn, "member-org")
repo = OrgRepository(db_conn)
result = await repo.add_member(uuid4(), user_id, org_id, "member")
assert result["user_id"] == user_id
assert result["org_id"] == org_id
assert result["role"] == "member"
assert result["created_at"] is not None
async def test_add_member_role_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
"""Role must be admin, member, or viewer per SPECS.md."""
org_id = await self._create_org(db_conn, "role-test-org")
repo = OrgRepository(db_conn)
# Valid roles should work
for role in ["admin", "member", "viewer"]:
member_id = uuid4()
# Need a new user for each since user+org must be unique
new_user_id = await self._create_user(db_conn, f"{role}@example.com")
result = await repo.add_member(member_id, new_user_id, org_id, role)
assert result["role"] == role
# Invalid role should fail
another_user = await self._create_user(db_conn, "invalid_role@example.com")
with pytest.raises(asyncpg.CheckViolationError):
await repo.add_member(uuid4(), another_user, org_id, "superuser")
async def test_add_member_user_org_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
"""User can only be member of an org once (unique constraint)."""
user_id = await self._create_user(db_conn, "unique_member@example.com")
org_id = await self._create_org(db_conn, "unique-member-org")
repo = OrgRepository(db_conn)
await repo.add_member(uuid4(), user_id, org_id, "member")
with pytest.raises(asyncpg.UniqueViolationError):
await repo.add_member(uuid4(), user_id, org_id, "admin")
async def test_get_member_returns_membership(self, db_conn: asyncpg.Connection) -> None:
"""get_member returns the membership for user and org."""
user_id = await self._create_user(db_conn, "get_member@example.com")
org_id = await self._create_org(db_conn, "get-member-org")
repo = OrgRepository(db_conn)
await repo.add_member(uuid4(), user_id, org_id, "admin")
result = await repo.get_member(user_id, org_id)
assert result is not None
assert result["user_id"] == user_id
assert result["org_id"] == org_id
assert result["role"] == "admin"
async def test_get_member_returns_none_for_nonmember(self, db_conn: asyncpg.Connection) -> None:
"""get_member returns None if user is not a member."""
user_id = await self._create_user(db_conn, "nonmember@example.com")
org_id = await self._create_org(db_conn, "nonmember-org")
repo = OrgRepository(db_conn)
result = await repo.get_member(user_id, org_id)
assert result is None
async def test_get_members_returns_all_org_members(self, db_conn: asyncpg.Connection) -> None:
"""get_members returns all members with their emails."""
org_id = await self._create_org(db_conn, "all-members-org")
user1 = await self._create_user(db_conn, "user1@example.com")
user2 = await self._create_user(db_conn, "user2@example.com")
user3 = await self._create_user(db_conn, "user3@example.com")
repo = OrgRepository(db_conn)
await repo.add_member(uuid4(), user1, org_id, "admin")
await repo.add_member(uuid4(), user2, org_id, "member")
await repo.add_member(uuid4(), user3, org_id, "viewer")
result = await repo.get_members(org_id)
assert len(result) == 3
emails = {m["email"] for m in result}
assert emails == {"user1@example.com", "user2@example.com", "user3@example.com"}
async def test_get_members_returns_empty_list_for_no_members(self, db_conn: asyncpg.Connection) -> None:
"""get_members returns empty list for org with no members."""
org_id = await self._create_org(db_conn, "empty-org")
repo = OrgRepository(db_conn)
result = await repo.get_members(org_id)
assert result == []
async def test_get_user_orgs_returns_all_user_memberships(self, db_conn: asyncpg.Connection) -> None:
"""get_user_orgs returns all orgs a user belongs to with their role."""
user_id = await self._create_user(db_conn, "multi_org@example.com")
org1 = await self._create_org(db_conn, "user-org-1")
org2 = await self._create_org(db_conn, "user-org-2")
repo = OrgRepository(db_conn)
await repo.add_member(uuid4(), user_id, org1, "admin")
await repo.add_member(uuid4(), user_id, org2, "member")
result = await repo.get_user_orgs(user_id)
assert len(result) == 2
slugs = {o["slug"] for o in result}
assert slugs == {"user-org-1", "user-org-2"}
# Check role is included
roles = {o["role"] for o in result}
assert roles == {"admin", "member"}
async def test_get_user_orgs_returns_empty_for_no_memberships(self, db_conn: asyncpg.Connection) -> None:
"""get_user_orgs returns empty list for user with no memberships."""
user_id = await self._create_user(db_conn, "no_orgs@example.com")
repo = OrgRepository(db_conn)
result = await repo.get_user_orgs(user_id)
assert result == []
async def test_member_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""org_members.user_id must reference existing user."""
org_id = await self._create_org(db_conn, "fk-test-org")
repo = OrgRepository(db_conn)
with pytest.raises(asyncpg.ForeignKeyViolationError):
await repo.add_member(uuid4(), uuid4(), org_id, "member")
async def test_member_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""org_members.org_id must reference existing org."""
user_id = await self._create_user(db_conn, "fk_user@example.com")
repo = OrgRepository(db_conn)
with pytest.raises(asyncpg.ForeignKeyViolationError):
await repo.add_member(uuid4(), user_id, uuid4(), "member")

View File

@@ -0,0 +1,788 @@
"""Tests for RefreshTokenRepository with security features."""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import asyncpg
import pytest
from app.repositories.org import OrgRepository
from app.repositories.refresh_token import RefreshTokenRepository
from app.repositories.user import UserRepository
class TestRefreshTokenRepository:
"""Tests for basic RefreshTokenRepository operations."""
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
"""Helper to create a user."""
user_repo = UserRepository(conn)
user_id = uuid4()
await user_repo.create(user_id, email, "hash")
return user_id
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
"""Helper to create an org."""
org_repo = OrgRepository(conn)
org_id = uuid4()
await org_repo.create(org_id, f"Org {slug}", slug)
return org_id
async def test_create_token_returns_token_data(self, db_conn: asyncpg.Connection) -> None:
"""Creating a refresh token returns the token data including rotated_to."""
user_id = await self._create_user(db_conn, "token_create@example.com")
org_id = await self._create_org(db_conn, "token-create-org")
repo = RefreshTokenRepository(db_conn)
token_id = uuid4()
token_hash = "sha256_hashed_token_value"
expires_at = datetime.now(UTC) + timedelta(days=30)
result = await repo.create(token_id, user_id, token_hash, org_id, expires_at)
assert result["id"] == token_id
assert result["user_id"] == user_id
assert result["token_hash"] == token_hash
assert result["active_org_id"] == org_id
assert result["expires_at"] is not None
assert result["revoked_at"] is None
assert result["rotated_to"] is None # New field
assert result["created_at"] is not None
async def test_token_hash_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
"""Token hash uniqueness constraint per SPECS.md refresh_tokens table."""
user_id = await self._create_user(db_conn, "unique_hash@example.com")
org_id = await self._create_org(db_conn, "unique-hash-org")
repo = RefreshTokenRepository(db_conn)
token_hash = "duplicate_hash_value"
expires_at = datetime.now(UTC) + timedelta(days=30)
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
with pytest.raises(asyncpg.UniqueViolationError):
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
async def test_get_by_hash_returns_token(self, db_conn: asyncpg.Connection) -> None:
"""get_by_hash returns the correct token (even if revoked/expired)."""
user_id = await self._create_user(db_conn, "get_hash@example.com")
org_id = await self._create_org(db_conn, "get-hash-org")
repo = RefreshTokenRepository(db_conn)
token_id = uuid4()
token_hash = "lookup_by_hash_value"
expires_at = datetime.now(UTC) + timedelta(days=30)
await repo.create(token_id, user_id, token_hash, org_id, expires_at)
result = await repo.get_by_hash(token_hash)
assert result is not None
assert result["id"] == token_id
assert result["token_hash"] == token_hash
async def test_get_by_hash_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_hash returns None for non-existent hash."""
repo = RefreshTokenRepository(db_conn)
result = await repo.get_by_hash("nonexistent_hash")
assert result is None
class TestGetValidByHash:
"""Tests for get_valid_by_hash with defense-in-depth validation."""
async def _setup_token(
self, conn: asyncpg.Connection, suffix: str = ""
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
"""Helper to create user, org, and token."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
token_id = uuid4()
token_hash = f"token_hash_{uuid4().hex[:8]}{suffix}"
await user_repo.create(user_id, f"user_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
return token_id, user_id, org_id, token_hash, token_repo
async def test_get_valid_returns_valid_token(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns token if not expired, not revoked, not rotated."""
_, _, _, token_hash, repo = await self._setup_token(db_conn)
result = await repo.get_valid_by_hash(token_hash)
assert result is not None
assert result["token_hash"] == token_hash
async def test_get_valid_returns_none_for_expired(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns None for expired token."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "expired@example.com", "hash")
await org_repo.create(org_id, "Org", "expired-org")
token_hash = "expired_token_hash"
expires_at = datetime.now(UTC) - timedelta(days=1) # Already expired
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
result = await repo.get_valid_by_hash(token_hash)
assert result is None
async def test_get_valid_returns_none_for_revoked(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns None for revoked token."""
token_id, _, _, token_hash, repo = await self._setup_token(db_conn, "_revoked")
await repo.revoke(token_id)
result = await repo.get_valid_by_hash(token_hash)
assert result is None
async def test_get_valid_returns_none_for_rotated(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns None for already-rotated token."""
_, user_id, org_id, old_hash, repo = await self._setup_token(db_conn, "_rotated")
# Rotate the token
new_hash = f"new_token_{uuid4().hex[:8]}"
new_expires = datetime.now(UTC) + timedelta(days=30)
await repo.rotate(old_hash, uuid4(), new_hash, new_expires)
# Old token should no longer be valid
result = await repo.get_valid_by_hash(old_hash)
assert result is None
async def test_get_valid_with_user_id_validation(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash validates user_id when provided (defense-in-depth)."""
_, user_id, _, token_hash, repo = await self._setup_token(db_conn, "_user_check")
# Correct user_id should work
result = await repo.get_valid_by_hash(token_hash, user_id=user_id)
assert result is not None
# Wrong user_id should return None
wrong_user_id = uuid4()
result = await repo.get_valid_by_hash(token_hash, user_id=wrong_user_id)
assert result is None
async def test_get_valid_with_org_id_validation(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash validates active_org_id when provided (defense-in-depth)."""
_, _, org_id, token_hash, repo = await self._setup_token(db_conn, "_org_check")
# Correct org_id should work
result = await repo.get_valid_by_hash(token_hash, active_org_id=org_id)
assert result is not None
# Wrong org_id should return None
wrong_org_id = uuid4()
result = await repo.get_valid_by_hash(token_hash, active_org_id=wrong_org_id)
assert result is None
async def test_get_valid_with_both_user_and_org_validation(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash validates both user_id and active_org_id together."""
_, user_id, org_id, token_hash, repo = await self._setup_token(db_conn, "_both")
# Both correct should work
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=org_id)
assert result is not None
# Either wrong should fail
result = await repo.get_valid_by_hash(token_hash, user_id=uuid4(), active_org_id=org_id)
assert result is None
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=uuid4())
assert result is None
class TestAtomicRotation:
"""Tests for atomic token rotation per SPECS.md."""
async def _setup_token(
self, conn: asyncpg.Connection
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
"""Helper to create user, org, and token."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
token_id = uuid4()
token_hash = f"rotate_token_{uuid4().hex[:8]}"
await user_repo.create(user_id, f"rotate_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Rotate Org", f"rotate-org-{uuid4().hex[:8]}")
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
return token_id, user_id, org_id, token_hash, token_repo
async def test_rotate_creates_new_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() creates a new token and returns it."""
old_id, user_id, org_id, old_hash, repo = await self._setup_token(db_conn)
new_id = uuid4()
new_hash = f"new_rotated_{uuid4().hex[:8]}"
new_expires = datetime.now(UTC) + timedelta(days=30)
result = await repo.rotate(old_hash, new_id, new_hash, new_expires)
assert result is not None
assert result["id"] == new_id
assert result["token_hash"] == new_hash
assert result["user_id"] == user_id
assert result["active_org_id"] == org_id
async def test_rotate_marks_old_token_as_rotated(self, db_conn: asyncpg.Connection) -> None:
"""rotate() sets rotated_to on the old token (not revoked_at)."""
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
new_id = uuid4()
new_hash = f"new_{uuid4().hex[:8]}"
new_expires = datetime.now(UTC) + timedelta(days=30)
await repo.rotate(old_hash, new_id, new_hash, new_expires)
# Check old token state
old_token = await repo.get_by_hash(old_hash)
assert old_token["rotated_to"] == new_id
assert old_token["revoked_at"] is None # Not revoked, just rotated
async def test_rotate_fails_for_invalid_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token is invalid."""
_, _, _, _, repo = await self._setup_token(db_conn)
result = await repo.rotate(
"nonexistent_hash",
uuid4(),
f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30),
)
assert result is None
async def test_rotate_fails_for_expired_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token is expired."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "exp_rotate@example.com", "hash")
await org_repo.create(org_id, "Org", "exp-rotate-org")
old_hash = "expired_for_rotation"
await repo.create(
uuid4(), user_id, old_hash, org_id,
datetime.now(UTC) - timedelta(days=1) # Already expired
)
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result is None
async def test_rotate_fails_for_revoked_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token is revoked."""
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
await repo.revoke(old_id)
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result is None
async def test_rotate_fails_for_already_rotated_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token was already rotated."""
_, _, _, old_hash, repo = await self._setup_token(db_conn)
# First rotation should succeed
result1 = await repo.rotate(
old_hash, uuid4(), f"new1_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result1 is not None
# Second rotation of same token should fail
result2 = await repo.rotate(
old_hash, uuid4(), f"new2_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result2 is None
async def test_rotate_with_org_switch(self, db_conn: asyncpg.Connection) -> None:
"""rotate() can change active_org_id (for org-switch flow)."""
_, user_id, old_org_id, old_hash, repo = await self._setup_token(db_conn)
# Create a new org for the user to switch to
org_repo = OrgRepository(db_conn)
new_org_id = uuid4()
await org_repo.create(new_org_id, "New Org", f"new-org-{uuid4().hex[:8]}")
new_hash = f"switched_{uuid4().hex[:8]}"
result = await repo.rotate(
old_hash, uuid4(), new_hash,
datetime.now(UTC) + timedelta(days=30),
new_active_org_id=new_org_id # Switch org
)
assert result is not None
assert result["active_org_id"] == new_org_id
assert result["active_org_id"] != old_org_id
async def test_rotate_validates_expected_user_id(self, db_conn: asyncpg.Connection) -> None:
"""rotate() fails if expected_user_id doesn't match token's user."""
_, user_id, _, old_hash, repo = await self._setup_token(db_conn)
# Wrong user should fail
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30),
expected_user_id=uuid4() # Wrong user
)
assert result is None
# Correct user should work
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30),
expected_user_id=user_id # Correct user
)
assert result is not None
class TestTokenReuseDetection:
"""Tests for detecting token reuse (stolen token attacks)."""
async def _setup_rotated_token(
self, conn: asyncpg.Connection
) -> tuple[uuid4, str, str, RefreshTokenRepository]:
"""Create a token and rotate it, returning old and new hashes."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, f"reuse_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Reuse Org", f"reuse-org-{uuid4().hex[:8]}")
old_hash = f"old_token_{uuid4().hex[:8]}"
expires_at = datetime.now(UTC) + timedelta(days=30)
old_token = await token_repo.create(uuid4(), user_id, old_hash, org_id, expires_at)
new_hash = f"new_token_{uuid4().hex[:8]}"
await token_repo.rotate(old_hash, uuid4(), new_hash, expires_at)
return old_token["id"], old_hash, new_hash, token_repo
async def test_check_token_reuse_detects_rotated_token(self, db_conn: asyncpg.Connection) -> None:
"""check_token_reuse returns token if it has been rotated."""
old_id, old_hash, _, repo = await self._setup_rotated_token(db_conn)
result = await repo.check_token_reuse(old_hash)
assert result is not None
assert result["id"] == old_id
assert result["rotated_to"] is not None
async def test_check_token_reuse_returns_none_for_active_token(self, db_conn: asyncpg.Connection) -> None:
"""check_token_reuse returns None for token that hasn't been rotated."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "active@example.com", "hash")
await org_repo.create(org_id, "Org", "active-org")
token_hash = "active_token_hash"
await repo.create(
uuid4(), user_id, token_hash, org_id,
datetime.now(UTC) + timedelta(days=30)
)
result = await repo.check_token_reuse(token_hash)
assert result is None
async def test_check_token_reuse_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""check_token_reuse returns None for non-existent token."""
repo = RefreshTokenRepository(db_conn)
result = await repo.check_token_reuse("nonexistent_hash")
assert result is None
class TestTokenChainRevocation:
"""Tests for revoking entire token chains (breach response)."""
async def _setup_token_chain(
self, conn: asyncpg.Connection, chain_length: int = 3
) -> tuple[list[uuid4], list[str], uuid4, RefreshTokenRepository]:
"""Create a chain of rotated tokens."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, f"chain_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Chain Org", f"chain-org-{uuid4().hex[:8]}")
token_ids = []
token_hashes = []
expires_at = datetime.now(UTC) + timedelta(days=30)
# Create first token
first_hash = f"chain_token_0_{uuid4().hex[:8]}"
first_token = await token_repo.create(uuid4(), user_id, first_hash, org_id, expires_at)
token_ids.append(first_token["id"])
token_hashes.append(first_hash)
# Rotate to create chain
current_hash = first_hash
for i in range(1, chain_length):
new_hash = f"chain_token_{i}_{uuid4().hex[:8]}"
new_id = uuid4()
await token_repo.rotate(current_hash, new_id, new_hash, expires_at)
token_ids.append(new_id)
token_hashes.append(new_hash)
current_hash = new_hash
return token_ids, token_hashes, user_id, token_repo
async def test_revoke_token_chain_revokes_all_in_chain(self, db_conn: asyncpg.Connection) -> None:
"""revoke_token_chain revokes the token and all its rotations."""
token_ids, token_hashes, _, repo = await self._setup_token_chain(db_conn, chain_length=3)
# Revoke starting from the first token
count = await repo.revoke_token_chain(token_ids[0])
# Should revoke all 3 tokens in the chain
# But note: only the last one wasn't already "consumed" by rotation
# Let's check that revoke was called on all that were eligible
assert count >= 1 # At least the leaf token
# Verify the leaf token is revoked
leaf_token = await repo.get_by_hash(token_hashes[-1])
assert leaf_token["revoked_at"] is not None
async def test_revoke_token_chain_returns_count(self, db_conn: asyncpg.Connection) -> None:
"""revoke_token_chain returns count of actually revoked tokens."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "single@example.com", "hash")
await org_repo.create(org_id, "Single Org", "single-org")
token_hash = "single_chain_token"
token = await repo.create(
uuid4(), user_id, token_hash, org_id,
datetime.now(UTC) + timedelta(days=30)
)
count = await repo.revoke_token_chain(token["id"])
assert count == 1
class TestTokenRevocation:
"""Tests for token revocation methods."""
async def _setup_token(self, conn: asyncpg.Connection) -> tuple[uuid4, str, RefreshTokenRepository]:
"""Helper to create user, org, and token."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
token_id = uuid4()
token_hash = f"revoke_token_{uuid4().hex[:8]}"
await user_repo.create(user_id, f"revoke_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Revoke Org", f"revoke-org-{uuid4().hex[:8]}")
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
return token_id, token_hash, token_repo
async def test_revoke_sets_revoked_at(self, db_conn: asyncpg.Connection) -> None:
"""revoke() sets the revoked_at timestamp."""
token_id, token_hash, repo = await self._setup_token(db_conn)
result = await repo.revoke(token_id)
assert result is True
token = await repo.get_by_hash(token_hash)
assert token["revoked_at"] is not None
async def test_revoke_returns_true_on_success(self, db_conn: asyncpg.Connection) -> None:
"""revoke() returns True when token is revoked."""
token_id, _, repo = await self._setup_token(db_conn)
result = await repo.revoke(token_id)
assert result is True
async def test_revoke_returns_false_for_already_revoked(self, db_conn: asyncpg.Connection) -> None:
"""revoke() returns False if token already revoked."""
token_id, _, repo = await self._setup_token(db_conn)
await repo.revoke(token_id)
result = await repo.revoke(token_id)
assert result is False
async def test_revoke_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""revoke() returns False for non-existent token."""
_, _, repo = await self._setup_token(db_conn)
result = await repo.revoke(uuid4())
assert result is False
async def test_revoke_by_hash_works(self, db_conn: asyncpg.Connection) -> None:
"""revoke_by_hash() revokes token by hash value."""
_, token_hash, repo = await self._setup_token(db_conn)
result = await repo.revoke_by_hash(token_hash)
assert result is True
token = await repo.get_by_hash(token_hash)
assert token["revoked_at"] is not None
async def test_revoke_by_hash_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""revoke_by_hash() returns False for non-existent hash."""
_, _, repo = await self._setup_token(db_conn)
result = await repo.revoke_by_hash("nonexistent_hash")
assert result is False
class TestRevokeAllForUser:
"""Tests for revoking all tokens for a user."""
async def test_revoke_all_for_user_revokes_all_tokens(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user() revokes all tokens for the user."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "multi_token@example.com", "hash")
await org_repo.create(org_id, "Multi Token Org", "multi-token-org")
# Create multiple tokens
hashes = []
for i in range(3):
token_hash = f"token_{i}_{uuid4().hex[:8]}"
hashes.append(token_hash)
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
result = await token_repo.revoke_all_for_user(user_id)
assert result == 3
for token_hash in hashes:
token = await token_repo.get_valid_by_hash(token_hash)
assert token is None
async def test_revoke_all_for_user_returns_zero_for_no_tokens(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user() returns 0 if user has no tokens."""
user_repo = UserRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
await user_repo.create(user_id, "no_tokens@example.com", "hash")
result = await token_repo.revoke_all_for_user(user_id)
assert result == 0
async def test_revoke_all_for_user_only_affects_user_tokens(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user() doesn't affect other users' tokens."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user1 = uuid4()
user2 = uuid4()
org_id = uuid4()
await user_repo.create(user1, "user1@example.com", "hash")
await user_repo.create(user2, "user2@example.com", "hash")
await org_repo.create(org_id, "Shared Org", "shared-org")
user1_hash = f"user1_token_{uuid4().hex[:8]}"
user2_hash = f"user2_token_{uuid4().hex[:8]}"
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(uuid4(), user1, user1_hash, org_id, expires_at)
await token_repo.create(uuid4(), user2, user2_hash, org_id, expires_at)
await token_repo.revoke_all_for_user(user1)
# User1's token is revoked
assert await token_repo.get_valid_by_hash(user1_hash) is None
# User2's token is still valid
assert await token_repo.get_valid_by_hash(user2_hash) is not None
class TestRevokeAllExcept:
"""Tests for revoking all tokens except current session."""
async def test_revoke_all_except_keeps_specified_token(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user_except() keeps the specified token active."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "except@example.com", "hash")
await org_repo.create(org_id, "Except Org", "except-org")
# Create multiple tokens
expires_at = datetime.now(UTC) + timedelta(days=30)
keep_token_id = uuid4()
keep_hash = f"keep_token_{uuid4().hex[:8]}"
await token_repo.create(keep_token_id, user_id, keep_hash, org_id, expires_at)
other_hashes = []
for i in range(2):
other_hash = f"other_token_{i}_{uuid4().hex[:8]}"
other_hashes.append(other_hash)
await token_repo.create(uuid4(), user_id, other_hash, org_id, expires_at)
result = await token_repo.revoke_all_for_user_except(user_id, keep_token_id)
assert result == 2 # Revoked 2 other tokens
# Keep token is still valid
assert await token_repo.get_valid_by_hash(keep_hash) is not None
# Other tokens are revoked
for other_hash in other_hashes:
assert await token_repo.get_valid_by_hash(other_hash) is None
class TestActiveTokensForUser:
"""Tests for listing active tokens for a user."""
async def test_get_active_tokens_returns_only_active(self, db_conn: asyncpg.Connection) -> None:
"""get_active_tokens_for_user() returns only non-revoked, non-expired, non-rotated."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "active_list@example.com", "hash")
await org_repo.create(org_id, "Active List Org", "active-list-org")
expires_at = datetime.now(UTC) + timedelta(days=30)
expired_at = datetime.now(UTC) - timedelta(days=1)
# Create active token
active_hash = f"active_{uuid4().hex[:8]}"
await token_repo.create(uuid4(), user_id, active_hash, org_id, expires_at)
# Create revoked token
revoked_id = uuid4()
revoked_hash = f"revoked_{uuid4().hex[:8]}"
await token_repo.create(revoked_id, user_id, revoked_hash, org_id, expires_at)
await token_repo.revoke(revoked_id)
# Create expired token
expired_hash = f"expired_{uuid4().hex[:8]}"
await token_repo.create(uuid4(), user_id, expired_hash, org_id, expired_at)
# Create rotated token
rotated_hash = f"rotated_{uuid4().hex[:8]}"
await token_repo.create(uuid4(), user_id, rotated_hash, org_id, expires_at)
await token_repo.rotate(rotated_hash, uuid4(), f"new_{uuid4().hex[:8]}", expires_at)
result = await token_repo.get_active_tokens_for_user(user_id)
# Should only return the active token and the new rotated token
assert len(result) == 2
hashes = {t["token_hash"] for t in result}
assert active_hash in hashes
assert revoked_hash not in hashes
assert expired_hash not in hashes
assert rotated_hash not in hashes
class TestTokenForeignKeys:
"""Tests for refresh token foreign key constraints."""
async def test_token_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""refresh_tokens.user_id must reference existing user."""
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
org_id = uuid4()
await org_repo.create(org_id, "FK Test Org", "fk-test-org")
with pytest.raises(asyncpg.ForeignKeyViolationError):
await token_repo.create(
uuid4(), uuid4(), "orphan_token", org_id,
datetime.now(UTC) + timedelta(days=30)
)
async def test_token_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""refresh_tokens.active_org_id must reference existing org."""
user_repo = UserRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
await user_repo.create(user_id, "fk_org_test@example.com", "hash")
with pytest.raises(asyncpg.ForeignKeyViolationError):
await token_repo.create(
uuid4(), user_id, "orphan_org_token", uuid4(),
datetime.now(UTC) + timedelta(days=30)
)
async def test_token_stores_active_org_id(self, db_conn: asyncpg.Connection) -> None:
"""Token stores active_org_id for org context per SPECS.md."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "active_org@example.com", "hash")
await org_repo.create(org_id, "Active Org", "active-org")
token_hash = f"active_org_token_{uuid4().hex[:8]}"
await token_repo.create(
uuid4(), user_id, token_hash, org_id,
datetime.now(UTC) + timedelta(days=30)
)
token = await token_repo.get_by_hash(token_hash)
assert token["active_org_id"] == org_id

View File

@@ -0,0 +1,201 @@
"""Tests for ServiceRepository."""
from uuid import uuid4
import asyncpg
import pytest
from app.repositories.org import OrgRepository
from app.repositories.service import ServiceRepository
class TestServiceRepository:
"""Tests for ServiceRepository conforming to SPECS.md."""
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
"""Helper to create an org."""
org_repo = OrgRepository(conn)
org_id = uuid4()
await org_repo.create(org_id, f"Org {slug}", slug)
return org_id
async def test_create_service_returns_service_data(self, db_conn: asyncpg.Connection) -> None:
"""Creating a service returns the service data."""
org_id = await self._create_org(db_conn, "service-org")
repo = ServiceRepository(db_conn)
service_id = uuid4()
result = await repo.create(service_id, org_id, "API Gateway", "api-gateway")
assert result["id"] == service_id
assert result["org_id"] == org_id
assert result["name"] == "API Gateway"
assert result["slug"] == "api-gateway"
assert result["created_at"] is not None
async def test_create_service_slug_unique_per_org(self, db_conn: asyncpg.Connection) -> None:
"""Service slug must be unique within an org per SPECS.md."""
org_id = await self._create_org(db_conn, "unique-slug-org")
repo = ServiceRepository(db_conn)
await repo.create(uuid4(), org_id, "Service One", "my-service")
with pytest.raises(asyncpg.UniqueViolationError):
await repo.create(uuid4(), org_id, "Service Two", "my-service")
async def test_same_slug_allowed_in_different_orgs(self, db_conn: asyncpg.Connection) -> None:
"""Same slug can exist in different orgs."""
org1 = await self._create_org(db_conn, "org-one")
org2 = await self._create_org(db_conn, "org-two")
repo = ServiceRepository(db_conn)
slug = "shared-slug"
# Both should succeed
result1 = await repo.create(uuid4(), org1, "Service Org1", slug)
result2 = await repo.create(uuid4(), org2, "Service Org2", slug)
assert result1["slug"] == slug
assert result2["slug"] == slug
assert result1["org_id"] != result2["org_id"]
async def test_get_by_id_returns_service(self, db_conn: asyncpg.Connection) -> None:
"""get_by_id returns the correct service."""
org_id = await self._create_org(db_conn, "get-service-org")
repo = ServiceRepository(db_conn)
service_id = uuid4()
await repo.create(service_id, org_id, "My Service", "my-service")
result = await repo.get_by_id(service_id)
assert result is not None
assert result["id"] == service_id
assert result["name"] == "My Service"
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_id returns None for non-existent service."""
repo = ServiceRepository(db_conn)
result = await repo.get_by_id(uuid4())
assert result is None
async def test_get_by_org_returns_all_org_services(self, db_conn: asyncpg.Connection) -> None:
"""get_by_org returns all services for an organization."""
org_id = await self._create_org(db_conn, "multi-service-org")
repo = ServiceRepository(db_conn)
await repo.create(uuid4(), org_id, "Service A", "service-a")
await repo.create(uuid4(), org_id, "Service B", "service-b")
await repo.create(uuid4(), org_id, "Service C", "service-c")
result = await repo.get_by_org(org_id)
assert len(result) == 3
names = {s["name"] for s in result}
assert names == {"Service A", "Service B", "Service C"}
async def test_get_by_org_returns_empty_for_no_services(self, db_conn: asyncpg.Connection) -> None:
"""get_by_org returns empty list for org with no services."""
org_id = await self._create_org(db_conn, "empty-service-org")
repo = ServiceRepository(db_conn)
result = await repo.get_by_org(org_id)
assert result == []
async def test_get_by_org_only_returns_own_services(self, db_conn: asyncpg.Connection) -> None:
"""get_by_org doesn't return services from other orgs (tenant isolation)."""
org1 = await self._create_org(db_conn, "isolated-org-1")
org2 = await self._create_org(db_conn, "isolated-org-2")
repo = ServiceRepository(db_conn)
await repo.create(uuid4(), org1, "Org1 Service", "org1-service")
await repo.create(uuid4(), org2, "Org2 Service", "org2-service")
result = await repo.get_by_org(org1)
assert len(result) == 1
assert result[0]["name"] == "Org1 Service"
async def test_get_by_slug_returns_service(self, db_conn: asyncpg.Connection) -> None:
"""get_by_slug returns service by org and slug."""
org_id = await self._create_org(db_conn, "slug-lookup-org")
repo = ServiceRepository(db_conn)
service_id = uuid4()
await repo.create(service_id, org_id, "Slug Service", "slug-service")
result = await repo.get_by_slug(org_id, "slug-service")
assert result is not None
assert result["id"] == service_id
async def test_get_by_slug_returns_none_for_wrong_org(self, db_conn: asyncpg.Connection) -> None:
"""get_by_slug returns None if slug exists but in different org."""
org1 = await self._create_org(db_conn, "slug-org-1")
org2 = await self._create_org(db_conn, "slug-org-2")
repo = ServiceRepository(db_conn)
await repo.create(uuid4(), org1, "Service", "the-slug")
result = await repo.get_by_slug(org2, "the-slug")
assert result is None
async def test_get_by_slug_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_slug returns None for non-existent slug."""
org_id = await self._create_org(db_conn, "no-slug-org")
repo = ServiceRepository(db_conn)
result = await repo.get_by_slug(org_id, "nonexistent")
assert result is None
async def test_slug_exists_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None:
"""slug_exists returns True when slug exists in org."""
org_id = await self._create_org(db_conn, "exists-org")
repo = ServiceRepository(db_conn)
await repo.create(uuid4(), org_id, "Exists Service", "exists-slug")
result = await repo.slug_exists(org_id, "exists-slug")
assert result is True
async def test_slug_exists_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None:
"""slug_exists returns False when slug doesn't exist in org."""
org_id = await self._create_org(db_conn, "not-exists-org")
repo = ServiceRepository(db_conn)
result = await repo.slug_exists(org_id, "no-such-slug")
assert result is False
async def test_slug_exists_returns_false_for_other_org(self, db_conn: asyncpg.Connection) -> None:
"""slug_exists returns False for slug in different org."""
org1 = await self._create_org(db_conn, "other-org-1")
org2 = await self._create_org(db_conn, "other-org-2")
repo = ServiceRepository(db_conn)
await repo.create(uuid4(), org1, "Service", "cross-org-slug")
result = await repo.slug_exists(org2, "cross-org-slug")
assert result is False
async def test_service_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""services.org_id must reference existing org."""
repo = ServiceRepository(db_conn)
with pytest.raises(asyncpg.ForeignKeyViolationError):
await repo.create(uuid4(), uuid4(), "Orphan Service", "orphan")
async def test_get_by_org_orders_by_name(self, db_conn: asyncpg.Connection) -> None:
"""get_by_org returns services ordered by name."""
org_id = await self._create_org(db_conn, "ordered-org")
repo = ServiceRepository(db_conn)
await repo.create(uuid4(), org_id, "Zebra", "zebra")
await repo.create(uuid4(), org_id, "Alpha", "alpha")
await repo.create(uuid4(), org_id, "Middle", "middle")
result = await repo.get_by_org(org_id)
names = [s["name"] for s in result]
assert names == ["Alpha", "Middle", "Zebra"]

View File

@@ -0,0 +1,133 @@
"""Tests for UserRepository."""
from uuid import uuid4
import asyncpg
import pytest
from app.repositories.user import UserRepository
class TestUserRepository:
"""Tests for UserRepository conforming to SPECS.md."""
async def test_create_user_returns_user_data(self, db_conn: asyncpg.Connection) -> None:
"""Creating a user returns the user data with id, email, created_at."""
repo = UserRepository(db_conn)
user_id = uuid4()
email = "test@example.com"
password_hash = "hashed_password_123"
result = await repo.create(user_id, email, password_hash)
assert result["id"] == user_id
assert result["email"] == email
assert result["created_at"] is not None
async def test_create_user_stores_password_hash(self, db_conn: asyncpg.Connection) -> None:
"""Password hash is stored correctly in the database."""
repo = UserRepository(db_conn)
user_id = uuid4()
email = "hash_test@example.com"
password_hash = "bcrypt_hashed_value"
await repo.create(user_id, email, password_hash)
user = await repo.get_by_id(user_id)
assert user["password_hash"] == password_hash
async def test_create_user_email_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
"""Email uniqueness constraint per SPECS.md users table."""
repo = UserRepository(db_conn)
email = "duplicate@example.com"
await repo.create(uuid4(), email, "hash1")
with pytest.raises(asyncpg.UniqueViolationError):
await repo.create(uuid4(), email, "hash2")
async def test_get_by_id_returns_user(self, db_conn: asyncpg.Connection) -> None:
"""get_by_id returns the correct user."""
repo = UserRepository(db_conn)
user_id = uuid4()
email = "getbyid@example.com"
await repo.create(user_id, email, "hash")
result = await repo.get_by_id(user_id)
assert result is not None
assert result["id"] == user_id
assert result["email"] == email
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_id returns None for non-existent user."""
repo = UserRepository(db_conn)
result = await repo.get_by_id(uuid4())
assert result is None
async def test_get_by_email_returns_user(self, db_conn: asyncpg.Connection) -> None:
"""get_by_email returns the correct user."""
repo = UserRepository(db_conn)
user_id = uuid4()
email = "getbyemail@example.com"
await repo.create(user_id, email, "hash")
result = await repo.get_by_email(email)
assert result is not None
assert result["id"] == user_id
assert result["email"] == email
async def test_get_by_email_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_email returns None for non-existent email."""
repo = UserRepository(db_conn)
result = await repo.get_by_email("nonexistent@example.com")
assert result is None
async def test_get_by_email_is_case_sensitive(self, db_conn: asyncpg.Connection) -> None:
"""Email lookup is case-sensitive (stored as provided)."""
repo = UserRepository(db_conn)
email = "CaseSensitive@Example.com"
await repo.create(uuid4(), email, "hash")
# Exact match works
result = await repo.get_by_email(email)
assert result is not None
# Different case returns None
result = await repo.get_by_email(email.lower())
assert result is None
async def test_exists_by_email_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None:
"""exists_by_email returns True when email exists."""
repo = UserRepository(db_conn)
email = "exists@example.com"
await repo.create(uuid4(), email, "hash")
result = await repo.exists_by_email(email)
assert result is True
async def test_exists_by_email_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None:
"""exists_by_email returns False when email doesn't exist."""
repo = UserRepository(db_conn)
result = await repo.exists_by_email("notexists@example.com")
assert result is False
async def test_user_id_is_uuid_primary_key(self, db_conn: asyncpg.Connection) -> None:
"""User ID must be a valid UUID (primary key)."""
repo = UserRepository(db_conn)
user_id = uuid4()
await repo.create(user_id, "pk_test@example.com", "hash")
# Duplicate ID should fail
with pytest.raises(asyncpg.UniqueViolationError):
await repo.create(user_id, "other@example.com", "hash")

31
uv.lock generated
View File

@@ -309,6 +309,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e8/cb/2da4cc83f5edb9c3257d09e1e7ab7b23f049c7962cae8d842bbef0a9cec9/cryptography-46.0.3-cp38-abi3-win_arm64.whl", hash = "sha256:d89c3468de4cdc4f08a57e214384d0471911a3830fcdaf7a8cc587e42a866372", size = 2918740, upload-time = "2025-10-15T23:18:12.277Z" }, { url = "https://files.pythonhosted.org/packages/e8/cb/2da4cc83f5edb9c3257d09e1e7ab7b23f049c7962cae8d842bbef0a9cec9/cryptography-46.0.3-cp38-abi3-win_arm64.whl", hash = "sha256:d89c3468de4cdc4f08a57e214384d0471911a3830fcdaf7a8cc587e42a866372", size = 2918740, upload-time = "2025-10-15T23:18:12.277Z" },
] ]
[[package]]
name = "dnspython"
version = "2.8.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" },
]
[[package]] [[package]]
name = "ecdsa" name = "ecdsa"
version = "0.19.1" version = "0.19.1"
@@ -321,6 +330,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
] ]
[[package]]
name = "email-validator"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dnspython" },
{ name = "idna" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
]
[[package]] [[package]]
name = "fastapi" name = "fastapi"
version = "0.128.0" version = "0.128.0"
@@ -407,7 +429,7 @@ dependencies = [
{ name = "celery", extra = ["redis"] }, { name = "celery", extra = ["redis"] },
{ name = "fastapi" }, { name = "fastapi" },
{ name = "httpx" }, { name = "httpx" },
{ name = "pydantic" }, { name = "pydantic", extra = ["email"] },
{ name = "pydantic-settings" }, { name = "pydantic-settings" },
{ name = "python-jose", extra = ["cryptography"] }, { name = "python-jose", extra = ["cryptography"] },
{ name = "redis" }, { name = "redis" },
@@ -428,7 +450,7 @@ requires-dist = [
{ name = "celery", extras = ["redis"], specifier = ">=5.4.0" }, { name = "celery", extras = ["redis"], specifier = ">=5.4.0" },
{ name = "fastapi", specifier = ">=0.115.0" }, { name = "fastapi", specifier = ">=0.115.0" },
{ name = "httpx", specifier = ">=0.28.0" }, { name = "httpx", specifier = ">=0.28.0" },
{ name = "pydantic", specifier = ">=2.0.0" }, { name = "pydantic", extras = ["email"], specifier = ">=2.0.0" },
{ name = "pydantic-settings", specifier = ">=2.0.0" }, { name = "pydantic-settings", specifier = ">=2.0.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
@@ -531,6 +553,11 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" }, { url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" },
] ]
[package.optional-dependencies]
email = [
{ name = "email-validator" },
]
[[package]] [[package]]
name = "pydantic-core" name = "pydantic-core"
version = "2.41.5" version = "2.41.5"