feat(api): Pydantic schemas + Data Repositories

This commit is contained in:
2025-12-07 03:58:02 -05:00
parent fbe9fbba6e
commit a8fbce09c4
23 changed files with 3549 additions and 3 deletions

View File

@@ -0,0 +1,17 @@
"""Repository layer for database operations."""
from app.repositories.incident import IncidentRepository
from app.repositories.notification import NotificationRepository
from app.repositories.org import OrgRepository
from app.repositories.refresh_token import RefreshTokenRepository
from app.repositories.service import ServiceRepository
from app.repositories.user import UserRepository
__all__ = [
"IncidentRepository",
"NotificationRepository",
"OrgRepository",
"RefreshTokenRepository",
"ServiceRepository",
"UserRepository",
]

View File

@@ -0,0 +1,161 @@
"""Incident repository for database operations."""
from datetime import datetime
from typing import Any
from uuid import UUID
import asyncpg
class IncidentRepository:
"""Database operations for incidents."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
incident_id: UUID,
org_id: UUID,
service_id: UUID,
title: str,
description: str | None,
severity: str,
) -> dict:
"""Create a new incident."""
row = await self.conn.fetchrow(
"""
INSERT INTO incidents (id, org_id, service_id, title, description, status, severity)
VALUES ($1, $2, $3, $4, $5, 'triggered', $6)
RETURNING id, org_id, service_id, title, description, status, severity,
version, created_at, updated_at
""",
incident_id,
org_id,
service_id,
title,
description,
severity,
)
return dict(row)
async def get_by_id(self, incident_id: UUID) -> dict | None:
"""Get incident by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, org_id, service_id, title, description, status, severity,
version, created_at, updated_at
FROM incidents
WHERE id = $1
""",
incident_id,
)
return dict(row) if row else None
async def get_by_org(
self,
org_id: UUID,
status: str | None = None,
cursor: datetime | None = None,
limit: int = 20,
) -> list[dict]:
"""Get incidents for an organization with optional filtering and pagination."""
query = """
SELECT id, org_id, service_id, title, description, status, severity,
version, created_at, updated_at
FROM incidents
WHERE org_id = $1
"""
params: list[Any] = [org_id]
param_idx = 2
if status:
query += f" AND status = ${param_idx}"
params.append(status)
param_idx += 1
if cursor:
query += f" AND created_at < ${param_idx}"
params.append(cursor)
param_idx += 1
query += f" ORDER BY created_at DESC LIMIT ${param_idx}"
params.append(limit + 1) # Fetch one extra to check if there are more
rows = await self.conn.fetch(query, *params)
return [dict(row) for row in rows]
async def update_status(
self,
incident_id: UUID,
new_status: str,
expected_version: int,
) -> dict | None:
"""Update incident status with optimistic locking.
Returns updated incident if successful, None if version mismatch.
"""
row = await self.conn.fetchrow(
"""
UPDATE incidents
SET status = $2, version = version + 1, updated_at = now()
WHERE id = $1 AND version = $3
RETURNING id, org_id, service_id, title, description, status, severity,
version, created_at, updated_at
""",
incident_id,
new_status,
expected_version,
)
return dict(row) if row else None
async def add_event(
self,
event_id: UUID,
incident_id: UUID,
event_type: str,
actor_user_id: UUID | None,
payload: dict[str, Any] | None,
) -> dict:
"""Add an event to the incident timeline."""
import json
row = await self.conn.fetchrow(
"""
INSERT INTO incident_events (id, incident_id, event_type, actor_user_id, payload)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, incident_id, event_type, actor_user_id, payload, created_at
""",
event_id,
incident_id,
event_type,
actor_user_id,
json.dumps(payload) if payload else None,
)
result = dict(row)
# Parse JSON payload back to dict
if result["payload"]:
result["payload"] = json.loads(result["payload"])
return result
async def get_events(self, incident_id: UUID) -> list[dict]:
"""Get all events for an incident."""
import json
rows = await self.conn.fetch(
"""
SELECT id, incident_id, event_type, actor_user_id, payload, created_at
FROM incident_events
WHERE incident_id = $1
ORDER BY created_at
""",
incident_id,
)
results = []
for row in rows:
result = dict(row)
if result["payload"]:
result["payload"] = json.loads(result["payload"])
results.append(result)
return results

View File

@@ -0,0 +1,199 @@
"""Notification repository for database operations."""
from datetime import datetime
from uuid import UUID
import asyncpg
class NotificationRepository:
"""Database operations for notification targets and attempts."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create_target(
self,
target_id: UUID,
org_id: UUID,
name: str,
target_type: str,
webhook_url: str | None = None,
enabled: bool = True,
) -> dict:
"""Create a new notification target."""
row = await self.conn.fetchrow(
"""
INSERT INTO notification_targets (id, org_id, name, target_type, webhook_url, enabled)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at
""",
target_id,
org_id,
name,
target_type,
webhook_url,
enabled,
)
return dict(row)
async def get_target_by_id(self, target_id: UUID) -> dict | None:
"""Get notification target by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, org_id, name, target_type, webhook_url, enabled, created_at
FROM notification_targets
WHERE id = $1
""",
target_id,
)
return dict(row) if row else None
async def get_targets_by_org(
self,
org_id: UUID,
enabled_only: bool = False,
) -> list[dict]:
"""Get all notification targets for an organization."""
query = """
SELECT id, org_id, name, target_type, webhook_url, enabled, created_at
FROM notification_targets
WHERE org_id = $1
"""
if enabled_only:
query += " AND enabled = true"
query += " ORDER BY name"
rows = await self.conn.fetch(query, org_id)
return [dict(row) for row in rows]
async def update_target(
self,
target_id: UUID,
name: str | None = None,
webhook_url: str | None = None,
enabled: bool | None = None,
) -> dict | None:
"""Update a notification target."""
updates = []
params = [target_id]
param_idx = 2
if name is not None:
updates.append(f"name = ${param_idx}")
params.append(name)
param_idx += 1
if webhook_url is not None:
updates.append(f"webhook_url = ${param_idx}")
params.append(webhook_url)
param_idx += 1
if enabled is not None:
updates.append(f"enabled = ${param_idx}")
params.append(enabled)
param_idx += 1
if not updates:
return await self.get_target_by_id(target_id)
query = f"""
UPDATE notification_targets
SET {", ".join(updates)}
WHERE id = $1
RETURNING id, org_id, name, target_type, webhook_url, enabled, created_at
"""
row = await self.conn.fetchrow(query, *params)
return dict(row) if row else None
async def delete_target(self, target_id: UUID) -> bool:
"""Delete a notification target. Returns True if deleted."""
result = await self.conn.execute(
"DELETE FROM notification_targets WHERE id = $1",
target_id,
)
return result == "DELETE 1"
async def create_attempt(
self,
attempt_id: UUID,
incident_id: UUID,
target_id: UUID,
) -> dict:
"""Create a notification attempt (idempotent via unique constraint)."""
row = await self.conn.fetchrow(
"""
INSERT INTO notification_attempts (id, incident_id, target_id, status)
VALUES ($1, $2, $3, 'pending')
ON CONFLICT (incident_id, target_id) DO UPDATE SET id = notification_attempts.id
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
""",
attempt_id,
incident_id,
target_id,
)
return dict(row)
async def get_attempt(self, incident_id: UUID, target_id: UUID) -> dict | None:
"""Get notification attempt for incident and target."""
row = await self.conn.fetchrow(
"""
SELECT id, incident_id, target_id, status, error, sent_at, created_at
FROM notification_attempts
WHERE incident_id = $1 AND target_id = $2
""",
incident_id,
target_id,
)
return dict(row) if row else None
async def update_attempt_success(
self,
attempt_id: UUID,
sent_at: datetime,
) -> dict | None:
"""Mark notification attempt as successful."""
row = await self.conn.fetchrow(
"""
UPDATE notification_attempts
SET status = 'sent', sent_at = $2, error = NULL
WHERE id = $1
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
""",
attempt_id,
sent_at,
)
return dict(row) if row else None
async def update_attempt_failure(
self,
attempt_id: UUID,
error: str,
) -> dict | None:
"""Mark notification attempt as failed."""
row = await self.conn.fetchrow(
"""
UPDATE notification_attempts
SET status = 'failed', error = $2
WHERE id = $1
RETURNING id, incident_id, target_id, status, error, sent_at, created_at
""",
attempt_id,
error,
)
return dict(row) if row else None
async def get_pending_attempts(self, incident_id: UUID) -> list[dict]:
"""Get all pending notification attempts for an incident."""
rows = await self.conn.fetch(
"""
SELECT na.id, na.incident_id, na.target_id, na.status, na.error,
na.sent_at, na.created_at,
nt.target_type, nt.webhook_url, nt.name as target_name
FROM notification_attempts na
JOIN notification_targets nt ON nt.id = na.target_id
WHERE na.incident_id = $1 AND na.status = 'pending'
""",
incident_id,
)
return [dict(row) for row in rows]

