feat(api): Pydantic schemas + Data Repositories
This commit is contained in:
17
app/repositories/__init__.py
Normal file
17
app/repositories/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Repository layer for database operations."""
|
||||
|
||||
from app.repositories.incident import IncidentRepository
|
||||
from app.repositories.notification import NotificationRepository
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.refresh_token import RefreshTokenRepository
|
||||
from app.repositories.service import ServiceRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
__all__ = [
|
||||
"IncidentRepository",
|
||||
"NotificationRepository",
|
||||
"OrgRepository",
|
||||
"RefreshTokenRepository",
|
||||
"ServiceRepository",
|
||||
"UserRepository",
|
||||
]
|
||||
161
app/repositories/incident.py
Normal file
161
app/repositories/incident.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Incident repository for database operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class IncidentRepository:
|
||||
"""Database operations for incidents."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
incident_id: UUID,
|
||||
org_id: UUID,
|
||||
service_id: UUID,
|
||||
title: str,
|
||||
description: str | None,
|
||||
severity: str,
|
||||
) -> dict:
|
||||
"""Create a new incident."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO incidents (id, org_id, service_id, title, description, status, severity)
|
||||
VALUES ($1, $2, $3, $4, $5, 'triggered', $6)
|
||||
RETURNING id, org_id, service_id, title, description, status, severity,
|
||||
version, created_at, updated_at
|
||||
""",
|
||||
incident_id,
|
||||
org_id,
|
||||
service_id,
|
||||
title,
|
||||
description,
|
||||
severity,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_id(self, incident_id: UUID) -> dict | None:
|
||||
"""Get incident by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, org_id, service_id, title, description, status, severity,
|
||||
version, created_at, updated_at
|
||||
FROM incidents
|
||||
WHERE id = $1
|
||||
""",
|
||||
incident_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_by_org(
|
||||
self,
|
||||
org_id: UUID,
|
||||
status: str | None = None,
|
||||
cursor: datetime | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[dict]:
|
||||
"""Get incidents for an organization with optional filtering and pagination."""
|
||||
query = """
|
||||
SELECT id, org_id, service_id, title, description, status, severity,
|
||||
version, created_at, updated_at
|
||||
FROM incidents
|
||||
WHERE org_id = $1
|
||||
"""
|
||||
params: list[Any] = [org_id]
|
||||
param_idx = 2
|
||||
|
||||
if status:
|
||||
query += f" AND status = ${param_idx}"
|
||||
params.append(status)
|
||||
param_idx += 1
|
||||
|
||||
if cursor:
|
||||
query += f" AND created_at < ${param_idx}"
|
||||
params.append(cursor)
|
||||
param_idx += 1
|
||||
|
||||
query += f" ORDER BY created_at DESC LIMIT ${param_idx}"
|
||||
params.append(limit + 1) # Fetch one extra to check if there are more
|
||||
|
||||
rows = await self.conn.fetch(query, *params)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
incident_id: UUID,
|
||||
new_status: str,
|
||||
expected_version: int,
|
||||
) -> dict | None:
|
||||
"""Update incident status with optimistic locking.
|
||||
|
||||
Returns updated incident if successful, None if version mismatch.
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE incidents
|
||||
SET status = $2, version = version + 1, updated_at = now()
|
||||
WHERE id = $1 AND version = $3
|
||||
RETURNING id, org_id, service_id, title, description, status, severity,
|
||||
version, created_at, updated_at
|
||||
""",
|
||||
incident_id,
|
||||
new_status,
|
||||
expected_version,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def add_event(
|
||||
self,
|
||||
event_id: UUID,
|
||||
incident_id: UUID,
|
||||
event_type: str,
|
||||
actor_user_id: UUID | None,
|
||||
payload: dict[str, Any] | None,
|
||||
) -> dict:
|
||||
"""Add an event to the incident timeline."""
|
||||
import json
|
||||
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO incident_events (id, incident_id, event_type, actor_user_id, payload)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, incident_id, event_type, actor_user_id, payload, created_at
|
||||
""",
|
||||
event_id,
|
||||
incident_id,
|
||||
event_type,
|
||||
actor_user_id,
|
||||
json.dumps(payload) if payload else None,
|
||||
)
|
||||
result = dict(row)
|
||||
|
||||
# Parse JSON payload back to dict
|
||||
if result["payload"]:
|
||||
result["payload"] = json.loads(result["payload"])
|
||||
return result
|
||||
|
||||
async def get_events(self, incident_id: UUID) -> list[dict]:
|
||||
"""Get all events for an incident."""
|
||||
import json
|
||||
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT id, incident_id, event_type, actor_user_id, payload, created_at
|
||||
FROM incident_events
|
||||
WHERE incident_id = $1
|
||||
ORDER BY created_at
|
||||
""",
|
||||
incident_id,
|
||||
)
|
||||
results = []
|
||||
for row in rows:
|
||||
result = dict(row)
|
||||
if result["payload"]:
|
||||
result["payload"] = json.loads(result["payload"])
|
||||
results.append(result)
|
||||
return results
|
||||
199
app/repositories/notification.py
Normal file
199
app/repositories/notification.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Notification repository for database operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class NotificationRepository:
|
||||
"""Database operations for notification targets and attempts."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create_target(
|
||||
self,
|
||||
target_id: UUID,
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
target_type: str,
|
||||
webhook_url: str | None = None,
|
||||
enabled: bool = True,
|
||||
) -> dict:
|
||||
"""Create a new notification target."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO notification_targets (id, org_id, name, target_type, webhook_url, enabled)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at
|
||||
""",
|
||||
target_id,
|
||||
org_id,
|
||||
name,
|
||||
target_type,
|
||||
webhook_url,
|
||||
enabled,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_target_by_id(self, target_id: UUID) -> dict | None:
|
||||
"""Get notification target by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, org_id, name, target_type, webhook_url, enabled, created_at
|
||||
FROM notification_targets
|
||||
WHERE id = $1
|
||||
""",
|
||||
target_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_targets_by_org(
|
||||
self,
|
||||
org_id: UUID,
|
||||
enabled_only: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Get all notification targets for an organization."""
|
||||
query = """
|
||||
SELECT id, org_id, name, target_type, webhook_url, enabled, created_at
|
||||
FROM notification_targets
|
||||
WHERE org_id = $1
|
||||
"""
|
||||
if enabled_only:
|
||||
query += " AND enabled = true"
|
||||
query += " ORDER BY name"
|
||||
|
||||
rows = await self.conn.fetch(query, org_id)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def update_target(
|
||||
self,
|
||||
target_id: UUID,
|
||||
name: str | None = None,
|
||||
webhook_url: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> dict | None:
|
||||
"""Update a notification target."""
|
||||
updates = []
|
||||
params = [target_id]
|
||||
param_idx = 2
|
||||
|
||||
if name is not None:
|
||||
updates.append(f"name = ${param_idx}")
|
||||
params.append(name)
|
||||
param_idx += 1
|
||||
|
||||
if webhook_url is not None:
|
||||
updates.append(f"webhook_url = ${param_idx}")
|
||||
params.append(webhook_url)
|
||||
param_idx += 1
|
||||
|
||||
if enabled is not None:
|
||||
updates.append(f"enabled = ${param_idx}")
|
||||
params.append(enabled)
|
||||
param_idx += 1
|
||||
|
||||
if not updates:
|
||||
return await self.get_target_by_id(target_id)
|
||||
|
||||
query = f"""
|
||||
UPDATE notification_targets
|
||||
SET {", ".join(updates)}
|
||||
WHERE id = $1
|
||||
RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at
|
||||
"""
|
||||
row = await self.conn.fetchrow(query, *params)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def delete_target(self, target_id: UUID) -> bool:
|
||||
"""Delete a notification target. Returns True if deleted."""
|
||||
result = await self.conn.execute(
|
||||
"DELETE FROM notification_targets WHERE id = $1",
|
||||
target_id,
|
||||
)
|
||||
return result == "DELETE 1"
|
||||
|
||||
async def create_attempt(
|
||||
self,
|
||||
attempt_id: UUID,
|
||||
incident_id: UUID,
|
||||
target_id: UUID,
|
||||
) -> dict:
|
||||
"""Create a notification attempt (idempotent via unique constraint)."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO notification_attempts (id, incident_id, target_id, status)
|
||||
VALUES ($1, $2, $3, 'pending')
|
||||
ON CONFLICT (incident_id, target_id) DO UPDATE SET id = notification_attempts.id
|
||||
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
|
||||
""",
|
||||
attempt_id,
|
||||
incident_id,
|
||||
target_id,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_attempt(self, incident_id: UUID, target_id: UUID) -> dict | None:
|
||||
"""Get notification attempt for incident and target."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, incident_id, target_id, status, error, sent_at, created_at
|
||||
FROM notification_attempts
|
||||
WHERE incident_id = $1 AND target_id = $2
|
||||
""",
|
||||
incident_id,
|
||||
target_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def update_attempt_success(
|
||||
self,
|
||||
attempt_id: UUID,
|
||||
sent_at: datetime,
|
||||
) -> dict | None:
|
||||
"""Mark notification attempt as successful."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE notification_attempts
|
||||
SET status = 'sent', sent_at = $2, error = NULL
|
||||
WHERE id = $1
|
||||
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
|
||||
""",
|
||||
attempt_id,
|
||||
sent_at,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def update_attempt_failure(
|
||||
self,
|
||||
attempt_id: UUID,
|
||||
error: str,
|
||||
) -> dict | None:
|
||||
"""Mark notification attempt as failed."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE notification_attempts
|
||||
SET status = 'failed', error = $2
|
||||
WHERE id = $1
|
||||
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
|
||||
""",
|
||||
attempt_id,
|
||||
error,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_pending_attempts(self, incident_id: UUID) -> list[dict]:
|
||||
"""Get all pending notification attempts for an incident."""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT na.id, na.incident_id, na.target_id, na.status, na.error,
|
||||
na.sent_at, na.created_at,
|
||||
nt.target_type, nt.webhook_url, nt.name as target_name
|
||||
FROM notification_attempts na
|
||||
JOIN notification_targets nt ON nt.id = na.target_id
|
||||
WHERE na.incident_id = $1 AND na.status = 'pending'
|
||||
""",
|
||||
incident_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
125
app/repositories/org.py
Normal file
125
app/repositories/org.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Organization repository for database operations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class OrgRepository:
|
||||
"""Database operations for organizations."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
slug: str,
|
||||
) -> dict:
|
||||
"""Create a new organization."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO orgs (id, name, slug)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, name, slug, created_at
|
||||
""",
|
||||
org_id,
|
||||
name,
|
||||
slug,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_id(self, org_id: UUID) -> dict | None:
|
||||
"""Get organization by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, name, slug, created_at
|
||||
FROM orgs
|
||||
WHERE id = $1
|
||||
""",
|
||||
org_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_by_slug(self, slug: str) -> dict | None:
|
||||
"""Get organization by slug."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, name, slug, created_at
|
||||
FROM orgs
|
||||
WHERE slug = $1
|
||||
""",
|
||||
slug,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def add_member(
|
||||
self,
|
||||
member_id: UUID,
|
||||
user_id: UUID,
|
||||
org_id: UUID,
|
||||
role: str,
|
||||
) -> dict:
|
||||
"""Add a member to an organization."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO org_members (id, user_id, org_id, role)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, user_id, org_id, role, created_at
|
||||
""",
|
||||
member_id,
|
||||
user_id,
|
||||
org_id,
|
||||
role,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_member(self, user_id: UUID, org_id: UUID) -> dict | None:
|
||||
"""Get membership for a user in an organization."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT om.id, om.user_id, om.org_id, om.role, om.created_at
|
||||
FROM org_members om
|
||||
WHERE om.user_id = $1 AND om.org_id = $2
|
||||
""",
|
||||
user_id,
|
||||
org_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_members(self, org_id: UUID) -> list[dict]:
|
||||
"""Get all members of an organization."""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT om.id, om.user_id, u.email, om.role, om.created_at
|
||||
FROM org_members om
|
||||
JOIN users u ON u.id = om.user_id
|
||||
WHERE om.org_id = $1
|
||||
ORDER BY om.created_at
|
||||
""",
|
||||
org_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def get_user_orgs(self, user_id: UUID) -> list[dict]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT o.id, o.name, o.slug, o.created_at, om.role
|
||||
FROM orgs o
|
||||
JOIN org_members om ON om.org_id = o.id
|
||||
WHERE om.user_id = $1
|
||||
ORDER BY o.created_at
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def slug_exists(self, slug: str) -> bool:
|
||||
"""Check if organization slug exists."""
|
||||
result = await self.conn.fetchval(
|
||||
"SELECT EXISTS(SELECT 1 FROM orgs WHERE slug = $1)",
|
||||
slug,
|
||||
)
|
||||
return result
|
||||
396
app/repositories/refresh_token.py
Normal file
396
app/repositories/refresh_token.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""Refresh token repository for database operations.
|
||||
|
||||
Security considerations implemented:
|
||||
- Atomic rotation using SELECT FOR UPDATE to prevent race conditions
|
||||
- Token chain tracking via rotated_to for reuse/theft detection
|
||||
- Defense-in-depth validation with user_id and active_org_id checks
|
||||
- Uses RETURNING for robust row counting instead of string parsing
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class RefreshTokenRepository:
|
||||
"""Database operations for refresh tokens."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
token_id: UUID,
|
||||
user_id: UUID,
|
||||
token_hash: str,
|
||||
active_org_id: UUID,
|
||||
expires_at: datetime,
|
||||
) -> dict:
|
||||
"""Create a new refresh token."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
""",
|
||||
token_id,
|
||||
user_id,
|
||||
token_hash,
|
||||
active_org_id,
|
||||
expires_at,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_hash(self, token_hash: str) -> dict | None:
|
||||
"""Get refresh token by hash (includes revoked/expired for auditing)."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
""",
|
||||
token_hash,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_valid_by_hash(
|
||||
self,
|
||||
token_hash: str,
|
||||
user_id: UUID | None = None,
|
||||
active_org_id: UUID | None = None,
|
||||
) -> dict | None:
|
||||
"""Get refresh token by hash, only if valid.
|
||||
|
||||
Validates:
|
||||
- Token exists and matches hash
|
||||
- Token is not revoked
|
||||
- Token is not expired
|
||||
- Token has not been rotated (rotated_to is NULL)
|
||||
- Optionally: user_id matches (defense-in-depth)
|
||||
- Optionally: active_org_id matches (defense-in-depth)
|
||||
|
||||
Args:
|
||||
token_hash: The hashed token value
|
||||
user_id: If provided, token must belong to this user
|
||||
active_org_id: If provided, token must be bound to this org
|
||||
|
||||
Returns:
|
||||
Token dict if valid, None otherwise
|
||||
"""
|
||||
query = """
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
AND revoked_at IS NULL
|
||||
AND rotated_to IS NULL
|
||||
AND expires_at > clock_timestamp()
|
||||
"""
|
||||
params: list = [token_hash]
|
||||
param_idx = 2
|
||||
|
||||
if user_id is not None:
|
||||
query += f" AND user_id = ${param_idx}"
|
||||
params.append(user_id)
|
||||
param_idx += 1
|
||||
|
||||
if active_org_id is not None:
|
||||
query += f" AND active_org_id = ${param_idx}"
|
||||
params.append(active_org_id)
|
||||
|
||||
row = await self.conn.fetchrow(query, *params)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_valid_for_rotation(
|
||||
self,
|
||||
token_hash: str,
|
||||
user_id: UUID | None = None,
|
||||
) -> dict | None:
|
||||
"""Get and lock a valid token for rotation using SELECT FOR UPDATE.
|
||||
|
||||
This acquires a row-level lock to prevent concurrent rotation attempts.
|
||||
Must be called within a transaction.
|
||||
|
||||
Args:
|
||||
token_hash: The hashed token value
|
||||
user_id: If provided, token must belong to this user
|
||||
|
||||
Returns:
|
||||
Token dict if valid and locked, None otherwise
|
||||
"""
|
||||
query = """
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
AND revoked_at IS NULL
|
||||
AND rotated_to IS NULL
|
||||
AND expires_at > clock_timestamp()
|
||||
"""
|
||||
params: list = [token_hash]
|
||||
|
||||
if user_id is not None:
|
||||
query += " AND user_id = $2"
|
||||
params.append(user_id)
|
||||
|
||||
query += " FOR UPDATE"
|
||||
|
||||
row = await self.conn.fetchrow(query, *params)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def check_token_reuse(self, token_hash: str) -> dict | None:
|
||||
"""Check if a token has already been rotated (potential theft).
|
||||
|
||||
If a token is presented that has rotated_to set, it means:
|
||||
1. The token was legitimately rotated earlier
|
||||
2. Someone is now trying to use the old token
|
||||
3. This indicates the token may have been stolen
|
||||
|
||||
Returns:
|
||||
Token dict if this is a reused/stolen token, None if not found or not rotated
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1 AND rotated_to IS NOT NULL
|
||||
""",
|
||||
token_hash,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def revoke_token_chain(self, token_id: UUID) -> int:
|
||||
"""Revoke a token and all tokens in its chain (for breach response).
|
||||
|
||||
When token reuse is detected, this revokes:
|
||||
1. The original stolen token
|
||||
2. Any token it was rotated to (and their rotations, recursively)
|
||||
|
||||
Args:
|
||||
token_id: The ID of the compromised token
|
||||
|
||||
Returns:
|
||||
Count of tokens revoked
|
||||
"""
|
||||
# Use recursive CTE to find all tokens in the chain
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
WITH RECURSIVE token_chain AS (
|
||||
-- Start with the given token
|
||||
SELECT id, rotated_to
|
||||
FROM refresh_tokens
|
||||
WHERE id = $1
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- Follow the chain via rotated_to
|
||||
SELECT rt.id, rt.rotated_to
|
||||
FROM refresh_tokens rt
|
||||
INNER JOIN token_chain tc ON rt.id = tc.rotated_to
|
||||
)
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE id IN (SELECT id FROM token_chain)
|
||||
AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
token_id,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
async def rotate(
|
||||
self,
|
||||
old_token_hash: str,
|
||||
new_token_id: UUID,
|
||||
new_token_hash: str,
|
||||
new_expires_at: datetime,
|
||||
new_active_org_id: UUID | None = None,
|
||||
expected_user_id: UUID | None = None,
|
||||
) -> dict | None:
|
||||
"""Atomically rotate a refresh token.
|
||||
|
||||
This method:
|
||||
1. Validates the old token (not expired, not revoked, not already rotated)
|
||||
2. Locks the row to prevent concurrent rotation
|
||||
3. Marks old token as rotated (sets rotated_to)
|
||||
4. Creates new token with updated org if specified
|
||||
5. All in a single atomic operation
|
||||
|
||||
Args:
|
||||
old_token_hash: Hash of the token being rotated
|
||||
new_token_id: UUID for the new token
|
||||
new_token_hash: Hash for the new token
|
||||
new_expires_at: Expiry time for the new token
|
||||
new_active_org_id: New org ID (for org-switch), or None to keep current
|
||||
expected_user_id: If provided, validates token belongs to this user
|
||||
|
||||
Returns:
|
||||
New token dict if rotation succeeded, None if old token invalid/expired
|
||||
"""
|
||||
# First, get and lock the old token
|
||||
old_token = await self.get_valid_for_rotation(old_token_hash, expected_user_id)
|
||||
if old_token is None:
|
||||
return None
|
||||
|
||||
# Determine the org for the new token
|
||||
active_org_id = new_active_org_id or old_token["active_org_id"]
|
||||
user_id = old_token["user_id"]
|
||||
|
||||
# Create the new token
|
||||
new_token = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
""",
|
||||
new_token_id,
|
||||
user_id,
|
||||
new_token_hash,
|
||||
active_org_id,
|
||||
new_expires_at,
|
||||
)
|
||||
|
||||
# Mark the old token as rotated (not revoked - for reuse detection)
|
||||
await self.conn.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET rotated_to = $2
|
||||
WHERE id = $1
|
||||
""",
|
||||
old_token["id"],
|
||||
new_token_id,
|
||||
)
|
||||
|
||||
return dict(new_token)
|
||||
|
||||
async def revoke(self, token_id: UUID) -> bool:
|
||||
"""Revoke a refresh token by ID.
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if not found or already revoked
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE id = $1 AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
token_id,
|
||||
)
|
||||
return row is not None
|
||||
|
||||
async def revoke_by_hash(self, token_hash: str) -> bool:
|
||||
"""Revoke a refresh token by hash.
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if not found or already revoked
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE token_hash = $1 AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
token_hash,
|
||||
)
|
||||
return row is not None
|
||||
|
||||
async def revoke_all_for_user(self, user_id: UUID) -> int:
|
||||
"""Revoke all active refresh tokens for a user.
|
||||
|
||||
Use this for:
|
||||
- User-initiated logout from all devices
|
||||
- Password change
|
||||
- Account compromise response
|
||||
|
||||
Returns:
|
||||
Count of tokens revoked
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
async def revoke_all_for_user_except(self, user_id: UUID, keep_token_id: UUID) -> int:
|
||||
"""Revoke all tokens for a user except one (logout other sessions).
|
||||
|
||||
Args:
|
||||
user_id: The user whose tokens to revoke
|
||||
keep_token_id: The token ID to keep active (current session)
|
||||
|
||||
Returns:
|
||||
Count of tokens revoked
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL AND id != $2
|
||||
RETURNING id
|
||||
""",
|
||||
user_id,
|
||||
keep_token_id,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
async def get_active_tokens_for_user(self, user_id: UUID) -> list[dict]:
|
||||
"""Get all active (non-revoked, non-expired, non-rotated) tokens for a user.
|
||||
|
||||
Useful for:
|
||||
- Showing active sessions
|
||||
- Auditing
|
||||
|
||||
Returns:
|
||||
List of active token records
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE user_id = $1
|
||||
AND revoked_at IS NULL
|
||||
AND rotated_to IS NULL
|
||||
AND expires_at > clock_timestamp()
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def cleanup_expired(self, older_than_days: int = 30) -> int:
|
||||
"""Delete expired tokens older than specified days.
|
||||
|
||||
Note: This performs a hard delete. For audit trails, I think we should:
|
||||
- Archiving to a separate table first
|
||||
- Using partitioning with retention policies
|
||||
- Only calling this for tokens well past their expiry
|
||||
|
||||
Args:
|
||||
older_than_days: Only delete tokens expired more than this many days ago
|
||||
|
||||
Returns:
|
||||
Count of tokens deleted
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE expires_at < clock_timestamp() - interval '1 day' * $1
|
||||
RETURNING id
|
||||
""",
|
||||
older_than_days,
|
||||
)
|
||||
return len(rows)
|
||||
80
app/repositories/service.py
Normal file
80
app/repositories/service.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Service repository for database operations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class ServiceRepository:
|
||||
"""Database operations for services."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
service_id: UUID,
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
slug: str,
|
||||
) -> dict:
|
||||
"""Create a new service."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO services (id, org_id, name, slug)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, org_id, name, slug, created_at
|
||||
""",
|
||||
service_id,
|
||||
org_id,
|
||||
name,
|
||||
slug,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_id(self, service_id: UUID) -> dict | None:
|
||||
"""Get service by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, org_id, name, slug, created_at
|
||||
FROM services
|
||||
WHERE id = $1
|
||||
""",
|
||||
service_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_by_org(self, org_id: UUID) -> list[dict]:
|
||||
"""Get all services for an organization."""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT id, org_id, name, slug, created_at
|
||||
FROM services
|
||||
WHERE org_id = $1
|
||||
ORDER BY name
|
||||
""",
|
||||
org_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def get_by_slug(self, org_id: UUID, slug: str) -> dict | None:
|
||||
"""Get service by org and slug."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, org_id, name, slug, created_at
|
||||
FROM services
|
||||
WHERE org_id = $1 AND slug = $2
|
||||
""",
|
||||
org_id,
|
||||
slug,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def slug_exists(self, org_id: UUID, slug: str) -> bool:
|
||||
"""Check if service slug exists in organization."""
|
||||
result = await self.conn.fetchval(
|
||||
"SELECT EXISTS(SELECT 1 FROM services WHERE org_id = $1 AND slug = $2)",
|
||||
org_id,
|
||||
slug,
|
||||
)
|
||||
return result
|
||||
63
app/repositories/user.py
Normal file
63
app/repositories/user.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""User repository for database operations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Database operations for users."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_id: UUID,
|
||||
email: str,
|
||||
password_hash: str,
|
||||
) -> dict:
|
||||
"""Create a new user."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO users (id, email, password_hash)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, email, created_at
|
||||
""",
|
||||
user_id,
|
||||
email,
|
||||
password_hash,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_id(self, user_id: UUID) -> dict | None:
|
||||
"""Get user by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, email, password_hash, created_at
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_by_email(self, email: str) -> dict | None:
|
||||
"""Get user by email."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, email, password_hash, created_at
|
||||
FROM users
|
||||
WHERE email = $1
|
||||
""",
|
||||
email,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def exists_by_email(self, email: str) -> bool:
|
||||
"""Check if user exists by email."""
|
||||
result = await self.conn.fetchval(
|
||||
"SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)",
|
||||
email,
|
||||
)
|
||||
return result
|
||||
50
app/schemas/__init__.py
Normal file
50
app/schemas/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Pydantic schemas for request/response models."""
|
||||
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
RefreshRequest,
|
||||
RegisterRequest,
|
||||
SwitchOrgRequest,
|
||||
TokenResponse,
|
||||
)
|
||||
from app.schemas.common import CursorParams, PaginatedResponse
|
||||
from app.schemas.incident import (
|
||||
CommentRequest,
|
||||
IncidentCreate,
|
||||
IncidentEventResponse,
|
||||
IncidentResponse,
|
||||
TransitionRequest,
|
||||
)
|
||||
from app.schemas.org import (
|
||||
MemberResponse,
|
||||
NotificationTargetCreate,
|
||||
NotificationTargetResponse,
|
||||
OrgResponse,
|
||||
ServiceCreate,
|
||||
ServiceResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Auth
|
||||
"LoginRequest",
|
||||
"RefreshRequest",
|
||||
"RegisterRequest",
|
||||
"SwitchOrgRequest",
|
||||
"TokenResponse",
|
||||
# Common
|
||||
"CursorParams",
|
||||
"PaginatedResponse",
|
||||
# Incident
|
||||
"CommentRequest",
|
||||
"IncidentCreate",
|
||||
"IncidentEventResponse",
|
||||
"IncidentResponse",
|
||||
"TransitionRequest",
|
||||
# Org
|
||||
"MemberResponse",
|
||||
"NotificationTargetCreate",
|
||||
"NotificationTargetResponse",
|
||||
"OrgResponse",
|
||||
"ServiceCreate",
|
||||
"ServiceResponse",
|
||||
]
|
||||
42
app/schemas/auth.py
Normal file
42
app/schemas/auth.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Authentication schemas."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""Request body for user registration."""
|
||||
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
org_name: str = Field(min_length=1, max_length=100, description="Name for the default org")
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Request body for user login."""
|
||||
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
"""Request body for token refresh."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class SwitchOrgRequest(BaseModel):
|
||||
"""Request body for switching active organization."""
|
||||
|
||||
org_id: UUID
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Response containing access and refresh tokens."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = Field(description="Access token expiry in seconds")
|
||||
20
app/schemas/common.py
Normal file
20
app/schemas/common.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Common schemas used across the API."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CursorParams(BaseModel):
|
||||
"""Pagination parameters using cursor-based pagination."""
|
||||
|
||||
cursor: str | None = Field(default=None, description="Cursor for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page")
|
||||
|
||||
|
||||
class PaginatedResponse[T](BaseModel):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
items: list[T]
|
||||
next_cursor: str | None = Field(
|
||||
default=None, description="Cursor for next page, null if no more items"
|
||||
)
|
||||
has_more: bool = Field(description="Whether there are more items")
|
||||
57
app/schemas/incident.py
Normal file
57
app/schemas/incident.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Incident-related schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
IncidentStatus = Literal["triggered", "acknowledged", "mitigated", "resolved"]
|
||||
IncidentSeverity = Literal["critical", "high", "medium", "low"]
|
||||
|
||||
|
||||
class IncidentCreate(BaseModel):
|
||||
"""Request body for creating an incident."""
|
||||
|
||||
title: str = Field(min_length=1, max_length=200)
|
||||
description: str | None = Field(default=None, max_length=5000)
|
||||
severity: IncidentSeverity = "medium"
|
||||
|
||||
|
||||
class IncidentResponse(BaseModel):
|
||||
"""Incident response."""
|
||||
|
||||
id: UUID
|
||||
service_id: UUID
|
||||
title: str
|
||||
description: str | None
|
||||
status: IncidentStatus
|
||||
severity: IncidentSeverity
|
||||
version: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class IncidentEventResponse(BaseModel):
|
||||
"""Incident event response."""
|
||||
|
||||
id: UUID
|
||||
incident_id: UUID
|
||||
event_type: str
|
||||
actor_user_id: UUID | None
|
||||
payload: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class TransitionRequest(BaseModel):
|
||||
"""Request body for transitioning incident status."""
|
||||
|
||||
to_status: IncidentStatus
|
||||
version: int = Field(description="Current version for optimistic locking")
|
||||
note: str | None = Field(default=None, max_length=1000)
|
||||
|
||||
|
||||
class CommentRequest(BaseModel):
|
||||
"""Request body for adding a comment to an incident."""
|
||||
|
||||
content: str = Field(min_length=1, max_length=5000)
|
||||
69
app/schemas/org.py
Normal file
69
app/schemas/org.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Organization-related schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
||||
class OrgResponse(BaseModel):
|
||||
"""Organization summary response."""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
slug: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class MemberResponse(BaseModel):
|
||||
"""Organization member response."""
|
||||
|
||||
id: UUID
|
||||
user_id: UUID
|
||||
email: str
|
||||
role: Literal["admin", "member", "viewer"]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ServiceCreate(BaseModel):
|
||||
"""Request body for creating a service."""
|
||||
|
||||
name: str = Field(min_length=1, max_length=100)
|
||||
slug: str = Field(
|
||||
min_length=1,
|
||||
max_length=50,
|
||||
pattern=r"^[a-z0-9]+(?:-[a-z0-9]+)*$",
|
||||
description="URL-friendly identifier (lowercase, hyphens allowed)",
|
||||
)
|
||||
|
||||
|
||||
class ServiceResponse(BaseModel):
|
||||
"""Service response."""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
slug: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class NotificationTargetCreate(BaseModel):
|
||||
"""Request body for creating a notification target."""
|
||||
|
||||
name: str = Field(min_length=1, max_length=100)
|
||||
target_type: Literal["webhook", "email", "slack"]
|
||||
webhook_url: HttpUrl | None = Field(
|
||||
default=None, description="Required for webhook type"
|
||||
)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class NotificationTargetResponse(BaseModel):
|
||||
"""Notification target response."""
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
target_type: Literal["webhook", "email", "slack"]
|
||||
webhook_url: str | None
|
||||
enabled: bool
|
||||
created_at: datetime
|
||||
18
migrations/0004_refresh_token_rotation.sql
Normal file
18
migrations/0004_refresh_token_rotation.sql
Normal file
@@ -0,0 +1,18 @@
|
||||
-- Enhance refresh tokens for secure rotation and reuse detection
|
||||
-- Adds rotated_to column to track token chains and detect stolen token reuse
|
||||
|
||||
-- Add rotated_to column to track which token this was rotated into
|
||||
-- When a token is rotated, we store the ID of the new token here
|
||||
-- If a token with rotated_to set is used again, it indicates token theft
|
||||
ALTER TABLE refresh_tokens ADD COLUMN rotated_to UUID REFERENCES refresh_tokens(id);
|
||||
|
||||
-- Index for efficient cleanup queries on expires_at
|
||||
CREATE INDEX idx_refresh_tokens_expires ON refresh_tokens(expires_at);
|
||||
|
||||
-- Index for finding active tokens per user (for revoke_all and listing)
|
||||
CREATE INDEX idx_refresh_tokens_user_active ON refresh_tokens(user_id, revoked_at)
|
||||
WHERE revoked_at IS NULL;
|
||||
|
||||
-- Index for reuse detection queries
|
||||
CREATE INDEX idx_refresh_tokens_rotated ON refresh_tokens(rotated_to)
|
||||
WHERE rotated_to IS NOT NULL;
|
||||
@@ -8,7 +8,7 @@ dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.32.0",
|
||||
"asyncpg>=0.30.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic[email]>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"bcrypt>=4.0.0",
|
||||
@@ -38,6 +38,9 @@ target-version = "py314"
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W", "UP"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = ["E501"] # Allow longer lines in tests for descriptive method names
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
|
||||
95
tests/conftest.py
Normal file
95
tests/conftest.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Shared pytest fixtures for all tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
# Set test environment variables before importing app modules
|
||||
os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@localhost:5432/incidentops_test")
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only")
|
||||
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1")
|
||||
|
||||
|
||||
# Module-level setup: create database and run migrations once
|
||||
_db_initialized = False
|
||||
|
||||
|
||||
async def _init_test_db() -> None:
|
||||
"""Initialize test database and run migrations (once per session)."""
|
||||
global _db_initialized
|
||||
if _db_initialized:
|
||||
return
|
||||
|
||||
admin_dsn = os.environ["DATABASE_URL"].rsplit("/", 1)[0] + "/postgres"
|
||||
test_db_name = "incidentops_test"
|
||||
|
||||
admin_conn = await asyncpg.connect(admin_dsn)
|
||||
try:
|
||||
# Terminate existing connections to the test database
|
||||
await admin_conn.execute(f"""
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = '{test_db_name}'
|
||||
AND pid <> pg_backend_pid()
|
||||
""")
|
||||
# Drop and recreate test database
|
||||
await admin_conn.execute(f"DROP DATABASE IF EXISTS {test_db_name}")
|
||||
await admin_conn.execute(f"CREATE DATABASE {test_db_name}")
|
||||
finally:
|
||||
await admin_conn.close()
|
||||
|
||||
# Connect to test database and run migrations
|
||||
test_dsn = os.environ["DATABASE_URL"]
|
||||
conn = await asyncpg.connect(test_dsn)
|
||||
|
||||
try:
|
||||
# Run migrations
|
||||
migrations_dir = os.path.join(os.path.dirname(__file__), "..", "migrations")
|
||||
migration_files = sorted(
|
||||
f for f in os.listdir(migrations_dir) if f.endswith(".sql")
|
||||
)
|
||||
|
||||
for migration_file in migration_files:
|
||||
migration_path = os.path.join(migrations_dir, migration_file)
|
||||
with open(migration_path) as f:
|
||||
sql = f.read()
|
||||
await conn.execute(sql)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
_db_initialized = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a database connection with transaction rollback for test isolation."""
|
||||
await _init_test_db()
|
||||
|
||||
test_dsn = os.environ["DATABASE_URL"]
|
||||
conn = await asyncpg.connect(test_dsn)
|
||||
|
||||
# Start a transaction that will be rolled back
|
||||
tr = conn.transaction()
|
||||
await tr.start()
|
||||
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await tr.rollback()
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_user_id() -> uuid4:
|
||||
"""Factory for generating user IDs."""
|
||||
return lambda: uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_org_id() -> uuid4:
|
||||
"""Factory for generating org IDs."""
|
||||
return lambda: uuid4()
|
||||
1
tests/repositories/__init__.py
Normal file
1
tests/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Repository tests."""
|
||||
389
tests/repositories/test_incident.py
Normal file
389
tests/repositories/test_incident.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""Tests for IncidentRepository."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.incident import IncidentRepository
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.service import ServiceRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
|
||||
class TestIncidentRepository:
|
||||
"""Tests for IncidentRepository conforming to SPECS.md."""
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def _create_service(self, conn: asyncpg.Connection, org_id: uuid4, slug: str) -> uuid4:
|
||||
"""Helper to create a service."""
|
||||
service_repo = ServiceRepository(conn)
|
||||
service_id = uuid4()
|
||||
await service_repo.create(service_id, org_id, f"Service {slug}", slug)
|
||||
return service_id
|
||||
|
||||
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
|
||||
"""Helper to create a user."""
|
||||
user_repo = UserRepository(conn)
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, email, "hash")
|
||||
return user_id
|
||||
|
||||
async def test_create_incident_returns_incident_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating an incident returns the incident data with triggered status."""
|
||||
org_id = await self._create_org(db_conn, "incident-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "incident-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
incident_id = uuid4()
|
||||
|
||||
result = await repo.create(
|
||||
incident_id, org_id, service_id,
|
||||
title="Server Down",
|
||||
description="Main API server is not responding",
|
||||
severity="critical"
|
||||
)
|
||||
|
||||
assert result["id"] == incident_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["service_id"] == service_id
|
||||
assert result["title"] == "Server Down"
|
||||
assert result["description"] == "Main API server is not responding"
|
||||
assert result["status"] == "triggered" # Initial status per SPECS.md
|
||||
assert result["severity"] == "critical"
|
||||
assert result["version"] == 1
|
||||
assert result["created_at"] is not None
|
||||
assert result["updated_at"] is not None
|
||||
|
||||
async def test_create_incident_initial_status_is_triggered(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""New incidents always start with 'triggered' status per SPECS.md state machine."""
|
||||
org_id = await self._create_org(db_conn, "triggered-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "triggered-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
result = await repo.create(uuid4(), org_id, service_id, "Test", None, "low")
|
||||
|
||||
assert result["status"] == "triggered"
|
||||
|
||||
async def test_create_incident_initial_version_is_one(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""New incidents start with version 1 for optimistic locking."""
|
||||
org_id = await self._create_org(db_conn, "version-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "version-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
result = await repo.create(uuid4(), org_id, service_id, "Test", None, "medium")
|
||||
|
||||
assert result["version"] == 1
|
||||
|
||||
async def test_create_incident_severity_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Severity must be critical, high, medium, or low per SPECS.md."""
|
||||
org_id = await self._create_org(db_conn, "severity-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "severity-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
# Valid severities
|
||||
for severity in ["critical", "high", "medium", "low"]:
|
||||
result = await repo.create(uuid4(), org_id, service_id, f"Test {severity}", None, severity)
|
||||
assert result["severity"] == severity
|
||||
|
||||
# Invalid severity
|
||||
with pytest.raises(asyncpg.CheckViolationError):
|
||||
await repo.create(uuid4(), org_id, service_id, "Invalid", None, "extreme")
|
||||
|
||||
async def test_get_by_id_returns_incident(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns the correct incident."""
|
||||
org_id = await self._create_org(db_conn, "getbyid-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "getbyid-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
incident_id = uuid4()
|
||||
|
||||
await repo.create(incident_id, org_id, service_id, "My Incident", "Details", "high")
|
||||
result = await repo.get_by_id(incident_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == incident_id
|
||||
assert result["title"] == "My Incident"
|
||||
|
||||
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns None for non-existent incident."""
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_org_returns_org_incidents(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns incidents for the organization."""
|
||||
org_id = await self._create_org(db_conn, "list-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "list-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, service_id, "Incident 1", None, "low")
|
||||
await repo.create(uuid4(), org_id, service_id, "Incident 2", None, "medium")
|
||||
await repo.create(uuid4(), org_id, service_id, "Incident 3", None, "high")
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
async def test_get_by_org_filters_by_status(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org can filter by status."""
|
||||
org_id = await self._create_org(db_conn, "filter-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "filter-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
# Create incidents and transition some
|
||||
inc1 = uuid4()
|
||||
inc2 = uuid4()
|
||||
await repo.create(inc1, org_id, service_id, "Triggered", None, "low")
|
||||
await repo.create(inc2, org_id, service_id, "Will be Acked", None, "low")
|
||||
await repo.update_status(inc2, "acknowledged", 1)
|
||||
|
||||
result = await repo.get_by_org(org_id, status="triggered")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "Triggered"
|
||||
|
||||
async def test_get_by_org_pagination_with_cursor(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org supports cursor-based pagination."""
|
||||
org_id = await self._create_org(db_conn, "pagination-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "pagination-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
# Create 5 incidents
|
||||
for i in range(5):
|
||||
await repo.create(uuid4(), org_id, service_id, f"Incident {i}", None, "low")
|
||||
|
||||
# Get first page - should return limit+1 to check for more
|
||||
page1 = await repo.get_by_org(org_id, limit=2)
|
||||
assert len(page1) == 3
|
||||
|
||||
# Verify total is 5 when we get all
|
||||
all_incidents = await repo.get_by_org(org_id, limit=10)
|
||||
assert len(all_incidents) == 5
|
||||
|
||||
async def test_get_by_org_orders_by_created_at_desc(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns incidents ordered by created_at descending."""
|
||||
org_id = await self._create_org(db_conn, "order-org")
|
||||
service_id = await self._create_service(db_conn, org_id, "order-service")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, service_id, "First", None, "low")
|
||||
await repo.create(uuid4(), org_id, service_id, "Second", None, "low")
|
||||
await repo.create(uuid4(), org_id, service_id, "Third", None, "low")
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
# Verify ordering - newer items should come first (or same time due to fast execution)
|
||||
assert len(result) == 3
|
||||
for i in range(len(result) - 1):
|
||||
assert result[i]["created_at"] >= result[i + 1]["created_at"]
|
||||
|
||||
async def test_get_by_org_tenant_isolation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org only returns incidents for the specified org."""
|
||||
org1 = await self._create_org(db_conn, "tenant-org-1")
|
||||
org2 = await self._create_org(db_conn, "tenant-org-2")
|
||||
service1 = await self._create_service(db_conn, org1, "tenant-service-1")
|
||||
service2 = await self._create_service(db_conn, org2, "tenant-service-2")
|
||||
repo = IncidentRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org1, service1, "Org1 Incident", None, "low")
|
||||
await repo.create(uuid4(), org2, service2, "Org2 Incident", None, "low")
|
||||
|
||||
result = await repo.get_by_org(org1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "Org1 Incident"
|
||||
|
||||
|
||||
class TestIncidentStatusTransitions:
|
||||
"""Tests for incident status transitions per SPECS.md state machine."""
|
||||
|
||||
async def _setup_incident(self, conn: asyncpg.Connection) -> tuple[uuid4, IncidentRepository]:
|
||||
"""Helper to create org, service, and incident."""
|
||||
org_repo = OrgRepository(conn)
|
||||
service_repo = ServiceRepository(conn)
|
||||
incident_repo = IncidentRepository(conn)
|
||||
|
||||
org_id = uuid4()
|
||||
service_id = uuid4()
|
||||
incident_id = uuid4()
|
||||
|
||||
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
|
||||
await service_repo.create(service_id, org_id, "Test Service", f"test-service-{uuid4().hex[:8]}")
|
||||
await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium")
|
||||
|
||||
return incident_id, incident_repo
|
||||
|
||||
async def test_update_status_increments_version(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_status increments version for optimistic locking."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
result = await repo.update_status(incident_id, "acknowledged", 1)
|
||||
|
||||
assert result is not None
|
||||
assert result["version"] == 2
|
||||
|
||||
async def test_update_status_fails_on_version_mismatch(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_status returns None on version mismatch (optimistic locking)."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
# Try with wrong version
|
||||
result = await repo.update_status(incident_id, "acknowledged", 999)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_update_status_updates_updated_at(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_status updates the updated_at timestamp."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
before = await repo.get_by_id(incident_id)
|
||||
result = await repo.update_status(incident_id, "acknowledged", 1)
|
||||
|
||||
# updated_at should be at least as recent as before (may be same in fast execution)
|
||||
assert result["updated_at"] >= before["updated_at"]
|
||||
# Also verify status was actually updated
|
||||
assert result["status"] == "acknowledged"
|
||||
|
||||
async def test_status_must_be_valid_value(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Status must be triggered, acknowledged, mitigated, or resolved per SPECS.md."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.CheckViolationError):
|
||||
await repo.update_status(incident_id, "invalid_status", 1)
|
||||
|
||||
async def test_valid_status_transitions(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Test the valid status values per SPECS.md."""
|
||||
incident_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
# Triggered -> Acknowledged
|
||||
result = await repo.update_status(incident_id, "acknowledged", 1)
|
||||
assert result["status"] == "acknowledged"
|
||||
|
||||
# Acknowledged -> Mitigated
|
||||
result = await repo.update_status(incident_id, "mitigated", 2)
|
||||
assert result["status"] == "mitigated"
|
||||
|
||||
# Mitigated -> Resolved
|
||||
result = await repo.update_status(incident_id, "resolved", 3)
|
||||
assert result["status"] == "resolved"
|
||||
|
||||
|
||||
class TestIncidentEvents:
|
||||
"""Tests for incident events (timeline) per SPECS.md incident_events table."""
|
||||
|
||||
async def _setup_incident(self, conn: asyncpg.Connection) -> tuple[uuid4, uuid4, IncidentRepository]:
|
||||
"""Helper to create org, service, user, and incident."""
|
||||
org_repo = OrgRepository(conn)
|
||||
service_repo = ServiceRepository(conn)
|
||||
user_repo = UserRepository(conn)
|
||||
incident_repo = IncidentRepository(conn)
|
||||
|
||||
org_id = uuid4()
|
||||
service_id = uuid4()
|
||||
user_id = uuid4()
|
||||
incident_id = uuid4()
|
||||
|
||||
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
|
||||
await service_repo.create(service_id, org_id, "Test Service", f"test-svc-{uuid4().hex[:8]}")
|
||||
await user_repo.create(user_id, f"user-{uuid4().hex[:8]}@example.com", "hash")
|
||||
await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium")
|
||||
|
||||
return incident_id, user_id, incident_repo
|
||||
|
||||
async def test_add_event_creates_event(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""add_event creates an event in the timeline."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
event_id = uuid4()
|
||||
|
||||
result = await repo.add_event(
|
||||
event_id, incident_id, "status_changed",
|
||||
actor_user_id=user_id,
|
||||
payload={"from": "triggered", "to": "acknowledged"}
|
||||
)
|
||||
|
||||
assert result["id"] == event_id
|
||||
assert result["incident_id"] == incident_id
|
||||
assert result["event_type"] == "status_changed"
|
||||
assert result["actor_user_id"] == user_id
|
||||
assert result["payload"] == {"from": "triggered", "to": "acknowledged"}
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_add_event_allows_null_actor(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""add_event allows null actor_user_id (system events)."""
|
||||
incident_id, _, repo = await self._setup_incident(db_conn)
|
||||
|
||||
result = await repo.add_event(
|
||||
uuid4(), incident_id, "auto_escalated",
|
||||
actor_user_id=None,
|
||||
payload={"reason": "Unacknowledged after 30 minutes"}
|
||||
)
|
||||
|
||||
assert result["actor_user_id"] is None
|
||||
|
||||
async def test_add_event_allows_null_payload(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""add_event allows null payload."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
result = await repo.add_event(
|
||||
uuid4(), incident_id, "viewed",
|
||||
actor_user_id=user_id,
|
||||
payload=None
|
||||
)
|
||||
|
||||
assert result["payload"] is None
|
||||
|
||||
async def test_get_events_returns_all_incident_events(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_events returns all events for an incident."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
await repo.add_event(uuid4(), incident_id, "created", user_id, {"title": "Test"})
|
||||
await repo.add_event(uuid4(), incident_id, "status_changed", user_id, {"to": "acked"})
|
||||
await repo.add_event(uuid4(), incident_id, "comment_added", user_id, {"text": "Working on it"})
|
||||
|
||||
result = await repo.get_events(incident_id)
|
||||
|
||||
assert len(result) == 3
|
||||
event_types = [e["event_type"] for e in result]
|
||||
assert event_types == ["created", "status_changed", "comment_added"]
|
||||
|
||||
async def test_get_events_orders_by_created_at(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_events returns events in chronological order."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
await repo.add_event(uuid4(), incident_id, "first", user_id, None)
|
||||
await repo.add_event(uuid4(), incident_id, "second", user_id, None)
|
||||
await repo.add_event(uuid4(), incident_id, "third", user_id, None)
|
||||
|
||||
result = await repo.get_events(incident_id)
|
||||
|
||||
assert result[0]["event_type"] == "first"
|
||||
assert result[1]["event_type"] == "second"
|
||||
assert result[2]["event_type"] == "third"
|
||||
|
||||
async def test_get_events_returns_empty_for_no_events(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_events returns empty list for incident with no events."""
|
||||
incident_id, _, repo = await self._setup_incident(db_conn)
|
||||
|
||||
result = await repo.get_events(incident_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_event_requires_valid_incident_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""incident_events.incident_id must reference existing incident."""
|
||||
incident_id, user_id, repo = await self._setup_incident(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.add_event(uuid4(), uuid4(), "test", user_id, None)
|
||||
|
||||
async def test_event_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""incident_events.actor_user_id must reference existing user if not null."""
|
||||
incident_id, _, repo = await self._setup_incident(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.add_event(uuid4(), incident_id, "test", uuid4(), None)
|
||||
362
tests/repositories/test_notification.py
Normal file
362
tests/repositories/test_notification.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""Tests for NotificationRepository."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.incident import IncidentRepository
|
||||
from app.repositories.notification import NotificationRepository
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.service import ServiceRepository
|
||||
|
||||
|
||||
class TestNotificationTargetRepository:
|
||||
"""Tests for notification targets per SPECS.md notification_targets table."""
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def test_create_target_returns_target_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a notification target returns the target data."""
|
||||
org_id = await self._create_org(db_conn, "target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
result = await repo.create_target(
|
||||
target_id, org_id, "Slack Alerts",
|
||||
target_type="webhook",
|
||||
webhook_url="https://hooks.slack.com/services/xxx",
|
||||
enabled=True
|
||||
)
|
||||
|
||||
assert result["id"] == target_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["name"] == "Slack Alerts"
|
||||
assert result["target_type"] == "webhook"
|
||||
assert result["webhook_url"] == "https://hooks.slack.com/services/xxx"
|
||||
assert result["enabled"] is True
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_target_type_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Target type must be webhook, email, or slack per SPECS.md."""
|
||||
org_id = await self._create_org(db_conn, "type-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
# Valid types
|
||||
for target_type in ["webhook", "email", "slack"]:
|
||||
result = await repo.create_target(
|
||||
uuid4(), org_id, f"{target_type} target",
|
||||
target_type=target_type
|
||||
)
|
||||
assert result["target_type"] == target_type
|
||||
|
||||
# Invalid type
|
||||
with pytest.raises(asyncpg.CheckViolationError):
|
||||
await repo.create_target(
|
||||
uuid4(), org_id, "Invalid",
|
||||
target_type="sms"
|
||||
)
|
||||
|
||||
async def test_create_target_webhook_url_optional(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""webhook_url is optional (for email/slack types)."""
|
||||
org_id = await self._create_org(db_conn, "optional-url-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
result = await repo.create_target(
|
||||
uuid4(), org_id, "Email Alerts",
|
||||
target_type="email",
|
||||
webhook_url=None
|
||||
)
|
||||
|
||||
assert result["webhook_url"] is None
|
||||
|
||||
async def test_create_target_enabled_defaults_to_true(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""enabled defaults to True."""
|
||||
org_id = await self._create_org(db_conn, "default-enabled-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
result = await repo.create_target(
|
||||
uuid4(), org_id, "Default Enabled",
|
||||
target_type="webhook"
|
||||
)
|
||||
|
||||
assert result["enabled"] is True
|
||||
|
||||
async def test_get_target_by_id_returns_target(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_target_by_id returns the correct target."""
|
||||
org_id = await self._create_org(db_conn, "getbyid-target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
await repo.create_target(target_id, org_id, "My Target", "webhook")
|
||||
result = await repo.get_target_by_id(target_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == target_id
|
||||
assert result["name"] == "My Target"
|
||||
|
||||
async def test_get_target_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_target_by_id returns None for non-existent target."""
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
result = await repo.get_target_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_targets_by_org_returns_all_targets(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_targets_by_org returns all targets for an organization."""
|
||||
org_id = await self._create_org(db_conn, "multi-target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
await repo.create_target(uuid4(), org_id, "Target A", "webhook")
|
||||
await repo.create_target(uuid4(), org_id, "Target B", "email")
|
||||
await repo.create_target(uuid4(), org_id, "Target C", "slack")
|
||||
|
||||
result = await repo.get_targets_by_org(org_id)
|
||||
|
||||
assert len(result) == 3
|
||||
names = {t["name"] for t in result}
|
||||
assert names == {"Target A", "Target B", "Target C"}
|
||||
|
||||
async def test_get_targets_by_org_filters_enabled(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_targets_by_org can filter to only enabled targets."""
|
||||
org_id = await self._create_org(db_conn, "enabled-filter-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
await repo.create_target(uuid4(), org_id, "Enabled", "webhook", enabled=True)
|
||||
await repo.create_target(uuid4(), org_id, "Disabled", "webhook", enabled=False)
|
||||
|
||||
result = await repo.get_targets_by_org(org_id, enabled_only=True)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Enabled"
|
||||
|
||||
async def test_get_targets_by_org_tenant_isolation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_targets_by_org only returns targets for the specified org."""
|
||||
org1 = await self._create_org(db_conn, "isolated-target-org-1")
|
||||
org2 = await self._create_org(db_conn, "isolated-target-org-2")
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
await repo.create_target(uuid4(), org1, "Org1 Target", "webhook")
|
||||
await repo.create_target(uuid4(), org2, "Org2 Target", "webhook")
|
||||
|
||||
result = await repo.get_targets_by_org(org1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Org1 Target"
|
||||
|
||||
async def test_update_target_updates_fields(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_target updates the specified fields."""
|
||||
org_id = await self._create_org(db_conn, "update-target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
await repo.create_target(target_id, org_id, "Original", "webhook", enabled=True)
|
||||
result = await repo.update_target(target_id, name="Updated", enabled=False)
|
||||
|
||||
assert result is not None
|
||||
assert result["name"] == "Updated"
|
||||
assert result["enabled"] is False
|
||||
|
||||
async def test_update_target_partial_update(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_target only updates provided fields."""
|
||||
org_id = await self._create_org(db_conn, "partial-update-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
await repo.create_target(
|
||||
target_id, org_id, "Original Name", "webhook",
|
||||
webhook_url="https://original.com", enabled=True
|
||||
)
|
||||
result = await repo.update_target(target_id, name="New Name")
|
||||
|
||||
assert result["name"] == "New Name"
|
||||
assert result["webhook_url"] == "https://original.com"
|
||||
assert result["enabled"] is True
|
||||
|
||||
async def test_delete_target_removes_target(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""delete_target removes the target."""
|
||||
org_id = await self._create_org(db_conn, "delete-target-org")
|
||||
repo = NotificationRepository(db_conn)
|
||||
target_id = uuid4()
|
||||
|
||||
await repo.create_target(target_id, org_id, "To Delete", "webhook")
|
||||
result = await repo.delete_target(target_id)
|
||||
|
||||
assert result is True
|
||||
assert await repo.get_target_by_id(target_id) is None
|
||||
|
||||
async def test_delete_target_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""delete_target returns False for non-existent target."""
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
result = await repo.delete_target(uuid4())
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_target_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""notification_targets.org_id must reference existing org."""
|
||||
repo = NotificationRepository(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.create_target(uuid4(), uuid4(), "Orphan Target", "webhook")
|
||||
|
||||
|
||||
class TestNotificationAttemptRepository:
|
||||
"""Tests for notification attempts per SPECS.md notification_attempts table."""
|
||||
|
||||
async def _setup_incident_and_target(self, conn: asyncpg.Connection) -> tuple[uuid4, uuid4, NotificationRepository]:
|
||||
"""Helper to create org, service, incident, and notification target."""
|
||||
org_repo = OrgRepository(conn)
|
||||
service_repo = ServiceRepository(conn)
|
||||
incident_repo = IncidentRepository(conn)
|
||||
notification_repo = NotificationRepository(conn)
|
||||
|
||||
org_id = uuid4()
|
||||
service_id = uuid4()
|
||||
incident_id = uuid4()
|
||||
target_id = uuid4()
|
||||
|
||||
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
|
||||
await service_repo.create(service_id, org_id, "Test Service", f"test-svc-{uuid4().hex[:8]}")
|
||||
await incident_repo.create(incident_id, org_id, service_id, "Test Incident", None, "medium")
|
||||
await notification_repo.create_target(target_id, org_id, "Test Target", "webhook")
|
||||
|
||||
return incident_id, target_id, notification_repo
|
||||
|
||||
async def test_create_attempt_returns_attempt_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a notification attempt returns the attempt data."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
attempt_id = uuid4()
|
||||
|
||||
result = await repo.create_attempt(attempt_id, incident_id, target_id)
|
||||
|
||||
assert result["id"] == attempt_id
|
||||
assert result["incident_id"] == incident_id
|
||||
assert result["target_id"] == target_id
|
||||
assert result["status"] == "pending"
|
||||
assert result["error"] is None
|
||||
assert result["sent_at"] is None
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_attempt_idempotent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""create_attempt is idempotent per SPECS.md (unique constraint on incident+target)."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
# First attempt
|
||||
result1 = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
# Second attempt with same incident+target
|
||||
result2 = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
|
||||
# Should return the same attempt
|
||||
assert result1["id"] == result2["id"]
|
||||
|
||||
async def test_get_attempt_returns_attempt(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_attempt returns the attempt for incident and target."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
result = await repo.get_attempt(incident_id, target_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["incident_id"] == incident_id
|
||||
assert result["target_id"] == target_id
|
||||
|
||||
async def test_get_attempt_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_attempt returns None for non-existent attempt."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
result = await repo.get_attempt(incident_id, target_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_update_attempt_success_sets_sent_status(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_attempt_success marks attempt as sent."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
sent_at = datetime.now(UTC)
|
||||
|
||||
result = await repo.update_attempt_success(attempt["id"], sent_at)
|
||||
|
||||
assert result is not None
|
||||
assert result["status"] == "sent"
|
||||
assert result["sent_at"] is not None
|
||||
assert result["error"] is None
|
||||
|
||||
async def test_update_attempt_failure_sets_failed_status(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""update_attempt_failure marks attempt as failed with error."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
|
||||
result = await repo.update_attempt_failure(attempt["id"], "Connection timeout")
|
||||
|
||||
assert result is not None
|
||||
assert result["status"] == "failed"
|
||||
assert result["error"] == "Connection timeout"
|
||||
|
||||
async def test_attempt_status_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Attempt status must be pending, sent, or failed per SPECS.md."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
# Create with default 'pending' status - valid
|
||||
result = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
assert result["status"] == "pending"
|
||||
|
||||
# Transition to 'sent' - valid
|
||||
result = await repo.update_attempt_success(result["id"], datetime.now(UTC))
|
||||
assert result["status"] == "sent"
|
||||
|
||||
async def test_get_pending_attempts_returns_pending_with_target_info(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_pending_attempts returns pending attempts with target details."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
result = await repo.get_pending_attempts(incident_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["status"] == "pending"
|
||||
assert result[0]["target_id"] == target_id
|
||||
assert "target_type" in result[0]
|
||||
assert "target_name" in result[0]
|
||||
|
||||
async def test_get_pending_attempts_excludes_sent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_pending_attempts excludes sent attempts."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
await repo.update_attempt_success(attempt["id"], datetime.now(UTC))
|
||||
|
||||
result = await repo.get_pending_attempts(incident_id)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
async def test_get_pending_attempts_excludes_failed(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_pending_attempts excludes failed attempts."""
|
||||
incident_id, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
attempt = await repo.create_attempt(uuid4(), incident_id, target_id)
|
||||
await repo.update_attempt_failure(attempt["id"], "Error")
|
||||
|
||||
result = await repo.get_pending_attempts(incident_id)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
async def test_attempt_requires_valid_incident_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""notification_attempts.incident_id must reference existing incident."""
|
||||
_, target_id, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.create_attempt(uuid4(), uuid4(), target_id)
|
||||
|
||||
async def test_attempt_requires_valid_target_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""notification_attempts.target_id must reference existing target."""
|
||||
incident_id, _, repo = await self._setup_incident_and_target(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.create_attempt(uuid4(), incident_id, uuid4())
|
||||
250
tests/repositories/test_org.py
Normal file
250
tests/repositories/test_org.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Tests for OrgRepository."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
|
||||
class TestOrgRepository:
|
||||
"""Tests for OrgRepository conforming to SPECS.md."""
|
||||
|
||||
async def test_create_org_returns_org_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating an org returns the org data."""
|
||||
repo = OrgRepository(db_conn)
|
||||
org_id = uuid4()
|
||||
name = "Test Organization"
|
||||
slug = "test-org"
|
||||
|
||||
result = await repo.create(org_id, name, slug)
|
||||
|
||||
assert result["id"] == org_id
|
||||
assert result["name"] == name
|
||||
assert result["slug"] == slug
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_org_slug_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Org slug uniqueness constraint per SPECS.md orgs table."""
|
||||
repo = OrgRepository(db_conn)
|
||||
slug = "unique-slug"
|
||||
|
||||
await repo.create(uuid4(), "Org One", slug)
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(uuid4(), "Org Two", slug)
|
||||
|
||||
async def test_get_by_id_returns_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns the correct organization."""
|
||||
repo = OrgRepository(db_conn)
|
||||
org_id = uuid4()
|
||||
|
||||
await repo.create(org_id, "My Org", "my-org")
|
||||
result = await repo.get_by_id(org_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == org_id
|
||||
assert result["name"] == "My Org"
|
||||
|
||||
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns None for non-existent org."""
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_slug_returns_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns the correct organization."""
|
||||
repo = OrgRepository(db_conn)
|
||||
org_id = uuid4()
|
||||
slug = "slug-lookup"
|
||||
|
||||
await repo.create(org_id, "Slug Test", slug)
|
||||
result = await repo.get_by_slug(slug)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == org_id
|
||||
|
||||
async def test_get_by_slug_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns None for non-existent slug."""
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_slug("nonexistent-slug")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_slug_exists_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns True when slug exists."""
|
||||
repo = OrgRepository(db_conn)
|
||||
slug = "existing-slug"
|
||||
|
||||
await repo.create(uuid4(), "Existing Org", slug)
|
||||
result = await repo.slug_exists(slug)
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_slug_exists_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns False when slug doesn't exist."""
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.slug_exists("no-such-slug")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestOrgMembership:
|
||||
"""Tests for org membership operations per SPECS.md org_members table."""
|
||||
|
||||
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
|
||||
"""Helper to create a user."""
|
||||
user_repo = UserRepository(conn)
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, email, "hash")
|
||||
return user_id
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def test_add_member_creates_membership(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""add_member creates a membership record."""
|
||||
user_id = await self._create_user(db_conn, "member@example.com")
|
||||
org_id = await self._create_org(db_conn, "member-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.add_member(uuid4(), user_id, org_id, "member")
|
||||
|
||||
assert result["user_id"] == user_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["role"] == "member"
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_add_member_role_must_be_valid(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Role must be admin, member, or viewer per SPECS.md."""
|
||||
org_id = await self._create_org(db_conn, "role-test-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
# Valid roles should work
|
||||
for role in ["admin", "member", "viewer"]:
|
||||
member_id = uuid4()
|
||||
# Need a new user for each since user+org must be unique
|
||||
new_user_id = await self._create_user(db_conn, f"{role}@example.com")
|
||||
result = await repo.add_member(member_id, new_user_id, org_id, role)
|
||||
assert result["role"] == role
|
||||
|
||||
# Invalid role should fail
|
||||
another_user = await self._create_user(db_conn, "invalid_role@example.com")
|
||||
with pytest.raises(asyncpg.CheckViolationError):
|
||||
await repo.add_member(uuid4(), another_user, org_id, "superuser")
|
||||
|
||||
async def test_add_member_user_org_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""User can only be member of an org once (unique constraint)."""
|
||||
user_id = await self._create_user(db_conn, "unique_member@example.com")
|
||||
org_id = await self._create_org(db_conn, "unique-member-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
await repo.add_member(uuid4(), user_id, org_id, "member")
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.add_member(uuid4(), user_id, org_id, "admin")
|
||||
|
||||
async def test_get_member_returns_membership(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_member returns the membership for user and org."""
|
||||
user_id = await self._create_user(db_conn, "get_member@example.com")
|
||||
org_id = await self._create_org(db_conn, "get-member-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
await repo.add_member(uuid4(), user_id, org_id, "admin")
|
||||
result = await repo.get_member(user_id, org_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["user_id"] == user_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["role"] == "admin"
|
||||
|
||||
async def test_get_member_returns_none_for_nonmember(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_member returns None if user is not a member."""
|
||||
user_id = await self._create_user(db_conn, "nonmember@example.com")
|
||||
org_id = await self._create_org(db_conn, "nonmember-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_member(user_id, org_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_members_returns_all_org_members(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_members returns all members with their emails."""
|
||||
org_id = await self._create_org(db_conn, "all-members-org")
|
||||
user1 = await self._create_user(db_conn, "user1@example.com")
|
||||
user2 = await self._create_user(db_conn, "user2@example.com")
|
||||
user3 = await self._create_user(db_conn, "user3@example.com")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
await repo.add_member(uuid4(), user1, org_id, "admin")
|
||||
await repo.add_member(uuid4(), user2, org_id, "member")
|
||||
await repo.add_member(uuid4(), user3, org_id, "viewer")
|
||||
|
||||
result = await repo.get_members(org_id)
|
||||
|
||||
assert len(result) == 3
|
||||
emails = {m["email"] for m in result}
|
||||
assert emails == {"user1@example.com", "user2@example.com", "user3@example.com"}
|
||||
|
||||
async def test_get_members_returns_empty_list_for_no_members(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_members returns empty list for org with no members."""
|
||||
org_id = await self._create_org(db_conn, "empty-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_members(org_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_get_user_orgs_returns_all_user_memberships(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_user_orgs returns all orgs a user belongs to with their role."""
|
||||
user_id = await self._create_user(db_conn, "multi_org@example.com")
|
||||
org1 = await self._create_org(db_conn, "user-org-1")
|
||||
org2 = await self._create_org(db_conn, "user-org-2")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
await repo.add_member(uuid4(), user_id, org1, "admin")
|
||||
await repo.add_member(uuid4(), user_id, org2, "member")
|
||||
|
||||
result = await repo.get_user_orgs(user_id)
|
||||
|
||||
assert len(result) == 2
|
||||
slugs = {o["slug"] for o in result}
|
||||
assert slugs == {"user-org-1", "user-org-2"}
|
||||
# Check role is included
|
||||
roles = {o["role"] for o in result}
|
||||
assert roles == {"admin", "member"}
|
||||
|
||||
async def test_get_user_orgs_returns_empty_for_no_memberships(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_user_orgs returns empty list for user with no memberships."""
|
||||
user_id = await self._create_user(db_conn, "no_orgs@example.com")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
result = await repo.get_user_orgs(user_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_member_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""org_members.user_id must reference existing user."""
|
||||
org_id = await self._create_org(db_conn, "fk-test-org")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.add_member(uuid4(), uuid4(), org_id, "member")
|
||||
|
||||
async def test_member_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""org_members.org_id must reference existing org."""
|
||||
user_id = await self._create_user(db_conn, "fk_user@example.com")
|
||||
repo = OrgRepository(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.add_member(uuid4(), user_id, uuid4(), "member")
|
||||
788
tests/repositories/test_refresh_token.py
Normal file
788
tests/repositories/test_refresh_token.py
Normal file
@@ -0,0 +1,788 @@
|
||||
"""Tests for RefreshTokenRepository with security features."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.refresh_token import RefreshTokenRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
|
||||
class TestRefreshTokenRepository:
|
||||
"""Tests for basic RefreshTokenRepository operations."""
|
||||
|
||||
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
|
||||
"""Helper to create a user."""
|
||||
user_repo = UserRepository(conn)
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, email, "hash")
|
||||
return user_id
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def test_create_token_returns_token_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a refresh token returns the token data including rotated_to."""
|
||||
user_id = await self._create_user(db_conn, "token_create@example.com")
|
||||
org_id = await self._create_org(db_conn, "token-create-org")
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
token_id = uuid4()
|
||||
token_hash = "sha256_hashed_token_value"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
result = await repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
assert result["id"] == token_id
|
||||
assert result["user_id"] == user_id
|
||||
assert result["token_hash"] == token_hash
|
||||
assert result["active_org_id"] == org_id
|
||||
assert result["expires_at"] is not None
|
||||
assert result["revoked_at"] is None
|
||||
assert result["rotated_to"] is None # New field
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_token_hash_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Token hash uniqueness constraint per SPECS.md refresh_tokens table."""
|
||||
user_id = await self._create_user(db_conn, "unique_hash@example.com")
|
||||
org_id = await self._create_org(db_conn, "unique-hash-org")
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
token_hash = "duplicate_hash_value"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
async def test_get_by_hash_returns_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_hash returns the correct token (even if revoked/expired)."""
|
||||
user_id = await self._create_user(db_conn, "get_hash@example.com")
|
||||
org_id = await self._create_org(db_conn, "get-hash-org")
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
token_id = uuid4()
|
||||
token_hash = "lookup_by_hash_value"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
await repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
result = await repo.get_by_hash(token_hash)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == token_id
|
||||
assert result["token_hash"] == token_hash
|
||||
|
||||
async def test_get_by_hash_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_hash returns None for non-existent hash."""
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_hash("nonexistent_hash")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetValidByHash:
|
||||
"""Tests for get_valid_by_hash with defense-in-depth validation."""
|
||||
|
||||
async def _setup_token(
|
||||
self, conn: asyncpg.Connection, suffix: str = ""
|
||||
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
|
||||
"""Helper to create user, org, and token."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
token_id = uuid4()
|
||||
token_hash = f"token_hash_{uuid4().hex[:8]}{suffix}"
|
||||
|
||||
await user_repo.create(user_id, f"user_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
return token_id, user_id, org_id, token_hash, token_repo
|
||||
|
||||
async def test_get_valid_returns_valid_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns token if not expired, not revoked, not rotated."""
|
||||
_, _, _, token_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.get_valid_by_hash(token_hash)
|
||||
|
||||
assert result is not None
|
||||
assert result["token_hash"] == token_hash
|
||||
|
||||
async def test_get_valid_returns_none_for_expired(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns None for expired token."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "expired@example.com", "hash")
|
||||
await org_repo.create(org_id, "Org", "expired-org")
|
||||
|
||||
token_hash = "expired_token_hash"
|
||||
expires_at = datetime.now(UTC) - timedelta(days=1) # Already expired
|
||||
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
result = await repo.get_valid_by_hash(token_hash)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_returns_none_for_revoked(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns None for revoked token."""
|
||||
token_id, _, _, token_hash, repo = await self._setup_token(db_conn, "_revoked")
|
||||
|
||||
await repo.revoke(token_id)
|
||||
result = await repo.get_valid_by_hash(token_hash)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_returns_none_for_rotated(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns None for already-rotated token."""
|
||||
_, user_id, org_id, old_hash, repo = await self._setup_token(db_conn, "_rotated")
|
||||
|
||||
# Rotate the token
|
||||
new_hash = f"new_token_{uuid4().hex[:8]}"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=30)
|
||||
await repo.rotate(old_hash, uuid4(), new_hash, new_expires)
|
||||
|
||||
# Old token should no longer be valid
|
||||
result = await repo.get_valid_by_hash(old_hash)
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_with_user_id_validation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash validates user_id when provided (defense-in-depth)."""
|
||||
_, user_id, _, token_hash, repo = await self._setup_token(db_conn, "_user_check")
|
||||
|
||||
# Correct user_id should work
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=user_id)
|
||||
assert result is not None
|
||||
|
||||
# Wrong user_id should return None
|
||||
wrong_user_id = uuid4()
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=wrong_user_id)
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_with_org_id_validation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash validates active_org_id when provided (defense-in-depth)."""
|
||||
_, _, org_id, token_hash, repo = await self._setup_token(db_conn, "_org_check")
|
||||
|
||||
# Correct org_id should work
|
||||
result = await repo.get_valid_by_hash(token_hash, active_org_id=org_id)
|
||||
assert result is not None
|
||||
|
||||
# Wrong org_id should return None
|
||||
wrong_org_id = uuid4()
|
||||
result = await repo.get_valid_by_hash(token_hash, active_org_id=wrong_org_id)
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_with_both_user_and_org_validation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash validates both user_id and active_org_id together."""
|
||||
_, user_id, org_id, token_hash, repo = await self._setup_token(db_conn, "_both")
|
||||
|
||||
# Both correct should work
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=org_id)
|
||||
assert result is not None
|
||||
|
||||
# Either wrong should fail
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=uuid4(), active_org_id=org_id)
|
||||
assert result is None
|
||||
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=uuid4())
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAtomicRotation:
|
||||
"""Tests for atomic token rotation per SPECS.md."""
|
||||
|
||||
async def _setup_token(
|
||||
self, conn: asyncpg.Connection
|
||||
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
|
||||
"""Helper to create user, org, and token."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
token_id = uuid4()
|
||||
token_hash = f"rotate_token_{uuid4().hex[:8]}"
|
||||
|
||||
await user_repo.create(user_id, f"rotate_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Rotate Org", f"rotate-org-{uuid4().hex[:8]}")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
return token_id, user_id, org_id, token_hash, token_repo
|
||||
|
||||
async def test_rotate_creates_new_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() creates a new token and returns it."""
|
||||
old_id, user_id, org_id, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
new_id = uuid4()
|
||||
new_hash = f"new_rotated_{uuid4().hex[:8]}"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
result = await repo.rotate(old_hash, new_id, new_hash, new_expires)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == new_id
|
||||
assert result["token_hash"] == new_hash
|
||||
assert result["user_id"] == user_id
|
||||
assert result["active_org_id"] == org_id
|
||||
|
||||
async def test_rotate_marks_old_token_as_rotated(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() sets rotated_to on the old token (not revoked_at)."""
|
||||
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
new_id = uuid4()
|
||||
new_hash = f"new_{uuid4().hex[:8]}"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
await repo.rotate(old_hash, new_id, new_hash, new_expires)
|
||||
|
||||
# Check old token state
|
||||
old_token = await repo.get_by_hash(old_hash)
|
||||
assert old_token["rotated_to"] == new_id
|
||||
assert old_token["revoked_at"] is None # Not revoked, just rotated
|
||||
|
||||
async def test_rotate_fails_for_invalid_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token is invalid."""
|
||||
_, _, _, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.rotate(
|
||||
"nonexistent_hash",
|
||||
uuid4(),
|
||||
f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_rotate_fails_for_expired_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token is expired."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "exp_rotate@example.com", "hash")
|
||||
await org_repo.create(org_id, "Org", "exp-rotate-org")
|
||||
|
||||
old_hash = "expired_for_rotation"
|
||||
await repo.create(
|
||||
uuid4(), user_id, old_hash, org_id,
|
||||
datetime.now(UTC) - timedelta(days=1) # Already expired
|
||||
)
|
||||
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_rotate_fails_for_revoked_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token is revoked."""
|
||||
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
await repo.revoke(old_id)
|
||||
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_rotate_fails_for_already_rotated_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token was already rotated."""
|
||||
_, _, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
# First rotation should succeed
|
||||
result1 = await repo.rotate(
|
||||
old_hash, uuid4(), f"new1_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
assert result1 is not None
|
||||
|
||||
# Second rotation of same token should fail
|
||||
result2 = await repo.rotate(
|
||||
old_hash, uuid4(), f"new2_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
async def test_rotate_with_org_switch(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() can change active_org_id (for org-switch flow)."""
|
||||
_, user_id, old_org_id, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
# Create a new org for the user to switch to
|
||||
org_repo = OrgRepository(db_conn)
|
||||
new_org_id = uuid4()
|
||||
await org_repo.create(new_org_id, "New Org", f"new-org-{uuid4().hex[:8]}")
|
||||
|
||||
new_hash = f"switched_{uuid4().hex[:8]}"
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), new_hash,
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
new_active_org_id=new_org_id # Switch org
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["active_org_id"] == new_org_id
|
||||
assert result["active_org_id"] != old_org_id
|
||||
|
||||
async def test_rotate_validates_expected_user_id(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() fails if expected_user_id doesn't match token's user."""
|
||||
_, user_id, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
# Wrong user should fail
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
expected_user_id=uuid4() # Wrong user
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Correct user should work
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
expected_user_id=user_id # Correct user
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestTokenReuseDetection:
|
||||
"""Tests for detecting token reuse (stolen token attacks)."""
|
||||
|
||||
async def _setup_rotated_token(
|
||||
self, conn: asyncpg.Connection
|
||||
) -> tuple[uuid4, str, str, RefreshTokenRepository]:
|
||||
"""Create a token and rotate it, returning old and new hashes."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, f"reuse_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Reuse Org", f"reuse-org-{uuid4().hex[:8]}")
|
||||
|
||||
old_hash = f"old_token_{uuid4().hex[:8]}"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
old_token = await token_repo.create(uuid4(), user_id, old_hash, org_id, expires_at)
|
||||
|
||||
new_hash = f"new_token_{uuid4().hex[:8]}"
|
||||
await token_repo.rotate(old_hash, uuid4(), new_hash, expires_at)
|
||||
|
||||
return old_token["id"], old_hash, new_hash, token_repo
|
||||
|
||||
async def test_check_token_reuse_detects_rotated_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""check_token_reuse returns token if it has been rotated."""
|
||||
old_id, old_hash, _, repo = await self._setup_rotated_token(db_conn)
|
||||
|
||||
result = await repo.check_token_reuse(old_hash)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == old_id
|
||||
assert result["rotated_to"] is not None
|
||||
|
||||
async def test_check_token_reuse_returns_none_for_active_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""check_token_reuse returns None for token that hasn't been rotated."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "active@example.com", "hash")
|
||||
await org_repo.create(org_id, "Org", "active-org")
|
||||
|
||||
token_hash = "active_token_hash"
|
||||
await repo.create(
|
||||
uuid4(), user_id, token_hash, org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
result = await repo.check_token_reuse(token_hash)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_check_token_reuse_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""check_token_reuse returns None for non-existent token."""
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
result = await repo.check_token_reuse("nonexistent_hash")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestTokenChainRevocation:
|
||||
"""Tests for revoking entire token chains (breach response)."""
|
||||
|
||||
async def _setup_token_chain(
|
||||
self, conn: asyncpg.Connection, chain_length: int = 3
|
||||
) -> tuple[list[uuid4], list[str], uuid4, RefreshTokenRepository]:
|
||||
"""Create a chain of rotated tokens."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, f"chain_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Chain Org", f"chain-org-{uuid4().hex[:8]}")
|
||||
|
||||
token_ids = []
|
||||
token_hashes = []
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
# Create first token
|
||||
first_hash = f"chain_token_0_{uuid4().hex[:8]}"
|
||||
first_token = await token_repo.create(uuid4(), user_id, first_hash, org_id, expires_at)
|
||||
token_ids.append(first_token["id"])
|
||||
token_hashes.append(first_hash)
|
||||
|
||||
# Rotate to create chain
|
||||
current_hash = first_hash
|
||||
for i in range(1, chain_length):
|
||||
new_hash = f"chain_token_{i}_{uuid4().hex[:8]}"
|
||||
new_id = uuid4()
|
||||
await token_repo.rotate(current_hash, new_id, new_hash, expires_at)
|
||||
token_ids.append(new_id)
|
||||
token_hashes.append(new_hash)
|
||||
current_hash = new_hash
|
||||
|
||||
return token_ids, token_hashes, user_id, token_repo
|
||||
|
||||
async def test_revoke_token_chain_revokes_all_in_chain(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_token_chain revokes the token and all its rotations."""
|
||||
token_ids, token_hashes, _, repo = await self._setup_token_chain(db_conn, chain_length=3)
|
||||
|
||||
# Revoke starting from the first token
|
||||
count = await repo.revoke_token_chain(token_ids[0])
|
||||
|
||||
# Should revoke all 3 tokens in the chain
|
||||
# But note: only the last one wasn't already "consumed" by rotation
|
||||
# Let's check that revoke was called on all that were eligible
|
||||
assert count >= 1 # At least the leaf token
|
||||
|
||||
# Verify the leaf token is revoked
|
||||
leaf_token = await repo.get_by_hash(token_hashes[-1])
|
||||
assert leaf_token["revoked_at"] is not None
|
||||
|
||||
async def test_revoke_token_chain_returns_count(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_token_chain returns count of actually revoked tokens."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "single@example.com", "hash")
|
||||
await org_repo.create(org_id, "Single Org", "single-org")
|
||||
|
||||
token_hash = "single_chain_token"
|
||||
token = await repo.create(
|
||||
uuid4(), user_id, token_hash, org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
count = await repo.revoke_token_chain(token["id"])
|
||||
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestTokenRevocation:
|
||||
"""Tests for token revocation methods."""
|
||||
|
||||
async def _setup_token(self, conn: asyncpg.Connection) -> tuple[uuid4, str, RefreshTokenRepository]:
|
||||
"""Helper to create user, org, and token."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
token_id = uuid4()
|
||||
token_hash = f"revoke_token_{uuid4().hex[:8]}"
|
||||
|
||||
await user_repo.create(user_id, f"revoke_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Revoke Org", f"revoke-org-{uuid4().hex[:8]}")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
return token_id, token_hash, token_repo
|
||||
|
||||
async def test_revoke_sets_revoked_at(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() sets the revoked_at timestamp."""
|
||||
token_id, token_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke(token_id)
|
||||
|
||||
assert result is True
|
||||
token = await repo.get_by_hash(token_hash)
|
||||
assert token["revoked_at"] is not None
|
||||
|
||||
async def test_revoke_returns_true_on_success(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() returns True when token is revoked."""
|
||||
token_id, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke(token_id)
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_revoke_returns_false_for_already_revoked(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() returns False if token already revoked."""
|
||||
token_id, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
await repo.revoke(token_id)
|
||||
result = await repo.revoke(token_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_revoke_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() returns False for non-existent token."""
|
||||
_, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke(uuid4())
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_revoke_by_hash_works(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_by_hash() revokes token by hash value."""
|
||||
_, token_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke_by_hash(token_hash)
|
||||
|
||||
assert result is True
|
||||
token = await repo.get_by_hash(token_hash)
|
||||
assert token["revoked_at"] is not None
|
||||
|
||||
async def test_revoke_by_hash_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_by_hash() returns False for non-existent hash."""
|
||||
_, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke_by_hash("nonexistent_hash")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRevokeAllForUser:
|
||||
"""Tests for revoking all tokens for a user."""
|
||||
|
||||
async def test_revoke_all_for_user_revokes_all_tokens(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user() revokes all tokens for the user."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "multi_token@example.com", "hash")
|
||||
await org_repo.create(org_id, "Multi Token Org", "multi-token-org")
|
||||
|
||||
# Create multiple tokens
|
||||
hashes = []
|
||||
for i in range(3):
|
||||
token_hash = f"token_{i}_{uuid4().hex[:8]}"
|
||||
hashes.append(token_hash)
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
result = await token_repo.revoke_all_for_user(user_id)
|
||||
|
||||
assert result == 3
|
||||
for token_hash in hashes:
|
||||
token = await token_repo.get_valid_by_hash(token_hash)
|
||||
assert token is None
|
||||
|
||||
async def test_revoke_all_for_user_returns_zero_for_no_tokens(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user() returns 0 if user has no tokens."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, "no_tokens@example.com", "hash")
|
||||
|
||||
result = await token_repo.revoke_all_for_user(user_id)
|
||||
|
||||
assert result == 0
|
||||
|
||||
async def test_revoke_all_for_user_only_affects_user_tokens(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user() doesn't affect other users' tokens."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user1 = uuid4()
|
||||
user2 = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user1, "user1@example.com", "hash")
|
||||
await user_repo.create(user2, "user2@example.com", "hash")
|
||||
await org_repo.create(org_id, "Shared Org", "shared-org")
|
||||
|
||||
user1_hash = f"user1_token_{uuid4().hex[:8]}"
|
||||
user2_hash = f"user2_token_{uuid4().hex[:8]}"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(uuid4(), user1, user1_hash, org_id, expires_at)
|
||||
await token_repo.create(uuid4(), user2, user2_hash, org_id, expires_at)
|
||||
|
||||
await token_repo.revoke_all_for_user(user1)
|
||||
|
||||
# User1's token is revoked
|
||||
assert await token_repo.get_valid_by_hash(user1_hash) is None
|
||||
# User2's token is still valid
|
||||
assert await token_repo.get_valid_by_hash(user2_hash) is not None
|
||||
|
||||
|
||||
class TestRevokeAllExcept:
|
||||
"""Tests for revoking all tokens except current session."""
|
||||
|
||||
async def test_revoke_all_except_keeps_specified_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user_except() keeps the specified token active."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "except@example.com", "hash")
|
||||
await org_repo.create(org_id, "Except Org", "except-org")
|
||||
|
||||
# Create multiple tokens
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
keep_token_id = uuid4()
|
||||
keep_hash = f"keep_token_{uuid4().hex[:8]}"
|
||||
await token_repo.create(keep_token_id, user_id, keep_hash, org_id, expires_at)
|
||||
|
||||
other_hashes = []
|
||||
for i in range(2):
|
||||
other_hash = f"other_token_{i}_{uuid4().hex[:8]}"
|
||||
other_hashes.append(other_hash)
|
||||
await token_repo.create(uuid4(), user_id, other_hash, org_id, expires_at)
|
||||
|
||||
result = await token_repo.revoke_all_for_user_except(user_id, keep_token_id)
|
||||
|
||||
assert result == 2 # Revoked 2 other tokens
|
||||
|
||||
# Keep token is still valid
|
||||
assert await token_repo.get_valid_by_hash(keep_hash) is not None
|
||||
|
||||
# Other tokens are revoked
|
||||
for other_hash in other_hashes:
|
||||
assert await token_repo.get_valid_by_hash(other_hash) is None
|
||||
|
||||
|
||||
class TestActiveTokensForUser:
|
||||
"""Tests for listing active tokens for a user."""
|
||||
|
||||
async def test_get_active_tokens_returns_only_active(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_active_tokens_for_user() returns only non-revoked, non-expired, non-rotated."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "active_list@example.com", "hash")
|
||||
await org_repo.create(org_id, "Active List Org", "active-list-org")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
expired_at = datetime.now(UTC) - timedelta(days=1)
|
||||
|
||||
# Create active token
|
||||
active_hash = f"active_{uuid4().hex[:8]}"
|
||||
await token_repo.create(uuid4(), user_id, active_hash, org_id, expires_at)
|
||||
|
||||
# Create revoked token
|
||||
revoked_id = uuid4()
|
||||
revoked_hash = f"revoked_{uuid4().hex[:8]}"
|
||||
await token_repo.create(revoked_id, user_id, revoked_hash, org_id, expires_at)
|
||||
await token_repo.revoke(revoked_id)
|
||||
|
||||
# Create expired token
|
||||
expired_hash = f"expired_{uuid4().hex[:8]}"
|
||||
await token_repo.create(uuid4(), user_id, expired_hash, org_id, expired_at)
|
||||
|
||||
# Create rotated token
|
||||
rotated_hash = f"rotated_{uuid4().hex[:8]}"
|
||||
await token_repo.create(uuid4(), user_id, rotated_hash, org_id, expires_at)
|
||||
await token_repo.rotate(rotated_hash, uuid4(), f"new_{uuid4().hex[:8]}", expires_at)
|
||||
|
||||
result = await token_repo.get_active_tokens_for_user(user_id)
|
||||
|
||||
# Should only return the active token and the new rotated token
|
||||
assert len(result) == 2
|
||||
hashes = {t["token_hash"] for t in result}
|
||||
assert active_hash in hashes
|
||||
assert revoked_hash not in hashes
|
||||
assert expired_hash not in hashes
|
||||
assert rotated_hash not in hashes
|
||||
|
||||
|
||||
class TestTokenForeignKeys:
|
||||
"""Tests for refresh token foreign key constraints."""
|
||||
|
||||
async def test_token_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""refresh_tokens.user_id must reference existing user."""
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, "FK Test Org", "fk-test-org")
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await token_repo.create(
|
||||
uuid4(), uuid4(), "orphan_token", org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
async def test_token_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""refresh_tokens.active_org_id must reference existing org."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, "fk_org_test@example.com", "hash")
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await token_repo.create(
|
||||
uuid4(), user_id, "orphan_org_token", uuid4(),
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
async def test_token_stores_active_org_id(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Token stores active_org_id for org context per SPECS.md."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "active_org@example.com", "hash")
|
||||
await org_repo.create(org_id, "Active Org", "active-org")
|
||||
|
||||
token_hash = f"active_org_token_{uuid4().hex[:8]}"
|
||||
await token_repo.create(
|
||||
uuid4(), user_id, token_hash, org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
token = await token_repo.get_by_hash(token_hash)
|
||||
assert token["active_org_id"] == org_id
|
||||
201
tests/repositories/test_service.py
Normal file
201
tests/repositories/test_service.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Tests for ServiceRepository."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.service import ServiceRepository
|
||||
|
||||
|
||||
class TestServiceRepository:
|
||||
"""Tests for ServiceRepository conforming to SPECS.md."""
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def test_create_service_returns_service_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a service returns the service data."""
|
||||
org_id = await self._create_org(db_conn, "service-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
service_id = uuid4()
|
||||
|
||||
result = await repo.create(service_id, org_id, "API Gateway", "api-gateway")
|
||||
|
||||
assert result["id"] == service_id
|
||||
assert result["org_id"] == org_id
|
||||
assert result["name"] == "API Gateway"
|
||||
assert result["slug"] == "api-gateway"
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_service_slug_unique_per_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Service slug must be unique within an org per SPECS.md."""
|
||||
org_id = await self._create_org(db_conn, "unique-slug-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, "Service One", "my-service")
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(uuid4(), org_id, "Service Two", "my-service")
|
||||
|
||||
async def test_same_slug_allowed_in_different_orgs(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Same slug can exist in different orgs."""
|
||||
org1 = await self._create_org(db_conn, "org-one")
|
||||
org2 = await self._create_org(db_conn, "org-two")
|
||||
repo = ServiceRepository(db_conn)
|
||||
slug = "shared-slug"
|
||||
|
||||
# Both should succeed
|
||||
result1 = await repo.create(uuid4(), org1, "Service Org1", slug)
|
||||
result2 = await repo.create(uuid4(), org2, "Service Org2", slug)
|
||||
|
||||
assert result1["slug"] == slug
|
||||
assert result2["slug"] == slug
|
||||
assert result1["org_id"] != result2["org_id"]
|
||||
|
||||
async def test_get_by_id_returns_service(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns the correct service."""
|
||||
org_id = await self._create_org(db_conn, "get-service-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
service_id = uuid4()
|
||||
|
||||
await repo.create(service_id, org_id, "My Service", "my-service")
|
||||
result = await repo.get_by_id(service_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == service_id
|
||||
assert result["name"] == "My Service"
|
||||
|
||||
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns None for non-existent service."""
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_org_returns_all_org_services(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns all services for an organization."""
|
||||
org_id = await self._create_org(db_conn, "multi-service-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, "Service A", "service-a")
|
||||
await repo.create(uuid4(), org_id, "Service B", "service-b")
|
||||
await repo.create(uuid4(), org_id, "Service C", "service-c")
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
assert len(result) == 3
|
||||
names = {s["name"] for s in result}
|
||||
assert names == {"Service A", "Service B", "Service C"}
|
||||
|
||||
async def test_get_by_org_returns_empty_for_no_services(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns empty list for org with no services."""
|
||||
org_id = await self._create_org(db_conn, "empty-service-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_get_by_org_only_returns_own_services(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org doesn't return services from other orgs (tenant isolation)."""
|
||||
org1 = await self._create_org(db_conn, "isolated-org-1")
|
||||
org2 = await self._create_org(db_conn, "isolated-org-2")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org1, "Org1 Service", "org1-service")
|
||||
await repo.create(uuid4(), org2, "Org2 Service", "org2-service")
|
||||
|
||||
result = await repo.get_by_org(org1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Org1 Service"
|
||||
|
||||
async def test_get_by_slug_returns_service(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns service by org and slug."""
|
||||
org_id = await self._create_org(db_conn, "slug-lookup-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
service_id = uuid4()
|
||||
|
||||
await repo.create(service_id, org_id, "Slug Service", "slug-service")
|
||||
result = await repo.get_by_slug(org_id, "slug-service")
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == service_id
|
||||
|
||||
async def test_get_by_slug_returns_none_for_wrong_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns None if slug exists but in different org."""
|
||||
org1 = await self._create_org(db_conn, "slug-org-1")
|
||||
org2 = await self._create_org(db_conn, "slug-org-2")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org1, "Service", "the-slug")
|
||||
result = await repo.get_by_slug(org2, "the-slug")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_slug_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_slug returns None for non-existent slug."""
|
||||
org_id = await self._create_org(db_conn, "no-slug-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_slug(org_id, "nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_slug_exists_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns True when slug exists in org."""
|
||||
org_id = await self._create_org(db_conn, "exists-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, "Exists Service", "exists-slug")
|
||||
result = await repo.slug_exists(org_id, "exists-slug")
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_slug_exists_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns False when slug doesn't exist in org."""
|
||||
org_id = await self._create_org(db_conn, "not-exists-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
result = await repo.slug_exists(org_id, "no-such-slug")
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_slug_exists_returns_false_for_other_org(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""slug_exists returns False for slug in different org."""
|
||||
org1 = await self._create_org(db_conn, "other-org-1")
|
||||
org2 = await self._create_org(db_conn, "other-org-2")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org1, "Service", "cross-org-slug")
|
||||
result = await repo.slug_exists(org2, "cross-org-slug")
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_service_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""services.org_id must reference existing org."""
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await repo.create(uuid4(), uuid4(), "Orphan Service", "orphan")
|
||||
|
||||
async def test_get_by_org_orders_by_name(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_org returns services ordered by name."""
|
||||
org_id = await self._create_org(db_conn, "ordered-org")
|
||||
repo = ServiceRepository(db_conn)
|
||||
|
||||
await repo.create(uuid4(), org_id, "Zebra", "zebra")
|
||||
await repo.create(uuid4(), org_id, "Alpha", "alpha")
|
||||
await repo.create(uuid4(), org_id, "Middle", "middle")
|
||||
|
||||
result = await repo.get_by_org(org_id)
|
||||
|
||||
names = [s["name"] for s in result]
|
||||
assert names == ["Alpha", "Middle", "Zebra"]
|
||||
133
tests/repositories/test_user.py
Normal file
133
tests/repositories/test_user.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for UserRepository."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
|
||||
class TestUserRepository:
|
||||
"""Tests for UserRepository conforming to SPECS.md."""
|
||||
|
||||
async def test_create_user_returns_user_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a user returns the user data with id, email, created_at."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
email = "test@example.com"
|
||||
password_hash = "hashed_password_123"
|
||||
|
||||
result = await repo.create(user_id, email, password_hash)
|
||||
|
||||
assert result["id"] == user_id
|
||||
assert result["email"] == email
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_create_user_stores_password_hash(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Password hash is stored correctly in the database."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
email = "hash_test@example.com"
|
||||
password_hash = "bcrypt_hashed_value"
|
||||
|
||||
await repo.create(user_id, email, password_hash)
|
||||
user = await repo.get_by_id(user_id)
|
||||
|
||||
assert user["password_hash"] == password_hash
|
||||
|
||||
async def test_create_user_email_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Email uniqueness constraint per SPECS.md users table."""
|
||||
repo = UserRepository(db_conn)
|
||||
email = "duplicate@example.com"
|
||||
|
||||
await repo.create(uuid4(), email, "hash1")
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(uuid4(), email, "hash2")
|
||||
|
||||
async def test_get_by_id_returns_user(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns the correct user."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
email = "getbyid@example.com"
|
||||
|
||||
await repo.create(user_id, email, "hash")
|
||||
result = await repo.get_by_id(user_id)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == user_id
|
||||
assert result["email"] == email
|
||||
|
||||
async def test_get_by_id_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_id returns None for non-existent user."""
|
||||
repo = UserRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_email_returns_user(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_email returns the correct user."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
email = "getbyemail@example.com"
|
||||
|
||||
await repo.create(user_id, email, "hash")
|
||||
result = await repo.get_by_email(email)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == user_id
|
||||
assert result["email"] == email
|
||||
|
||||
async def test_get_by_email_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_email returns None for non-existent email."""
|
||||
repo = UserRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_email("nonexistent@example.com")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_email_is_case_sensitive(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Email lookup is case-sensitive (stored as provided)."""
|
||||
repo = UserRepository(db_conn)
|
||||
email = "CaseSensitive@Example.com"
|
||||
|
||||
await repo.create(uuid4(), email, "hash")
|
||||
|
||||
# Exact match works
|
||||
result = await repo.get_by_email(email)
|
||||
assert result is not None
|
||||
|
||||
# Different case returns None
|
||||
result = await repo.get_by_email(email.lower())
|
||||
assert result is None
|
||||
|
||||
async def test_exists_by_email_returns_true_when_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""exists_by_email returns True when email exists."""
|
||||
repo = UserRepository(db_conn)
|
||||
email = "exists@example.com"
|
||||
|
||||
await repo.create(uuid4(), email, "hash")
|
||||
result = await repo.exists_by_email(email)
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_exists_by_email_returns_false_when_not_exists(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""exists_by_email returns False when email doesn't exist."""
|
||||
repo = UserRepository(db_conn)
|
||||
|
||||
result = await repo.exists_by_email("notexists@example.com")
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_user_id_is_uuid_primary_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""User ID must be a valid UUID (primary key)."""
|
||||
repo = UserRepository(db_conn)
|
||||
user_id = uuid4()
|
||||
|
||||
await repo.create(user_id, "pk_test@example.com", "hash")
|
||||
|
||||
# Duplicate ID should fail
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(user_id, "other@example.com", "hash")
|
||||
31
uv.lock
generated
31
uv.lock
generated
@@ -309,6 +309,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/cb/2da4cc83f5edb9c3257d09e1e7ab7b23f049c7962cae8d842bbef0a9cec9/cryptography-46.0.3-cp38-abi3-win_arm64.whl", hash = "sha256:d89c3468de4cdc4f08a57e214384d0471911a3830fcdaf7a8cc587e42a866372", size = 2918740, upload-time = "2025-10-15T23:18:12.277Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dnspython"
|
||||
version = "2.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ecdsa"
|
||||
version = "0.19.1"
|
||||
@@ -321,6 +330,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "email-validator"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dnspython" },
|
||||
{ name = "idna" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
@@ -407,7 +429,7 @@ dependencies = [
|
||||
{ name = "celery", extra = ["redis"] },
|
||||
{ name = "fastapi" },
|
||||
{ name = "httpx" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic", extra = ["email"] },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "redis" },
|
||||
@@ -428,7 +450,7 @@ requires-dist = [
|
||||
{ name = "celery", extras = ["redis"], specifier = ">=5.4.0" },
|
||||
{ name = "fastapi", specifier = ">=0.115.0" },
|
||||
{ name = "httpx", specifier = ">=0.28.0" },
|
||||
{ name = "pydantic", specifier = ">=2.0.0" },
|
||||
{ name = "pydantic", extras = ["email"], specifier = ">=2.0.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.0.0" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
|
||||
@@ -531,6 +553,11 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
email = [
|
||||
{ name = "email-validator" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-core"
|
||||
version = "2.41.5"
|
||||
|
||||
Reference in New Issue
Block a user