feat(api): Pydantic schemas + Data Repositories
This commit is contained in:
17
app/repositories/__init__.py
Normal file
17
app/repositories/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Repository layer for database operations."""
|
||||
|
||||
from app.repositories.incident import IncidentRepository
|
||||
from app.repositories.notification import NotificationRepository
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.refresh_token import RefreshTokenRepository
|
||||
from app.repositories.service import ServiceRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
__all__ = [
|
||||
"IncidentRepository",
|
||||
"NotificationRepository",
|
||||
"OrgRepository",
|
||||
"RefreshTokenRepository",
|
||||
"ServiceRepository",
|
||||
"UserRepository",
|
||||
]
|
||||
161
app/repositories/incident.py
Normal file
161
app/repositories/incident.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Incident repository for database operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class IncidentRepository:
|
||||
"""Database operations for incidents."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
incident_id: UUID,
|
||||
org_id: UUID,
|
||||
service_id: UUID,
|
||||
title: str,
|
||||
description: str | None,
|
||||
severity: str,
|
||||
) -> dict:
|
||||
"""Create a new incident."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO incidents (id, org_id, service_id, title, description, status, severity)
|
||||
VALUES ($1, $2, $3, $4, $5, 'triggered', $6)
|
||||
RETURNING id, org_id, service_id, title, description, status, severity,
|
||||
version, created_at, updated_at
|
||||
""",
|
||||
incident_id,
|
||||
org_id,
|
||||
service_id,
|
||||
title,
|
||||
description,
|
||||
severity,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_id(self, incident_id: UUID) -> dict | None:
|
||||
"""Get incident by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, org_id, service_id, title, description, status, severity,
|
||||
version, created_at, updated_at
|
||||
FROM incidents
|
||||
WHERE id = $1
|
||||
""",
|
||||
incident_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_by_org(
|
||||
self,
|
||||
org_id: UUID,
|
||||
status: str | None = None,
|
||||
cursor: datetime | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[dict]:
|
||||
"""Get incidents for an organization with optional filtering and pagination."""
|
||||
query = """
|
||||
SELECT id, org_id, service_id, title, description, status, severity,
|
||||
version, created_at, updated_at
|
||||
FROM incidents
|
||||
WHERE org_id = $1
|
||||
"""
|
||||
params: list[Any] = [org_id]
|
||||
param_idx = 2
|
||||
|
||||
if status:
|
||||
query += f" AND status = ${param_idx}"
|
||||
params.append(status)
|
||||
param_idx += 1
|
||||
|
||||
if cursor:
|
||||
query += f" AND created_at < ${param_idx}"
|
||||
params.append(cursor)
|
||||
param_idx += 1
|
||||
|
||||
query += f" ORDER BY created_at DESC LIMIT ${param_idx}"
|
||||
params.append(limit + 1) # Fetch one extra to check if there are more
|
||||
|
||||
rows = await self.conn.fetch(query, *params)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
incident_id: UUID,
|
||||
new_status: str,
|
||||
expected_version: int,
|
||||
) -> dict | None:
|
||||
"""Update incident status with optimistic locking.
|
||||
|
||||
Returns updated incident if successful, None if version mismatch.
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE incidents
|
||||
SET status = $2, version = version + 1, updated_at = now()
|
||||
WHERE id = $1 AND version = $3
|
||||
RETURNING id, org_id, service_id, title, description, status, severity,
|
||||
version, created_at, updated_at
|
||||
""",
|
||||
incident_id,
|
||||
new_status,
|
||||
expected_version,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def add_event(
|
||||
self,
|
||||
event_id: UUID,
|
||||
incident_id: UUID,
|
||||
event_type: str,
|
||||
actor_user_id: UUID | None,
|
||||
payload: dict[str, Any] | None,
|
||||
) -> dict:
|
||||
"""Add an event to the incident timeline."""
|
||||
import json
|
||||
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO incident_events (id, incident_id, event_type, actor_user_id, payload)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, incident_id, event_type, actor_user_id, payload, created_at
|
||||
""",
|
||||
event_id,
|
||||
incident_id,
|
||||
event_type,
|
||||
actor_user_id,
|
||||
json.dumps(payload) if payload else None,
|
||||
)
|
||||
result = dict(row)
|
||||
|
||||
# Parse JSON payload back to dict
|
||||
if result["payload"]:
|
||||
result["payload"] = json.loads(result["payload"])
|
||||
return result
|
||||
|
||||
async def get_events(self, incident_id: UUID) -> list[dict]:
|
||||
"""Get all events for an incident."""
|
||||
import json
|
||||
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT id, incident_id, event_type, actor_user_id, payload, created_at
|
||||
FROM incident_events
|
||||
WHERE incident_id = $1
|
||||
ORDER BY created_at
|
||||
""",
|
||||
incident_id,
|
||||
)
|
||||
results = []
|
||||
for row in rows:
|
||||
result = dict(row)
|
||||
if result["payload"]:
|
||||
result["payload"] = json.loads(result["payload"])
|
||||
results.append(result)
|
||||
return results
|
||||
199
app/repositories/notification.py
Normal file
199
app/repositories/notification.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Notification repository for database operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class NotificationRepository:
|
||||
"""Database operations for notification targets and attempts."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create_target(
|
||||
self,
|
||||
target_id: UUID,
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
target_type: str,
|
||||
webhook_url: str | None = None,
|
||||
enabled: bool = True,
|
||||
) -> dict:
|
||||
"""Create a new notification target."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO notification_targets (id, org_id, name, target_type, webhook_url, enabled)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at
|
||||
""",
|
||||
target_id,
|
||||
org_id,
|
||||
name,
|
||||
target_type,
|
||||
webhook_url,
|
||||
enabled,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_target_by_id(self, target_id: UUID) -> dict | None:
|
||||
"""Get notification target by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, org_id, name, target_type, webhook_url, enabled, created_at
|
||||
FROM notification_targets
|
||||
WHERE id = $1
|
||||
""",
|
||||
target_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_targets_by_org(
|
||||
self,
|
||||
org_id: UUID,
|
||||
enabled_only: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Get all notification targets for an organization."""
|
||||
query = """
|
||||
SELECT id, org_id, name, target_type, webhook_url, enabled, created_at
|
||||
FROM notification_targets
|
||||
WHERE org_id = $1
|
||||
"""
|
||||
if enabled_only:
|
||||
query += " AND enabled = true"
|
||||
query += " ORDER BY name"
|
||||
|
||||
rows = await self.conn.fetch(query, org_id)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def update_target(
|
||||
self,
|
||||
target_id: UUID,
|
||||
name: str | None = None,
|
||||
webhook_url: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> dict | None:
|
||||
"""Update a notification target."""
|
||||
updates = []
|
||||
params = [target_id]
|
||||
param_idx = 2
|
||||
|
||||
if name is not None:
|
||||
updates.append(f"name = ${param_idx}")
|
||||
params.append(name)
|
||||
param_idx += 1
|
||||
|
||||
if webhook_url is not None:
|
||||
updates.append(f"webhook_url = ${param_idx}")
|
||||
params.append(webhook_url)
|
||||
param_idx += 1
|
||||
|
||||
if enabled is not None:
|
||||
updates.append(f"enabled = ${param_idx}")
|
||||
params.append(enabled)
|
||||
param_idx += 1
|
||||
|
||||
if not updates:
|
||||
return await self.get_target_by_id(target_id)
|
||||
|
||||
query = f"""
|
||||
UPDATE notification_targets
|
||||
SET {", ".join(updates)}
|
||||
WHERE id = $1
|
||||
RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at
|
||||
"""
|
||||
row = await self.conn.fetchrow(query, *params)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def delete_target(self, target_id: UUID) -> bool:
|
||||
"""Delete a notification target. Returns True if deleted."""
|
||||
result = await self.conn.execute(
|
||||
"DELETE FROM notification_targets WHERE id = $1",
|
||||
target_id,
|
||||
)
|
||||
return result == "DELETE 1"
|
||||
|
||||
async def create_attempt(
|
||||
self,
|
||||
attempt_id: UUID,
|
||||
incident_id: UUID,
|
||||
target_id: UUID,
|
||||
) -> dict:
|
||||
"""Create a notification attempt (idempotent via unique constraint)."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO notification_attempts (id, incident_id, target_id, status)
|
||||
VALUES ($1, $2, $3, 'pending')
|
||||
ON CONFLICT (incident_id, target_id) DO UPDATE SET id = notification_attempts.id
|
||||
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
|
||||
""",
|
||||
attempt_id,
|
||||
incident_id,
|
||||
target_id,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_attempt(self, incident_id: UUID, target_id: UUID) -> dict | None:
|
||||
"""Get notification attempt for incident and target."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, incident_id, target_id, status, error, sent_at, created_at
|
||||
FROM notification_attempts
|
||||
WHERE incident_id = $1 AND target_id = $2
|
||||
""",
|
||||
incident_id,
|
||||
target_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def update_attempt_success(
|
||||
self,
|
||||
attempt_id: UUID,
|
||||
sent_at: datetime,
|
||||
) -> dict | None:
|
||||
"""Mark notification attempt as successful."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE notification_attempts
|
||||
SET status = 'sent', sent_at = $2, error = NULL
|
||||
WHERE id = $1
|
||||
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
|
||||
""",
|
||||
attempt_id,
|
||||
sent_at,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def update_attempt_failure(
|
||||
self,
|
||||
attempt_id: UUID,
|
||||
error: str,
|
||||
) -> dict | None:
|
||||
"""Mark notification attempt as failed."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE notification_attempts
|
||||
SET status = 'failed', error = $2
|
||||
WHERE id = $1
|
||||
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
|
||||
""",
|
||||
attempt_id,
|
||||
error,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_pending_attempts(self, incident_id: UUID) -> list[dict]:
|
||||
"""Get all pending notification attempts for an incident."""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT na.id, na.incident_id, na.target_id, na.status, na.error,
|
||||
na.sent_at, na.created_at,
|
||||
nt.target_type, nt.webhook_url, nt.name as target_name
|
||||
FROM notification_attempts na
|
||||
JOIN notification_targets nt ON nt.id = na.target_id
|
||||
WHERE na.incident_id = $1 AND na.status = 'pending'
|
||||
""",
|
||||
incident_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
125
app/repositories/org.py
Normal file
125
app/repositories/org.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Organization repository for database operations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class OrgRepository:
|
||||
"""Database operations for organizations."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
slug: str,
|
||||
) -> dict:
|
||||
"""Create a new organization."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO orgs (id, name, slug)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, name, slug, created_at
|
||||
""",
|
||||
org_id,
|
||||
name,
|
||||
slug,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_id(self, org_id: UUID) -> dict | None:
|
||||
"""Get organization by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, name, slug, created_at
|
||||
FROM orgs
|
||||
WHERE id = $1
|
||||
""",
|
||||
org_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_by_slug(self, slug: str) -> dict | None:
|
||||
"""Get organization by slug."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, name, slug, created_at
|
||||
FROM orgs
|
||||
WHERE slug = $1
|
||||
""",
|
||||
slug,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def add_member(
|
||||
self,
|
||||
member_id: UUID,
|
||||
user_id: UUID,
|
||||
org_id: UUID,
|
||||
role: str,
|
||||
) -> dict:
|
||||
"""Add a member to an organization."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO org_members (id, user_id, org_id, role)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, user_id, org_id, role, created_at
|
||||
""",
|
||||
member_id,
|
||||
user_id,
|
||||
org_id,
|
||||
role,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_member(self, user_id: UUID, org_id: UUID) -> dict | None:
|
||||
"""Get membership for a user in an organization."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT om.id, om.user_id, om.org_id, om.role, om.created_at
|
||||
FROM org_members om
|
||||
WHERE om.user_id = $1 AND om.org_id = $2
|
||||
""",
|
||||
user_id,
|
||||
org_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_members(self, org_id: UUID) -> list[dict]:
|
||||
"""Get all members of an organization."""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT om.id, om.user_id, u.email, om.role, om.created_at
|
||||
FROM org_members om
|
||||
JOIN users u ON u.id = om.user_id
|
||||
WHERE om.org_id = $1
|
||||
ORDER BY om.created_at
|
||||
""",
|
||||
org_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def get_user_orgs(self, user_id: UUID) -> list[dict]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT o.id, o.name, o.slug, o.created_at, om.role
|
||||
FROM orgs o
|
||||
JOIN org_members om ON om.org_id = o.id
|
||||
WHERE om.user_id = $1
|
||||
ORDER BY o.created_at
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def slug_exists(self, slug: str) -> bool:
|
||||
"""Check if organization slug exists."""
|
||||
result = await self.conn.fetchval(
|
||||
"SELECT EXISTS(SELECT 1 FROM orgs WHERE slug = $1)",
|
||||
slug,
|
||||
)
|
||||
return result
|
||||
396
app/repositories/refresh_token.py
Normal file
396
app/repositories/refresh_token.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""Refresh token repository for database operations.
|
||||
|
||||
Security considerations implemented:
|
||||
- Atomic rotation using SELECT FOR UPDATE to prevent race conditions
|
||||
- Token chain tracking via rotated_to for reuse/theft detection
|
||||
- Defense-in-depth validation with user_id and active_org_id checks
|
||||
- Uses RETURNING for robust row counting instead of string parsing
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class RefreshTokenRepository:
|
||||
"""Database operations for refresh tokens."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
token_id: UUID,
|
||||
user_id: UUID,
|
||||
token_hash: str,
|
||||
active_org_id: UUID,
|
||||
expires_at: datetime,
|
||||
) -> dict:
|
||||
"""Create a new refresh token."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
""",
|
||||
token_id,
|
||||
user_id,
|
||||
token_hash,
|
||||
active_org_id,
|
||||
expires_at,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_hash(self, token_hash: str) -> dict | None:
|
||||
"""Get refresh token by hash (includes revoked/expired for auditing)."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
""",
|
||||
token_hash,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_valid_by_hash(
|
||||
self,
|
||||
token_hash: str,
|
||||
user_id: UUID | None = None,
|
||||
active_org_id: UUID | None = None,
|
||||
) -> dict | None:
|
||||
"""Get refresh token by hash, only if valid.
|
||||
|
||||
Validates:
|
||||
- Token exists and matches hash
|
||||
- Token is not revoked
|
||||
- Token is not expired
|
||||
- Token has not been rotated (rotated_to is NULL)
|
||||
- Optionally: user_id matches (defense-in-depth)
|
||||
- Optionally: active_org_id matches (defense-in-depth)
|
||||
|
||||
Args:
|
||||
token_hash: The hashed token value
|
||||
user_id: If provided, token must belong to this user
|
||||
active_org_id: If provided, token must be bound to this org
|
||||
|
||||
Returns:
|
||||
Token dict if valid, None otherwise
|
||||
"""
|
||||
query = """
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
AND revoked_at IS NULL
|
||||
AND rotated_to IS NULL
|
||||
AND expires_at > clock_timestamp()
|
||||
"""
|
||||
params: list = [token_hash]
|
||||
param_idx = 2
|
||||
|
||||
if user_id is not None:
|
||||
query += f" AND user_id = ${param_idx}"
|
||||
params.append(user_id)
|
||||
param_idx += 1
|
||||
|
||||
if active_org_id is not None:
|
||||
query += f" AND active_org_id = ${param_idx}"
|
||||
params.append(active_org_id)
|
||||
|
||||
row = await self.conn.fetchrow(query, *params)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_valid_for_rotation(
|
||||
self,
|
||||
token_hash: str,
|
||||
user_id: UUID | None = None,
|
||||
) -> dict | None:
|
||||
"""Get and lock a valid token for rotation using SELECT FOR UPDATE.
|
||||
|
||||
This acquires a row-level lock to prevent concurrent rotation attempts.
|
||||
Must be called within a transaction.
|
||||
|
||||
Args:
|
||||
token_hash: The hashed token value
|
||||
user_id: If provided, token must belong to this user
|
||||
|
||||
Returns:
|
||||
Token dict if valid and locked, None otherwise
|
||||
"""
|
||||
query = """
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
AND revoked_at IS NULL
|
||||
AND rotated_to IS NULL
|
||||
AND expires_at > clock_timestamp()
|
||||
"""
|
||||
params: list = [token_hash]
|
||||
|
||||
if user_id is not None:
|
||||
query += " AND user_id = $2"
|
||||
params.append(user_id)
|
||||
|
||||
query += " FOR UPDATE"
|
||||
|
||||
row = await self.conn.fetchrow(query, *params)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def check_token_reuse(self, token_hash: str) -> dict | None:
|
||||
"""Check if a token has already been rotated (potential theft).
|
||||
|
||||
If a token is presented that has rotated_to set, it means:
|
||||
1. The token was legitimately rotated earlier
|
||||
2. Someone is now trying to use the old token
|
||||
3. This indicates the token may have been stolen
|
||||
|
||||
Returns:
|
||||
Token dict if this is a reused/stolen token, None if not found or not rotated
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1 AND rotated_to IS NOT NULL
|
||||
""",
|
||||
token_hash,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def revoke_token_chain(self, token_id: UUID) -> int:
|
||||
"""Revoke a token and all tokens in its chain (for breach response).
|
||||
|
||||
When token reuse is detected, this revokes:
|
||||
1. The original stolen token
|
||||
2. Any token it was rotated to (and their rotations, recursively)
|
||||
|
||||
Args:
|
||||
token_id: The ID of the compromised token
|
||||
|
||||
Returns:
|
||||
Count of tokens revoked
|
||||
"""
|
||||
# Use recursive CTE to find all tokens in the chain
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
WITH RECURSIVE token_chain AS (
|
||||
-- Start with the given token
|
||||
SELECT id, rotated_to
|
||||
FROM refresh_tokens
|
||||
WHERE id = $1
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- Follow the chain via rotated_to
|
||||
SELECT rt.id, rt.rotated_to
|
||||
FROM refresh_tokens rt
|
||||
INNER JOIN token_chain tc ON rt.id = tc.rotated_to
|
||||
)
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE id IN (SELECT id FROM token_chain)
|
||||
AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
token_id,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
async def rotate(
|
||||
self,
|
||||
old_token_hash: str,
|
||||
new_token_id: UUID,
|
||||
new_token_hash: str,
|
||||
new_expires_at: datetime,
|
||||
new_active_org_id: UUID | None = None,
|
||||
expected_user_id: UUID | None = None,
|
||||
) -> dict | None:
|
||||
"""Atomically rotate a refresh token.
|
||||
|
||||
This method:
|
||||
1. Validates the old token (not expired, not revoked, not already rotated)
|
||||
2. Locks the row to prevent concurrent rotation
|
||||
3. Marks old token as rotated (sets rotated_to)
|
||||
4. Creates new token with updated org if specified
|
||||
5. All in a single atomic operation
|
||||
|
||||
Args:
|
||||
old_token_hash: Hash of the token being rotated
|
||||
new_token_id: UUID for the new token
|
||||
new_token_hash: Hash for the new token
|
||||
new_expires_at: Expiry time for the new token
|
||||
new_active_org_id: New org ID (for org-switch), or None to keep current
|
||||
expected_user_id: If provided, validates token belongs to this user
|
||||
|
||||
Returns:
|
||||
New token dict if rotation succeeded, None if old token invalid/expired
|
||||
"""
|
||||
# First, get and lock the old token
|
||||
old_token = await self.get_valid_for_rotation(old_token_hash, expected_user_id)
|
||||
if old_token is None:
|
||||
return None
|
||||
|
||||
# Determine the org for the new token
|
||||
active_org_id = new_active_org_id or old_token["active_org_id"]
|
||||
user_id = old_token["user_id"]
|
||||
|
||||
# Create the new token
|
||||
new_token = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
""",
|
||||
new_token_id,
|
||||
user_id,
|
||||
new_token_hash,
|
||||
active_org_id,
|
||||
new_expires_at,
|
||||
)
|
||||
|
||||
# Mark the old token as rotated (not revoked - for reuse detection)
|
||||
await self.conn.execute(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET rotated_to = $2
|
||||
WHERE id = $1
|
||||
""",
|
||||
old_token["id"],
|
||||
new_token_id,
|
||||
)
|
||||
|
||||
return dict(new_token)
|
||||
|
||||
async def revoke(self, token_id: UUID) -> bool:
|
||||
"""Revoke a refresh token by ID.
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if not found or already revoked
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE id = $1 AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
token_id,
|
||||
)
|
||||
return row is not None
|
||||
|
||||
async def revoke_by_hash(self, token_hash: str) -> bool:
|
||||
"""Revoke a refresh token by hash.
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if not found or already revoked
|
||||
"""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE token_hash = $1 AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
token_hash,
|
||||
)
|
||||
return row is not None
|
||||
|
||||
async def revoke_all_for_user(self, user_id: UUID) -> int:
|
||||
"""Revoke all active refresh tokens for a user.
|
||||
|
||||
Use this for:
|
||||
- User-initiated logout from all devices
|
||||
- Password change
|
||||
- Account compromise response
|
||||
|
||||
Returns:
|
||||
Count of tokens revoked
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
async def revoke_all_for_user_except(self, user_id: UUID, keep_token_id: UUID) -> int:
|
||||
"""Revoke all tokens for a user except one (logout other sessions).
|
||||
|
||||
Args:
|
||||
user_id: The user whose tokens to revoke
|
||||
keep_token_id: The token ID to keep active (current session)
|
||||
|
||||
Returns:
|
||||
Count of tokens revoked
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
UPDATE refresh_tokens
|
||||
SET revoked_at = clock_timestamp()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL AND id != $2
|
||||
RETURNING id
|
||||
""",
|
||||
user_id,
|
||||
keep_token_id,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
async def get_active_tokens_for_user(self, user_id: UUID) -> list[dict]:
|
||||
"""Get all active (non-revoked, non-expired, non-rotated) tokens for a user.
|
||||
|
||||
Useful for:
|
||||
- Showing active sessions
|
||||
- Auditing
|
||||
|
||||
Returns:
|
||||
List of active token records
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
||||
revoked_at, rotated_to, created_at
|
||||
FROM refresh_tokens
|
||||
WHERE user_id = $1
|
||||
AND revoked_at IS NULL
|
||||
AND rotated_to IS NULL
|
||||
AND expires_at > clock_timestamp()
|
||||
ORDER BY created_at DESC
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def cleanup_expired(self, older_than_days: int = 30) -> int:
|
||||
"""Delete expired tokens older than specified days.
|
||||
|
||||
Note: This performs a hard delete. For audit trails, I think we should:
|
||||
- Archiving to a separate table first
|
||||
- Using partitioning with retention policies
|
||||
- Only calling this for tokens well past their expiry
|
||||
|
||||
Args:
|
||||
older_than_days: Only delete tokens expired more than this many days ago
|
||||
|
||||
Returns:
|
||||
Count of tokens deleted
|
||||
"""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE expires_at < clock_timestamp() - interval '1 day' * $1
|
||||
RETURNING id
|
||||
""",
|
||||
older_than_days,
|
||||
)
|
||||
return len(rows)
|
||||
80
app/repositories/service.py
Normal file
80
app/repositories/service.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Service repository for database operations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class ServiceRepository:
|
||||
"""Database operations for services."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
service_id: UUID,
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
slug: str,
|
||||
) -> dict:
|
||||
"""Create a new service."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO services (id, org_id, name, slug)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, org_id, name, slug, created_at
|
||||
""",
|
||||
service_id,
|
||||
org_id,
|
||||
name,
|
||||
slug,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_id(self, service_id: UUID) -> dict | None:
|
||||
"""Get service by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, org_id, name, slug, created_at
|
||||
FROM services
|
||||
WHERE id = $1
|
||||
""",
|
||||
service_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_by_org(self, org_id: UUID) -> list[dict]:
|
||||
"""Get all services for an organization."""
|
||||
rows = await self.conn.fetch(
|
||||
"""
|
||||
SELECT id, org_id, name, slug, created_at
|
||||
FROM services
|
||||
WHERE org_id = $1
|
||||
ORDER BY name
|
||||
""",
|
||||
org_id,
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def get_by_slug(self, org_id: UUID, slug: str) -> dict | None:
|
||||
"""Get service by org and slug."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, org_id, name, slug, created_at
|
||||
FROM services
|
||||
WHERE org_id = $1 AND slug = $2
|
||||
""",
|
||||
org_id,
|
||||
slug,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def slug_exists(self, org_id: UUID, slug: str) -> bool:
|
||||
"""Check if service slug exists in organization."""
|
||||
result = await self.conn.fetchval(
|
||||
"SELECT EXISTS(SELECT 1 FROM services WHERE org_id = $1 AND slug = $2)",
|
||||
org_id,
|
||||
slug,
|
||||
)
|
||||
return result
|
||||
63
app/repositories/user.py
Normal file
63
app/repositories/user.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""User repository for database operations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Database operations for users."""
|
||||
|
||||
def __init__(self, conn: asyncpg.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_id: UUID,
|
||||
email: str,
|
||||
password_hash: str,
|
||||
) -> dict:
|
||||
"""Create a new user."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO users (id, email, password_hash)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, email, created_at
|
||||
""",
|
||||
user_id,
|
||||
email,
|
||||
password_hash,
|
||||
)
|
||||
return dict(row)
|
||||
|
||||
async def get_by_id(self, user_id: UUID) -> dict | None:
|
||||
"""Get user by ID."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, email, password_hash, created_at
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_by_email(self, email: str) -> dict | None:
|
||||
"""Get user by email."""
|
||||
row = await self.conn.fetchrow(
|
||||
"""
|
||||
SELECT id, email, password_hash, created_at
|
||||
FROM users
|
||||
WHERE email = $1
|
||||
""",
|
||||
email,
|
||||
)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def exists_by_email(self, email: str) -> bool:
|
||||
"""Check if user exists by email."""
|
||||
result = await self.conn.fetchval(
|
||||
"SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)",
|
||||
email,
|
||||
)
|
||||
return result
|
||||
Reference in New Issue
Block a user