125
app/repositories/org.py Normal file
View File

@@ -0,0 +1,125 @@
"""Organization repository for database operations."""
from uuid import UUID
import asyncpg
class OrgRepository:
"""Database operations for organizations."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
org_id: UUID,
name: str,
slug: str,
) -> dict:
"""Create a new organization."""
row = await self.conn.fetchrow(
"""
INSERT INTO orgs (id, name, slug)
VALUES ($1, $2, $3)
RETURNING id, name, slug, created_at
""",
org_id,
name,
slug,
)
return dict(row)
async def get_by_id(self, org_id: UUID) -> dict | None:
"""Get organization by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, name, slug, created_at
FROM orgs
WHERE id = $1
""",
org_id,
)
return dict(row) if row else None
async def get_by_slug(self, slug: str) -> dict | None:
"""Get organization by slug."""
row = await self.conn.fetchrow(
"""
SELECT id, name, slug, created_at
FROM orgs
WHERE slug = $1
""",
slug,
)
return dict(row) if row else None
async def add_member(
self,
member_id: UUID,
user_id: UUID,
org_id: UUID,
role: str,
) -> dict:
"""Add a member to an organization."""
row = await self.conn.fetchrow(
"""
INSERT INTO org_members (id, user_id, org_id, role)
VALUES ($1, $2, $3, $4)
RETURNING id, user_id, org_id, role, created_at
""",
member_id,
user_id,
org_id,
role,
)
return dict(row)
async def get_member(self, user_id: UUID, org_id: UUID) -> dict | None:
"""Get membership for a user in an organization."""
row = await self.conn.fetchrow(
"""
SELECT om.id, om.user_id, om.org_id, om.role, om.created_at
FROM org_members om
WHERE om.user_id = $1 AND om.org_id = $2
""",
user_id,
org_id,
)
return dict(row) if row else None
async def get_members(self, org_id: UUID) -> list[dict]:
"""Get all members of an organization."""
rows = await self.conn.fetch(
"""
SELECT om.id, om.user_id, u.email, om.role, om.created_at
FROM org_members om
JOIN users u ON u.id = om.user_id
WHERE om.org_id = $1
ORDER BY om.created_at
""",
org_id,
)
return [dict(row) for row in rows]
async def get_user_orgs(self, user_id: UUID) -> list[dict]:
"""Get all organizations a user belongs to."""
rows = await self.conn.fetch(
"""
SELECT o.id, o.name, o.slug, o.created_at, om.role
FROM orgs o
JOIN org_members om ON om.org_id = o.id
WHERE om.user_id = $1
ORDER BY o.created_at
""",
user_id,
)
return [dict(row) for row in rows]
async def slug_exists(self, slug: str) -> bool:
"""Check if organization slug exists."""
result = await self.conn.fetchval(
"SELECT EXISTS(SELECT 1 FROM orgs WHERE slug = $1)",
slug,
)
return result

View File

@@ -0,0 +1,396 @@
"""Refresh token repository for database operations.
Security considerations implemented:
- Atomic rotation using SELECT FOR UPDATE to prevent race conditions
- Token chain tracking via rotated_to for reuse/theft detection
- Defense-in-depth validation with user_id and active_org_id checks
- Uses RETURNING for robust row counting instead of string parsing
"""
from datetime import datetime
from uuid import UUID
import asyncpg
class RefreshTokenRepository:
"""Database operations for refresh tokens."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
token_id: UUID,
user_id: UUID,
token_hash: str,
active_org_id: UUID,
expires_at: datetime,
) -> dict:
"""Create a new refresh token."""
row = await self.conn.fetchrow(
"""
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
""",
token_id,
user_id,
token_hash,
active_org_id,
expires_at,
)
return dict(row)
async def get_by_hash(self, token_hash: str) -> dict | None:
"""Get refresh token by hash (includes revoked/expired for auditing)."""
row = await self.conn.fetchrow(
"""
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1
""",
token_hash,
)
return dict(row) if row else None
async def get_valid_by_hash(
self,
token_hash: str,
user_id: UUID | None = None,
active_org_id: UUID | None = None,
) -> dict | None:
"""Get refresh token by hash, only if valid.
Validates:
- Token exists and matches hash
- Token is not revoked
- Token is not expired
- Token has not been rotated (rotated_to is NULL)
- Optionally: user_id matches (defense-in-depth)
- Optionally: active_org_id matches (defense-in-depth)
Args:
token_hash: The hashed token value
user_id: If provided, token must belong to this user
active_org_id: If provided, token must be bound to this org
Returns:
Token dict if valid, None otherwise
"""
query = """
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1
AND revoked_at IS NULL
AND rotated_to IS NULL
AND expires_at > clock_timestamp()
"""
params: list = [token_hash]
param_idx = 2
if user_id is not None:
query += f" AND user_id = ${param_idx}"
params.append(user_id)
param_idx += 1
if active_org_id is not None:
query += f" AND active_org_id = ${param_idx}"
params.append(active_org_id)
row = await self.conn.fetchrow(query, *params)
return dict(row) if row else None
async def get_valid_for_rotation(
self,
token_hash: str,
user_id: UUID | None = None,
) -> dict | None:
"""Get and lock a valid token for rotation using SELECT FOR UPDATE.
This acquires a row-level lock to prevent concurrent rotation attempts.
Must be called within a transaction.
Args:
token_hash: The hashed token value
user_id: If provided, token must belong to this user
Returns:
Token dict if valid and locked, None otherwise
"""
query = """
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1
AND revoked_at IS NULL
AND rotated_to IS NULL
AND expires_at > clock_timestamp()
"""
params: list = [token_hash]
if user_id is not None:
query += " AND user_id = $2"
params.append(user_id)
query += " FOR UPDATE"
row = await self.conn.fetchrow(query, *params)
return dict(row) if row else None
async def check_token_reuse(self, token_hash: str) -> dict | None:
"""Check if a token has already been rotated (potential theft).
If a token is presented that has rotated_to set, it means:
1. The token was legitimately rotated earlier
2. Someone is now trying to use the old token
3. This indicates the token may have been stolen
Returns:
Token dict if this is a reused/stolen token, None if not found or not rotated
"""
row = await self.conn.fetchrow(
"""
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE token_hash = $1 AND rotated_to IS NOT NULL
""",
token_hash,
)
return dict(row) if row else None
async def revoke_token_chain(self, token_id: UUID) -> int:
"""Revoke a token and all tokens in its chain (for breach response).
When token reuse is detected, this revokes:
1. The original stolen token
2. Any token it was rotated to (and their rotations, recursively)
Args:
token_id: The ID of the compromised token
Returns:
Count of tokens revoked
"""
# Use recursive CTE to find all tokens in the chain
rows = await self.conn.fetch(
"""
WITH RECURSIVE token_chain AS (
-- Start with the given token
SELECT id, rotated_to
FROM refresh_tokens
WHERE id = $1
UNION ALL
-- Follow the chain via rotated_to
SELECT rt.id, rt.rotated_to
FROM refresh_tokens rt
INNER JOIN token_chain tc ON rt.id = tc.rotated_to
)
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE id IN (SELECT id FROM token_chain)
AND revoked_at IS NULL
RETURNING id
""",
token_id,
)
return len(rows)
async def rotate(
self,
old_token_hash: str,
new_token_id: UUID,
new_token_hash: str,
new_expires_at: datetime,
new_active_org_id: UUID | None = None,
expected_user_id: UUID | None = None,
) -> dict | None:
"""Atomically rotate a refresh token.
This method:
1. Validates the old token (not expired, not revoked, not already rotated)
2. Locks the row to prevent concurrent rotation
3. Marks old token as rotated (sets rotated_to)
4. Creates new token with updated org if specified
5. All in a single atomic operation
Args:
old_token_hash: Hash of the token being rotated
new_token_id: UUID for the new token
new_token_hash: Hash for the new token
new_expires_at: Expiry time for the new token
new_active_org_id: New org ID (for org-switch), or None to keep current
expected_user_id: If provided, validates token belongs to this user
Returns:
New token dict if rotation succeeded, None if old token invalid/expired
"""
# First, get and lock the old token
old_token = await self.get_valid_for_rotation(old_token_hash, expected_user_id)
if old_token is None:
return None
# Determine the org for the new token
active_org_id = new_active_org_id or old_token["active_org_id"]
user_id = old_token["user_id"]
# Create the new token
new_token = await self.conn.fetchrow(
"""
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
""",
new_token_id,
user_id,
new_token_hash,
active_org_id,
new_expires_at,
)
# Mark the old token as rotated (not revoked - for reuse detection)
await self.conn.execute(
"""
UPDATE refresh_tokens
SET rotated_to = $2
WHERE id = $1
""",
old_token["id"],
new_token_id,
)
return dict(new_token)
async def revoke(self, token_id: UUID) -> bool:
"""Revoke a refresh token by ID.
Returns:
True if token was revoked, False if not found or already revoked
"""
row = await self.conn.fetchrow(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE id = $1 AND revoked_at IS NULL
RETURNING id
""",
token_id,
)
return row is not None
async def revoke_by_hash(self, token_hash: str) -> bool:
"""Revoke a refresh token by hash.
Returns:
True if token was revoked, False if not found or already revoked
"""
row = await self.conn.fetchrow(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE token_hash = $1 AND revoked_at IS NULL
RETURNING id
""",
token_hash,
)
return row is not None
async def revoke_all_for_user(self, user_id: UUID) -> int:
"""Revoke all active refresh tokens for a user.
Use this for:
- User-initiated logout from all devices
- Password change
- Account compromise response
Returns:
Count of tokens revoked
"""
rows = await self.conn.fetch(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE user_id = $1 AND revoked_at IS NULL
RETURNING id
""",
user_id,
)
return len(rows)
async def revoke_all_for_user_except(self, user_id: UUID, keep_token_id: UUID) -> int:
"""Revoke all tokens for a user except one (logout other sessions).
Args:
user_id: The user whose tokens to revoke
keep_token_id: The token ID to keep active (current session)
Returns:
Count of tokens revoked
"""
rows = await self.conn.fetch(
"""
UPDATE refresh_tokens
SET revoked_at = clock_timestamp()
WHERE user_id = $1 AND revoked_at IS NULL AND id != $2
RETURNING id
""",
user_id,
keep_token_id,
)
return len(rows)
async def get_active_tokens_for_user(self, user_id: UUID) -> list[dict]:
"""Get all active (non-revoked, non-expired, non-rotated) tokens for a user.
Useful for:
- Showing active sessions
- Auditing
Returns:
List of active token records
"""
rows = await self.conn.fetch(
"""
SELECT id, user_id, token_hash, active_org_id, expires_at,
revoked_at, rotated_to, created_at
FROM refresh_tokens
WHERE user_id = $1
AND revoked_at IS NULL
AND rotated_to IS NULL
AND expires_at > clock_timestamp()
ORDER BY created_at DESC
""",
user_id,
)
return [dict(row) for row in rows]
async def cleanup_expired(self, older_than_days: int = 30) -> int:
"""Delete expired tokens older than specified days.
Note: This performs a hard delete. For audit trails, I think we should:
- Archiving to a separate table first
- Using partitioning with retention policies
- Only calling this for tokens well past their expiry
Args:
older_than_days: Only delete tokens expired more than this many days ago
Returns:
Count of tokens deleted
"""
rows = await self.conn.fetch(
"""
DELETE FROM refresh_tokens
WHERE expires_at < clock_timestamp() - interval '1 day' * $1
RETURNING id
""",
older_than_days,
)
return len(rows)

View File

@@ -0,0 +1,80 @@
"""Service repository for database operations."""
from uuid import UUID
import asyncpg
class ServiceRepository:
"""Database operations for services."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
service_id: UUID,
org_id: UUID,
name: str,
slug: str,
) -> dict:
"""Create a new service."""
row = await self.conn.fetchrow(
"""
INSERT INTO services (id, org_id, name, slug)
VALUES ($1, $2, $3, $4)
RETURNING id, org_id, name, slug, created_at
""",
service_id,
org_id,
name,
slug,
)
return dict(row)
async def get_by_id(self, service_id: UUID) -> dict | None:
"""Get service by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, org_id, name, slug, created_at
FROM services
WHERE id = $1
""",
service_id,
)
return dict(row) if row else None
async def get_by_org(self, org_id: UUID) -> list[dict]:
"""Get all services for an organization."""
rows = await self.conn.fetch(
"""
SELECT id, org_id, name, slug, created_at
FROM services
WHERE org_id = $1
ORDER BY name
""",
org_id,
)
return [dict(row) for row in rows]
async def get_by_slug(self, org_id: UUID, slug: str) -> dict | None:
"""Get service by org and slug."""
row = await self.conn.fetchrow(
"""
SELECT id, org_id, name, slug, created_at
FROM services
WHERE org_id = $1 AND slug = $2
""",
org_id,
slug,
)
return dict(row) if row else None
async def slug_exists(self, org_id: UUID, slug: str) -> bool:
"""Check if service slug exists in organization."""
result = await self.conn.fetchval(
"SELECT EXISTS(SELECT 1 FROM services WHERE org_id = $1 AND slug = $2)",
org_id,
slug,
)
return result

