397 lines
13 KiB
Python
397 lines
13 KiB
Python
"""Refresh token repository for database operations.
|
|
|
|
Security considerations implemented:
|
|
- Atomic rotation using SELECT FOR UPDATE to prevent race conditions
|
|
- Token chain tracking via rotated_to for reuse/theft detection
|
|
- Defense-in-depth validation with user_id and active_org_id checks
|
|
- Uses RETURNING for robust row counting instead of string parsing
|
|
"""
|
|
|
|
from datetime import datetime
|
|
from uuid import UUID
|
|
|
|
import asyncpg
|
|
|
|
|
|
class RefreshTokenRepository:
|
|
"""Database operations for refresh tokens."""
|
|
|
|
def __init__(self, conn: asyncpg.Connection) -> None:
|
|
self.conn = conn
|
|
|
|
async def create(
|
|
self,
|
|
token_id: UUID,
|
|
user_id: UUID,
|
|
token_hash: str,
|
|
active_org_id: UUID,
|
|
expires_at: datetime,
|
|
) -> dict:
|
|
"""Create a new refresh token."""
|
|
row = await self.conn.fetchrow(
|
|
"""
|
|
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
RETURNING id, user_id, token_hash, active_org_id, expires_at,
|
|
revoked_at, rotated_to, created_at
|
|
""",
|
|
token_id,
|
|
user_id,
|
|
token_hash,
|
|
active_org_id,
|
|
expires_at,
|
|
)
|
|
return dict(row)
|
|
|
|
async def get_by_hash(self, token_hash: str) -> dict | None:
|
|
"""Get refresh token by hash (includes revoked/expired for auditing)."""
|
|
row = await self.conn.fetchrow(
|
|
"""
|
|
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
|
revoked_at, rotated_to, created_at
|
|
FROM refresh_tokens
|
|
WHERE token_hash = $1
|
|
""",
|
|
token_hash,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
async def get_valid_by_hash(
|
|
self,
|
|
token_hash: str,
|
|
user_id: UUID | None = None,
|
|
active_org_id: UUID | None = None,
|
|
) -> dict | None:
|
|
"""Get refresh token by hash, only if valid.
|
|
|
|
Validates:
|
|
- Token exists and matches hash
|
|
- Token is not revoked
|
|
- Token is not expired
|
|
- Token has not been rotated (rotated_to is NULL)
|
|
- Optionally: user_id matches (defense-in-depth)
|
|
- Optionally: active_org_id matches (defense-in-depth)
|
|
|
|
Args:
|
|
token_hash: The hashed token value
|
|
user_id: If provided, token must belong to this user
|
|
active_org_id: If provided, token must be bound to this org
|
|
|
|
Returns:
|
|
Token dict if valid, None otherwise
|
|
"""
|
|
query = """
|
|
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
|
revoked_at, rotated_to, created_at
|
|
FROM refresh_tokens
|
|
WHERE token_hash = $1
|
|
AND revoked_at IS NULL
|
|
AND rotated_to IS NULL
|
|
AND expires_at > clock_timestamp()
|
|
"""
|
|
params: list = [token_hash]
|
|
param_idx = 2
|
|
|
|
if user_id is not None:
|
|
query += f" AND user_id = ${param_idx}"
|
|
params.append(user_id)
|
|
param_idx += 1
|
|
|
|
if active_org_id is not None:
|
|
query += f" AND active_org_id = ${param_idx}"
|
|
params.append(active_org_id)
|
|
|
|
row = await self.conn.fetchrow(query, *params)
|
|
return dict(row) if row else None
|
|
|
|
async def get_valid_for_rotation(
|
|
self,
|
|
token_hash: str,
|
|
user_id: UUID | None = None,
|
|
) -> dict | None:
|
|
"""Get and lock a valid token for rotation using SELECT FOR UPDATE.
|
|
|
|
This acquires a row-level lock to prevent concurrent rotation attempts.
|
|
Must be called within a transaction.
|
|
|
|
Args:
|
|
token_hash: The hashed token value
|
|
user_id: If provided, token must belong to this user
|
|
|
|
Returns:
|
|
Token dict if valid and locked, None otherwise
|
|
"""
|
|
query = """
|
|
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
|
revoked_at, rotated_to, created_at
|
|
FROM refresh_tokens
|
|
WHERE token_hash = $1
|
|
AND revoked_at IS NULL
|
|
AND rotated_to IS NULL
|
|
AND expires_at > clock_timestamp()
|
|
"""
|
|
params: list = [token_hash]
|
|
|
|
if user_id is not None:
|
|
query += " AND user_id = $2"
|
|
params.append(user_id)
|
|
|
|
query += " FOR UPDATE"
|
|
|
|
row = await self.conn.fetchrow(query, *params)
|
|
return dict(row) if row else None
|
|
|
|
async def check_token_reuse(self, token_hash: str) -> dict | None:
|
|
"""Check if a token has already been rotated (potential theft).
|
|
|
|
If a token is presented that has rotated_to set, it means:
|
|
1. The token was legitimately rotated earlier
|
|
2. Someone is now trying to use the old token
|
|
3. This indicates the token may have been stolen
|
|
|
|
Returns:
|
|
Token dict if this is a reused/stolen token, None if not found or not rotated
|
|
"""
|
|
row = await self.conn.fetchrow(
|
|
"""
|
|
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
|
revoked_at, rotated_to, created_at
|
|
FROM refresh_tokens
|
|
WHERE token_hash = $1 AND rotated_to IS NOT NULL
|
|
""",
|
|
token_hash,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
async def revoke_token_chain(self, token_id: UUID) -> int:
|
|
"""Revoke a token and all tokens in its chain (for breach response).
|
|
|
|
When token reuse is detected, this revokes:
|
|
1. The original stolen token
|
|
2. Any token it was rotated to (and their rotations, recursively)
|
|
|
|
Args:
|
|
token_id: The ID of the compromised token
|
|
|
|
Returns:
|
|
Count of tokens revoked
|
|
"""
|
|
# Use recursive CTE to find all tokens in the chain
|
|
rows = await self.conn.fetch(
|
|
"""
|
|
WITH RECURSIVE token_chain AS (
|
|
-- Start with the given token
|
|
SELECT id, rotated_to
|
|
FROM refresh_tokens
|
|
WHERE id = $1
|
|
|
|
UNION ALL
|
|
|
|
-- Follow the chain via rotated_to
|
|
SELECT rt.id, rt.rotated_to
|
|
FROM refresh_tokens rt
|
|
INNER JOIN token_chain tc ON rt.id = tc.rotated_to
|
|
)
|
|
UPDATE refresh_tokens
|
|
SET revoked_at = clock_timestamp()
|
|
WHERE id IN (SELECT id FROM token_chain)
|
|
AND revoked_at IS NULL
|
|
RETURNING id
|
|
""",
|
|
token_id,
|
|
)
|
|
return len(rows)
|
|
|
|
async def rotate(
|
|
self,
|
|
old_token_hash: str,
|
|
new_token_id: UUID,
|
|
new_token_hash: str,
|
|
new_expires_at: datetime,
|
|
new_active_org_id: UUID | None = None,
|
|
expected_user_id: UUID | None = None,
|
|
) -> dict | None:
|
|
"""Atomically rotate a refresh token.
|
|
|
|
This method:
|
|
1. Validates the old token (not expired, not revoked, not already rotated)
|
|
2. Locks the row to prevent concurrent rotation
|
|
3. Marks old token as rotated (sets rotated_to)
|
|
4. Creates new token with updated org if specified
|
|
5. All in a single atomic operation
|
|
|
|
Args:
|
|
old_token_hash: Hash of the token being rotated
|
|
new_token_id: UUID for the new token
|
|
new_token_hash: Hash for the new token
|
|
new_expires_at: Expiry time for the new token
|
|
new_active_org_id: New org ID (for org-switch), or None to keep current
|
|
expected_user_id: If provided, validates token belongs to this user
|
|
|
|
Returns:
|
|
New token dict if rotation succeeded, None if old token invalid/expired
|
|
"""
|
|
# First, get and lock the old token
|
|
old_token = await self.get_valid_for_rotation(old_token_hash, expected_user_id)
|
|
if old_token is None:
|
|
return None
|
|
|
|
# Determine the org for the new token
|
|
active_org_id = new_active_org_id or old_token["active_org_id"]
|
|
user_id = old_token["user_id"]
|
|
|
|
# Create the new token
|
|
new_token = await self.conn.fetchrow(
|
|
"""
|
|
INSERT INTO refresh_tokens (id, user_id, token_hash, active_org_id, expires_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
RETURNING id, user_id, token_hash, active_org_id, expires_at,
|
|
revoked_at, rotated_to, created_at
|
|
""",
|
|
new_token_id,
|
|
user_id,
|
|
new_token_hash,
|
|
active_org_id,
|
|
new_expires_at,
|
|
)
|
|
|
|
# Mark the old token as rotated (not revoked - for reuse detection)
|
|
await self.conn.execute(
|
|
"""
|
|
UPDATE refresh_tokens
|
|
SET rotated_to = $2
|
|
WHERE id = $1
|
|
""",
|
|
old_token["id"],
|
|
new_token_id,
|
|
)
|
|
|
|
return dict(new_token)
|
|
|
|
async def revoke(self, token_id: UUID) -> bool:
|
|
"""Revoke a refresh token by ID.
|
|
|
|
Returns:
|
|
True if token was revoked, False if not found or already revoked
|
|
"""
|
|
row = await self.conn.fetchrow(
|
|
"""
|
|
UPDATE refresh_tokens
|
|
SET revoked_at = clock_timestamp()
|
|
WHERE id = $1 AND revoked_at IS NULL
|
|
RETURNING id
|
|
""",
|
|
token_id,
|
|
)
|
|
return row is not None
|
|
|
|
async def revoke_by_hash(self, token_hash: str) -> bool:
|
|
"""Revoke a refresh token by hash.
|
|
|
|
Returns:
|
|
True if token was revoked, False if not found or already revoked
|
|
"""
|
|
row = await self.conn.fetchrow(
|
|
"""
|
|
UPDATE refresh_tokens
|
|
SET revoked_at = clock_timestamp()
|
|
WHERE token_hash = $1 AND revoked_at IS NULL
|
|
RETURNING id
|
|
""",
|
|
token_hash,
|
|
)
|
|
return row is not None
|
|
|
|
async def revoke_all_for_user(self, user_id: UUID) -> int:
|
|
"""Revoke all active refresh tokens for a user.
|
|
|
|
Use this for:
|
|
- User-initiated logout from all devices
|
|
- Password change
|
|
- Account compromise response
|
|
|
|
Returns:
|
|
Count of tokens revoked
|
|
"""
|
|
rows = await self.conn.fetch(
|
|
"""
|
|
UPDATE refresh_tokens
|
|
SET revoked_at = clock_timestamp()
|
|
WHERE user_id = $1 AND revoked_at IS NULL
|
|
RETURNING id
|
|
""",
|
|
user_id,
|
|
)
|
|
return len(rows)
|
|
|
|
async def revoke_all_for_user_except(self, user_id: UUID, keep_token_id: UUID) -> int:
|
|
"""Revoke all tokens for a user except one (logout other sessions).
|
|
|
|
Args:
|
|
user_id: The user whose tokens to revoke
|
|
keep_token_id: The token ID to keep active (current session)
|
|
|
|
Returns:
|
|
Count of tokens revoked
|
|
"""
|
|
rows = await self.conn.fetch(
|
|
"""
|
|
UPDATE refresh_tokens
|
|
SET revoked_at = clock_timestamp()
|
|
WHERE user_id = $1 AND revoked_at IS NULL AND id != $2
|
|
RETURNING id
|
|
""",
|
|
user_id,
|
|
keep_token_id,
|
|
)
|
|
return len(rows)
|
|
|
|
async def get_active_tokens_for_user(self, user_id: UUID) -> list[dict]:
|
|
"""Get all active (non-revoked, non-expired, non-rotated) tokens for a user.
|
|
|
|
Useful for:
|
|
- Showing active sessions
|
|
- Auditing
|
|
|
|
Returns:
|
|
List of active token records
|
|
"""
|
|
rows = await self.conn.fetch(
|
|
"""
|
|
SELECT id, user_id, token_hash, active_org_id, expires_at,
|
|
revoked_at, rotated_to, created_at
|
|
FROM refresh_tokens
|
|
WHERE user_id = $1
|
|
AND revoked_at IS NULL
|
|
AND rotated_to IS NULL
|
|
AND expires_at > clock_timestamp()
|
|
ORDER BY created_at DESC
|
|
""",
|
|
user_id,
|
|
)
|
|
return [dict(row) for row in rows]
|
|
|
|
async def cleanup_expired(self, older_than_days: int = 30) -> int:
|
|
"""Delete expired tokens older than specified days.
|
|
|
|
Note: This performs a hard delete. For audit trails, I think we should:
|
|
- Archiving to a separate table first
|
|
- Using partitioning with retention policies
|
|
- Only calling this for tokens well past their expiry
|
|
|
|
Args:
|
|
older_than_days: Only delete tokens expired more than this many days ago
|
|
|
|
Returns:
|
|
Count of tokens deleted
|
|
"""
|
|
rows = await self.conn.fetch(
|
|
"""
|
|
DELETE FROM refresh_tokens
|
|
WHERE expires_at < clock_timestamp() - interval '1 day' * $1
|
|
RETURNING id
|
|
""",
|
|
older_than_days,
|
|
)
|
|
return len(rows)
|