feat(api): Pydantic schemas + Data Repositories

This commit is contained in:
2025-12-07 12:00:00 +00:00
parent 359291eec7
commit 3170f10e86
23 changed files with 3549 additions and 3 deletions

View File

@@ -0,0 +1,788 @@
"""Tests for RefreshTokenRepository with security features."""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import asyncpg
import pytest
from app.repositories.org import OrgRepository
from app.repositories.refresh_token import RefreshTokenRepository
from app.repositories.user import UserRepository
class TestRefreshTokenRepository:
"""Tests for basic RefreshTokenRepository operations."""
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
"""Helper to create a user."""
user_repo = UserRepository(conn)
user_id = uuid4()
await user_repo.create(user_id, email, "hash")
return user_id
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
"""Helper to create an org."""
org_repo = OrgRepository(conn)
org_id = uuid4()
await org_repo.create(org_id, f"Org {slug}", slug)
return org_id
async def test_create_token_returns_token_data(self, db_conn: asyncpg.Connection) -> None:
"""Creating a refresh token returns the token data including rotated_to."""
user_id = await self._create_user(db_conn, "token_create@example.com")
org_id = await self._create_org(db_conn, "token-create-org")
repo = RefreshTokenRepository(db_conn)
token_id = uuid4()
token_hash = "sha256_hashed_token_value"
expires_at = datetime.now(UTC) + timedelta(days=30)
result = await repo.create(token_id, user_id, token_hash, org_id, expires_at)
assert result["id"] == token_id
assert result["user_id"] == user_id
assert result["token_hash"] == token_hash
assert result["active_org_id"] == org_id
assert result["expires_at"] is not None
assert result["revoked_at"] is None
assert result["rotated_to"] is None # New field
assert result["created_at"] is not None
async def test_token_hash_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
"""Token hash uniqueness constraint per SPECS.md refresh_tokens table."""
user_id = await self._create_user(db_conn, "unique_hash@example.com")
org_id = await self._create_org(db_conn, "unique-hash-org")
repo = RefreshTokenRepository(db_conn)
token_hash = "duplicate_hash_value"
expires_at = datetime.now(UTC) + timedelta(days=30)
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
with pytest.raises(asyncpg.UniqueViolationError):
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
async def test_get_by_hash_returns_token(self, db_conn: asyncpg.Connection) -> None:
"""get_by_hash returns the correct token (even if revoked/expired)."""
user_id = await self._create_user(db_conn, "get_hash@example.com")
org_id = await self._create_org(db_conn, "get-hash-org")
repo = RefreshTokenRepository(db_conn)
token_id = uuid4()
token_hash = "lookup_by_hash_value"
expires_at = datetime.now(UTC) + timedelta(days=30)
await repo.create(token_id, user_id, token_hash, org_id, expires_at)
result = await repo.get_by_hash(token_hash)
assert result is not None
assert result["id"] == token_id
assert result["token_hash"] == token_hash
async def test_get_by_hash_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""get_by_hash returns None for non-existent hash."""
repo = RefreshTokenRepository(db_conn)
result = await repo.get_by_hash("nonexistent_hash")
assert result is None
class TestGetValidByHash:
"""Tests for get_valid_by_hash with defense-in-depth validation."""
async def _setup_token(
self, conn: asyncpg.Connection, suffix: str = ""
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
"""Helper to create user, org, and token."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
token_id = uuid4()
token_hash = f"token_hash_{uuid4().hex[:8]}{suffix}"
await user_repo.create(user_id, f"user_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
return token_id, user_id, org_id, token_hash, token_repo
async def test_get_valid_returns_valid_token(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns token if not expired, not revoked, not rotated."""
_, _, _, token_hash, repo = await self._setup_token(db_conn)
result = await repo.get_valid_by_hash(token_hash)
assert result is not None
assert result["token_hash"] == token_hash
async def test_get_valid_returns_none_for_expired(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns None for expired token."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "expired@example.com", "hash")
await org_repo.create(org_id, "Org", "expired-org")
token_hash = "expired_token_hash"
expires_at = datetime.now(UTC) - timedelta(days=1) # Already expired
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
result = await repo.get_valid_by_hash(token_hash)
assert result is None
async def test_get_valid_returns_none_for_revoked(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns None for revoked token."""
token_id, _, _, token_hash, repo = await self._setup_token(db_conn, "_revoked")
await repo.revoke(token_id)
result = await repo.get_valid_by_hash(token_hash)
assert result is None
async def test_get_valid_returns_none_for_rotated(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash returns None for already-rotated token."""
_, user_id, org_id, old_hash, repo = await self._setup_token(db_conn, "_rotated")
# Rotate the token
new_hash = f"new_token_{uuid4().hex[:8]}"
new_expires = datetime.now(UTC) + timedelta(days=30)
await repo.rotate(old_hash, uuid4(), new_hash, new_expires)
# Old token should no longer be valid
result = await repo.get_valid_by_hash(old_hash)
assert result is None
async def test_get_valid_with_user_id_validation(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash validates user_id when provided (defense-in-depth)."""
_, user_id, _, token_hash, repo = await self._setup_token(db_conn, "_user_check")
# Correct user_id should work
result = await repo.get_valid_by_hash(token_hash, user_id=user_id)
assert result is not None
# Wrong user_id should return None
wrong_user_id = uuid4()
result = await repo.get_valid_by_hash(token_hash, user_id=wrong_user_id)
assert result is None
async def test_get_valid_with_org_id_validation(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash validates active_org_id when provided (defense-in-depth)."""
_, _, org_id, token_hash, repo = await self._setup_token(db_conn, "_org_check")
# Correct org_id should work
result = await repo.get_valid_by_hash(token_hash, active_org_id=org_id)
assert result is not None
# Wrong org_id should return None
wrong_org_id = uuid4()
result = await repo.get_valid_by_hash(token_hash, active_org_id=wrong_org_id)
assert result is None
async def test_get_valid_with_both_user_and_org_validation(self, db_conn: asyncpg.Connection) -> None:
"""get_valid_by_hash validates both user_id and active_org_id together."""
_, user_id, org_id, token_hash, repo = await self._setup_token(db_conn, "_both")
# Both correct should work
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=org_id)
assert result is not None
# Either wrong should fail
result = await repo.get_valid_by_hash(token_hash, user_id=uuid4(), active_org_id=org_id)
assert result is None
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=uuid4())
assert result is None
class TestAtomicRotation:
"""Tests for atomic token rotation per SPECS.md."""
async def _setup_token(
self, conn: asyncpg.Connection
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
"""Helper to create user, org, and token."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
token_id = uuid4()
token_hash = f"rotate_token_{uuid4().hex[:8]}"
await user_repo.create(user_id, f"rotate_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Rotate Org", f"rotate-org-{uuid4().hex[:8]}")
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
return token_id, user_id, org_id, token_hash, token_repo
async def test_rotate_creates_new_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() creates a new token and returns it."""
old_id, user_id, org_id, old_hash, repo = await self._setup_token(db_conn)
new_id = uuid4()
new_hash = f"new_rotated_{uuid4().hex[:8]}"
new_expires = datetime.now(UTC) + timedelta(days=30)
result = await repo.rotate(old_hash, new_id, new_hash, new_expires)
assert result is not None
assert result["id"] == new_id
assert result["token_hash"] == new_hash
assert result["user_id"] == user_id
assert result["active_org_id"] == org_id
async def test_rotate_marks_old_token_as_rotated(self, db_conn: asyncpg.Connection) -> None:
"""rotate() sets rotated_to on the old token (not revoked_at)."""
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
new_id = uuid4()
new_hash = f"new_{uuid4().hex[:8]}"
new_expires = datetime.now(UTC) + timedelta(days=30)
await repo.rotate(old_hash, new_id, new_hash, new_expires)
# Check old token state
old_token = await repo.get_by_hash(old_hash)
assert old_token["rotated_to"] == new_id
assert old_token["revoked_at"] is None # Not revoked, just rotated
async def test_rotate_fails_for_invalid_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token is invalid."""
_, _, _, _, repo = await self._setup_token(db_conn)
result = await repo.rotate(
"nonexistent_hash",
uuid4(),
f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30),
)
assert result is None
async def test_rotate_fails_for_expired_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token is expired."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "exp_rotate@example.com", "hash")
await org_repo.create(org_id, "Org", "exp-rotate-org")
old_hash = "expired_for_rotation"
await repo.create(
uuid4(), user_id, old_hash, org_id,
datetime.now(UTC) - timedelta(days=1) # Already expired
)
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result is None
async def test_rotate_fails_for_revoked_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token is revoked."""
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
await repo.revoke(old_id)
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result is None
async def test_rotate_fails_for_already_rotated_token(self, db_conn: asyncpg.Connection) -> None:
"""rotate() returns None if old token was already rotated."""
_, _, _, old_hash, repo = await self._setup_token(db_conn)
# First rotation should succeed
result1 = await repo.rotate(
old_hash, uuid4(), f"new1_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result1 is not None
# Second rotation of same token should fail
result2 = await repo.rotate(
old_hash, uuid4(), f"new2_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30)
)
assert result2 is None
async def test_rotate_with_org_switch(self, db_conn: asyncpg.Connection) -> None:
"""rotate() can change active_org_id (for org-switch flow)."""
_, user_id, old_org_id, old_hash, repo = await self._setup_token(db_conn)
# Create a new org for the user to switch to
org_repo = OrgRepository(db_conn)
new_org_id = uuid4()
await org_repo.create(new_org_id, "New Org", f"new-org-{uuid4().hex[:8]}")
new_hash = f"switched_{uuid4().hex[:8]}"
result = await repo.rotate(
old_hash, uuid4(), new_hash,
datetime.now(UTC) + timedelta(days=30),
new_active_org_id=new_org_id # Switch org
)
assert result is not None
assert result["active_org_id"] == new_org_id
assert result["active_org_id"] != old_org_id
async def test_rotate_validates_expected_user_id(self, db_conn: asyncpg.Connection) -> None:
"""rotate() fails if expected_user_id doesn't match token's user."""
_, user_id, _, old_hash, repo = await self._setup_token(db_conn)
# Wrong user should fail
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30),
expected_user_id=uuid4() # Wrong user
)
assert result is None
# Correct user should work
result = await repo.rotate(
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
datetime.now(UTC) + timedelta(days=30),
expected_user_id=user_id # Correct user
)
assert result is not None
class TestTokenReuseDetection:
"""Tests for detecting token reuse (stolen token attacks)."""
async def _setup_rotated_token(
self, conn: asyncpg.Connection
) -> tuple[uuid4, str, str, RefreshTokenRepository]:
"""Create a token and rotate it, returning old and new hashes."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, f"reuse_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Reuse Org", f"reuse-org-{uuid4().hex[:8]}")
old_hash = f"old_token_{uuid4().hex[:8]}"
expires_at = datetime.now(UTC) + timedelta(days=30)
old_token = await token_repo.create(uuid4(), user_id, old_hash, org_id, expires_at)
new_hash = f"new_token_{uuid4().hex[:8]}"
await token_repo.rotate(old_hash, uuid4(), new_hash, expires_at)
return old_token["id"], old_hash, new_hash, token_repo
async def test_check_token_reuse_detects_rotated_token(self, db_conn: asyncpg.Connection) -> None:
"""check_token_reuse returns token if it has been rotated."""
old_id, old_hash, _, repo = await self._setup_rotated_token(db_conn)
result = await repo.check_token_reuse(old_hash)
assert result is not None
assert result["id"] == old_id
assert result["rotated_to"] is not None
async def test_check_token_reuse_returns_none_for_active_token(self, db_conn: asyncpg.Connection) -> None:
"""check_token_reuse returns None for token that hasn't been rotated."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "active@example.com", "hash")
await org_repo.create(org_id, "Org", "active-org")
token_hash = "active_token_hash"
await repo.create(
uuid4(), user_id, token_hash, org_id,
datetime.now(UTC) + timedelta(days=30)
)
result = await repo.check_token_reuse(token_hash)
assert result is None
async def test_check_token_reuse_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""check_token_reuse returns None for non-existent token."""
repo = RefreshTokenRepository(db_conn)
result = await repo.check_token_reuse("nonexistent_hash")
assert result is None
class TestTokenChainRevocation:
"""Tests for revoking entire token chains (breach response)."""
async def _setup_token_chain(
self, conn: asyncpg.Connection, chain_length: int = 3
) -> tuple[list[uuid4], list[str], uuid4, RefreshTokenRepository]:
"""Create a chain of rotated tokens."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, f"chain_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Chain Org", f"chain-org-{uuid4().hex[:8]}")
token_ids = []
token_hashes = []
expires_at = datetime.now(UTC) + timedelta(days=30)
# Create first token
first_hash = f"chain_token_0_{uuid4().hex[:8]}"
first_token = await token_repo.create(uuid4(), user_id, first_hash, org_id, expires_at)
token_ids.append(first_token["id"])
token_hashes.append(first_hash)
# Rotate to create chain
current_hash = first_hash
for i in range(1, chain_length):
new_hash = f"chain_token_{i}_{uuid4().hex[:8]}"
new_id = uuid4()
await token_repo.rotate(current_hash, new_id, new_hash, expires_at)
token_ids.append(new_id)
token_hashes.append(new_hash)
current_hash = new_hash
return token_ids, token_hashes, user_id, token_repo
async def test_revoke_token_chain_revokes_all_in_chain(self, db_conn: asyncpg.Connection) -> None:
"""revoke_token_chain revokes the token and all its rotations."""
token_ids, token_hashes, _, repo = await self._setup_token_chain(db_conn, chain_length=3)
# Revoke starting from the first token
count = await repo.revoke_token_chain(token_ids[0])
# Should revoke all 3 tokens in the chain
# But note: only the last one wasn't already "consumed" by rotation
# Let's check that revoke was called on all that were eligible
assert count >= 1 # At least the leaf token
# Verify the leaf token is revoked
leaf_token = await repo.get_by_hash(token_hashes[-1])
assert leaf_token["revoked_at"] is not None
async def test_revoke_token_chain_returns_count(self, db_conn: asyncpg.Connection) -> None:
"""revoke_token_chain returns count of actually revoked tokens."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "single@example.com", "hash")
await org_repo.create(org_id, "Single Org", "single-org")
token_hash = "single_chain_token"
token = await repo.create(
uuid4(), user_id, token_hash, org_id,
datetime.now(UTC) + timedelta(days=30)
)
count = await repo.revoke_token_chain(token["id"])
assert count == 1
class TestTokenRevocation:
"""Tests for token revocation methods."""
async def _setup_token(self, conn: asyncpg.Connection) -> tuple[uuid4, str, RefreshTokenRepository]:
"""Helper to create user, org, and token."""
user_repo = UserRepository(conn)
org_repo = OrgRepository(conn)
token_repo = RefreshTokenRepository(conn)
user_id = uuid4()
org_id = uuid4()
token_id = uuid4()
token_hash = f"revoke_token_{uuid4().hex[:8]}"
await user_repo.create(user_id, f"revoke_{uuid4().hex[:8]}@example.com", "hash")
await org_repo.create(org_id, "Revoke Org", f"revoke-org-{uuid4().hex[:8]}")
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
return token_id, token_hash, token_repo
async def test_revoke_sets_revoked_at(self, db_conn: asyncpg.Connection) -> None:
"""revoke() sets the revoked_at timestamp."""
token_id, token_hash, repo = await self._setup_token(db_conn)
result = await repo.revoke(token_id)
assert result is True
token = await repo.get_by_hash(token_hash)
assert token["revoked_at"] is not None
async def test_revoke_returns_true_on_success(self, db_conn: asyncpg.Connection) -> None:
"""revoke() returns True when token is revoked."""
token_id, _, repo = await self._setup_token(db_conn)
result = await repo.revoke(token_id)
assert result is True
async def test_revoke_returns_false_for_already_revoked(self, db_conn: asyncpg.Connection) -> None:
"""revoke() returns False if token already revoked."""
token_id, _, repo = await self._setup_token(db_conn)
await repo.revoke(token_id)
result = await repo.revoke(token_id)
assert result is False
async def test_revoke_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""revoke() returns False for non-existent token."""
_, _, repo = await self._setup_token(db_conn)
result = await repo.revoke(uuid4())
assert result is False
async def test_revoke_by_hash_works(self, db_conn: asyncpg.Connection) -> None:
"""revoke_by_hash() revokes token by hash value."""
_, token_hash, repo = await self._setup_token(db_conn)
result = await repo.revoke_by_hash(token_hash)
assert result is True
token = await repo.get_by_hash(token_hash)
assert token["revoked_at"] is not None
async def test_revoke_by_hash_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
"""revoke_by_hash() returns False for non-existent hash."""
_, _, repo = await self._setup_token(db_conn)
result = await repo.revoke_by_hash("nonexistent_hash")
assert result is False
class TestRevokeAllForUser:
"""Tests for revoking all tokens for a user."""
async def test_revoke_all_for_user_revokes_all_tokens(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user() revokes all tokens for the user."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "multi_token@example.com", "hash")
await org_repo.create(org_id, "Multi Token Org", "multi-token-org")
# Create multiple tokens
hashes = []
for i in range(3):
token_hash = f"token_{i}_{uuid4().hex[:8]}"
hashes.append(token_hash)
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
result = await token_repo.revoke_all_for_user(user_id)
assert result == 3
for token_hash in hashes:
token = await token_repo.get_valid_by_hash(token_hash)
assert token is None
async def test_revoke_all_for_user_returns_zero_for_no_tokens(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user() returns 0 if user has no tokens."""
user_repo = UserRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
await user_repo.create(user_id, "no_tokens@example.com", "hash")
result = await token_repo.revoke_all_for_user(user_id)
assert result == 0
async def test_revoke_all_for_user_only_affects_user_tokens(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user() doesn't affect other users' tokens."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user1 = uuid4()
user2 = uuid4()
org_id = uuid4()
await user_repo.create(user1, "user1@example.com", "hash")
await user_repo.create(user2, "user2@example.com", "hash")
await org_repo.create(org_id, "Shared Org", "shared-org")
user1_hash = f"user1_token_{uuid4().hex[:8]}"
user2_hash = f"user2_token_{uuid4().hex[:8]}"
expires_at = datetime.now(UTC) + timedelta(days=30)
await token_repo.create(uuid4(), user1, user1_hash, org_id, expires_at)
await token_repo.create(uuid4(), user2, user2_hash, org_id, expires_at)
await token_repo.revoke_all_for_user(user1)
# User1's token is revoked
assert await token_repo.get_valid_by_hash(user1_hash) is None
# User2's token is still valid
assert await token_repo.get_valid_by_hash(user2_hash) is not None
class TestRevokeAllExcept:
"""Tests for revoking all tokens except current session."""
async def test_revoke_all_except_keeps_specified_token(self, db_conn: asyncpg.Connection) -> None:
"""revoke_all_for_user_except() keeps the specified token active."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "except@example.com", "hash")
await org_repo.create(org_id, "Except Org", "except-org")
# Create multiple tokens
expires_at = datetime.now(UTC) + timedelta(days=30)
keep_token_id = uuid4()
keep_hash = f"keep_token_{uuid4().hex[:8]}"
await token_repo.create(keep_token_id, user_id, keep_hash, org_id, expires_at)
other_hashes = []
for i in range(2):
other_hash = f"other_token_{i}_{uuid4().hex[:8]}"
other_hashes.append(other_hash)
await token_repo.create(uuid4(), user_id, other_hash, org_id, expires_at)
result = await token_repo.revoke_all_for_user_except(user_id, keep_token_id)
assert result == 2 # Revoked 2 other tokens
# Keep token is still valid
assert await token_repo.get_valid_by_hash(keep_hash) is not None
# Other tokens are revoked
for other_hash in other_hashes:
assert await token_repo.get_valid_by_hash(other_hash) is None
class TestActiveTokensForUser:
"""Tests for listing active tokens for a user."""
async def test_get_active_tokens_returns_only_active(self, db_conn: asyncpg.Connection) -> None:
"""get_active_tokens_for_user() returns only non-revoked, non-expired, non-rotated."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "active_list@example.com", "hash")
await org_repo.create(org_id, "Active List Org", "active-list-org")
expires_at = datetime.now(UTC) + timedelta(days=30)
expired_at = datetime.now(UTC) - timedelta(days=1)
# Create active token
active_hash = f"active_{uuid4().hex[:8]}"
await token_repo.create(uuid4(), user_id, active_hash, org_id, expires_at)
# Create revoked token
revoked_id = uuid4()
revoked_hash = f"revoked_{uuid4().hex[:8]}"
await token_repo.create(revoked_id, user_id, revoked_hash, org_id, expires_at)
await token_repo.revoke(revoked_id)
# Create expired token
expired_hash = f"expired_{uuid4().hex[:8]}"
await token_repo.create(uuid4(), user_id, expired_hash, org_id, expired_at)
# Create rotated token
rotated_hash = f"rotated_{uuid4().hex[:8]}"
await token_repo.create(uuid4(), user_id, rotated_hash, org_id, expires_at)
await token_repo.rotate(rotated_hash, uuid4(), f"new_{uuid4().hex[:8]}", expires_at)
result = await token_repo.get_active_tokens_for_user(user_id)
# Should only return the active token and the new rotated token
assert len(result) == 2
hashes = {t["token_hash"] for t in result}
assert active_hash in hashes
assert revoked_hash not in hashes
assert expired_hash not in hashes
assert rotated_hash not in hashes
class TestTokenForeignKeys:
"""Tests for refresh token foreign key constraints."""
async def test_token_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""refresh_tokens.user_id must reference existing user."""
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
org_id = uuid4()
await org_repo.create(org_id, "FK Test Org", "fk-test-org")
with pytest.raises(asyncpg.ForeignKeyViolationError):
await token_repo.create(
uuid4(), uuid4(), "orphan_token", org_id,
datetime.now(UTC) + timedelta(days=30)
)
async def test_token_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
"""refresh_tokens.active_org_id must reference existing org."""
user_repo = UserRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
await user_repo.create(user_id, "fk_org_test@example.com", "hash")
with pytest.raises(asyncpg.ForeignKeyViolationError):
await token_repo.create(
uuid4(), user_id, "orphan_org_token", uuid4(),
datetime.now(UTC) + timedelta(days=30)
)
async def test_token_stores_active_org_id(self, db_conn: asyncpg.Connection) -> None:
"""Token stores active_org_id for org context per SPECS.md."""
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
token_repo = RefreshTokenRepository(db_conn)
user_id = uuid4()
org_id = uuid4()
await user_repo.create(user_id, "active_org@example.com", "hash")
await org_repo.create(org_id, "Active Org", "active-org")
token_hash = f"active_org_token_{uuid4().hex[:8]}"
await token_repo.create(
uuid4(), user_id, token_hash, org_id,
datetime.now(UTC) + timedelta(days=30)
)
token = await token_repo.get_by_hash(token_hash)
assert token["active_org_id"] == org_id