63
app/repositories/user.py Normal file
View File

@@ -0,0 +1,63 @@
"""User repository for database operations."""
from uuid import UUID
import asyncpg
class UserRepository:
"""Database operations for users."""
def __init__(self, conn: asyncpg.Connection) -> None:
self.conn = conn
async def create(
self,
user_id: UUID,
email: str,
password_hash: str,
) -> dict:
"""Create a new user."""
row = await self.conn.fetchrow(
"""
INSERT INTO users (id, email, password_hash)
VALUES ($1, $2, $3)
RETURNING id, email, created_at
""",
user_id,
email,
password_hash,
)
return dict(row)
async def get_by_id(self, user_id: UUID) -> dict | None:
"""Get user by ID."""
row = await self.conn.fetchrow(
"""
SELECT id, email, password_hash, created_at
FROM users
WHERE id = $1
""",
user_id,
)
return dict(row) if row else None
async def get_by_email(self, email: str) -> dict | None:
"""Get user by email."""
row = await self.conn.fetchrow(
"""
SELECT id, email, password_hash, created_at
FROM users
WHERE email = $1
""",
email,
)
return dict(row) if row else None
async def exists_by_email(self, email: str) -> bool:
"""Check if user exists by email."""
result = await self.conn.fetchval(
"SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)",
email,
)
return result