feat(api): Pydantic schemas + Data Repositories
This commit is contained in:
788
tests/repositories/test_refresh_token.py
Normal file
788
tests/repositories/test_refresh_token.py
Normal file
@@ -0,0 +1,788 @@
|
||||
"""Tests for RefreshTokenRepository with security features."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
from app.repositories.org import OrgRepository
|
||||
from app.repositories.refresh_token import RefreshTokenRepository
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
|
||||
class TestRefreshTokenRepository:
|
||||
"""Tests for basic RefreshTokenRepository operations."""
|
||||
|
||||
async def _create_user(self, conn: asyncpg.Connection, email: str) -> uuid4:
|
||||
"""Helper to create a user."""
|
||||
user_repo = UserRepository(conn)
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, email, "hash")
|
||||
return user_id
|
||||
|
||||
async def _create_org(self, conn: asyncpg.Connection, slug: str) -> uuid4:
|
||||
"""Helper to create an org."""
|
||||
org_repo = OrgRepository(conn)
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, f"Org {slug}", slug)
|
||||
return org_id
|
||||
|
||||
async def test_create_token_returns_token_data(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Creating a refresh token returns the token data including rotated_to."""
|
||||
user_id = await self._create_user(db_conn, "token_create@example.com")
|
||||
org_id = await self._create_org(db_conn, "token-create-org")
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
token_id = uuid4()
|
||||
token_hash = "sha256_hashed_token_value"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
result = await repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
assert result["id"] == token_id
|
||||
assert result["user_id"] == user_id
|
||||
assert result["token_hash"] == token_hash
|
||||
assert result["active_org_id"] == org_id
|
||||
assert result["expires_at"] is not None
|
||||
assert result["revoked_at"] is None
|
||||
assert result["rotated_to"] is None # New field
|
||||
assert result["created_at"] is not None
|
||||
|
||||
async def test_token_hash_must_be_unique(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Token hash uniqueness constraint per SPECS.md refresh_tokens table."""
|
||||
user_id = await self._create_user(db_conn, "unique_hash@example.com")
|
||||
org_id = await self._create_org(db_conn, "unique-hash-org")
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
token_hash = "duplicate_hash_value"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
with pytest.raises(asyncpg.UniqueViolationError):
|
||||
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
async def test_get_by_hash_returns_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_hash returns the correct token (even if revoked/expired)."""
|
||||
user_id = await self._create_user(db_conn, "get_hash@example.com")
|
||||
org_id = await self._create_org(db_conn, "get-hash-org")
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
token_id = uuid4()
|
||||
token_hash = "lookup_by_hash_value"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
await repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
result = await repo.get_by_hash(token_hash)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == token_id
|
||||
assert result["token_hash"] == token_hash
|
||||
|
||||
async def test_get_by_hash_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_by_hash returns None for non-existent hash."""
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
result = await repo.get_by_hash("nonexistent_hash")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetValidByHash:
|
||||
"""Tests for get_valid_by_hash with defense-in-depth validation."""
|
||||
|
||||
async def _setup_token(
|
||||
self, conn: asyncpg.Connection, suffix: str = ""
|
||||
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
|
||||
"""Helper to create user, org, and token."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
token_id = uuid4()
|
||||
token_hash = f"token_hash_{uuid4().hex[:8]}{suffix}"
|
||||
|
||||
await user_repo.create(user_id, f"user_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Test Org", f"test-org-{uuid4().hex[:8]}")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
return token_id, user_id, org_id, token_hash, token_repo
|
||||
|
||||
async def test_get_valid_returns_valid_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns token if not expired, not revoked, not rotated."""
|
||||
_, _, _, token_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.get_valid_by_hash(token_hash)
|
||||
|
||||
assert result is not None
|
||||
assert result["token_hash"] == token_hash
|
||||
|
||||
async def test_get_valid_returns_none_for_expired(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns None for expired token."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "expired@example.com", "hash")
|
||||
await org_repo.create(org_id, "Org", "expired-org")
|
||||
|
||||
token_hash = "expired_token_hash"
|
||||
expires_at = datetime.now(UTC) - timedelta(days=1) # Already expired
|
||||
await repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
result = await repo.get_valid_by_hash(token_hash)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_returns_none_for_revoked(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns None for revoked token."""
|
||||
token_id, _, _, token_hash, repo = await self._setup_token(db_conn, "_revoked")
|
||||
|
||||
await repo.revoke(token_id)
|
||||
result = await repo.get_valid_by_hash(token_hash)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_returns_none_for_rotated(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash returns None for already-rotated token."""
|
||||
_, user_id, org_id, old_hash, repo = await self._setup_token(db_conn, "_rotated")
|
||||
|
||||
# Rotate the token
|
||||
new_hash = f"new_token_{uuid4().hex[:8]}"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=30)
|
||||
await repo.rotate(old_hash, uuid4(), new_hash, new_expires)
|
||||
|
||||
# Old token should no longer be valid
|
||||
result = await repo.get_valid_by_hash(old_hash)
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_with_user_id_validation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash validates user_id when provided (defense-in-depth)."""
|
||||
_, user_id, _, token_hash, repo = await self._setup_token(db_conn, "_user_check")
|
||||
|
||||
# Correct user_id should work
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=user_id)
|
||||
assert result is not None
|
||||
|
||||
# Wrong user_id should return None
|
||||
wrong_user_id = uuid4()
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=wrong_user_id)
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_with_org_id_validation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash validates active_org_id when provided (defense-in-depth)."""
|
||||
_, _, org_id, token_hash, repo = await self._setup_token(db_conn, "_org_check")
|
||||
|
||||
# Correct org_id should work
|
||||
result = await repo.get_valid_by_hash(token_hash, active_org_id=org_id)
|
||||
assert result is not None
|
||||
|
||||
# Wrong org_id should return None
|
||||
wrong_org_id = uuid4()
|
||||
result = await repo.get_valid_by_hash(token_hash, active_org_id=wrong_org_id)
|
||||
assert result is None
|
||||
|
||||
async def test_get_valid_with_both_user_and_org_validation(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_valid_by_hash validates both user_id and active_org_id together."""
|
||||
_, user_id, org_id, token_hash, repo = await self._setup_token(db_conn, "_both")
|
||||
|
||||
# Both correct should work
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=org_id)
|
||||
assert result is not None
|
||||
|
||||
# Either wrong should fail
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=uuid4(), active_org_id=org_id)
|
||||
assert result is None
|
||||
|
||||
result = await repo.get_valid_by_hash(token_hash, user_id=user_id, active_org_id=uuid4())
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAtomicRotation:
|
||||
"""Tests for atomic token rotation per SPECS.md."""
|
||||
|
||||
async def _setup_token(
|
||||
self, conn: asyncpg.Connection
|
||||
) -> tuple[uuid4, uuid4, uuid4, str, RefreshTokenRepository]:
|
||||
"""Helper to create user, org, and token."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
token_id = uuid4()
|
||||
token_hash = f"rotate_token_{uuid4().hex[:8]}"
|
||||
|
||||
await user_repo.create(user_id, f"rotate_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Rotate Org", f"rotate-org-{uuid4().hex[:8]}")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
return token_id, user_id, org_id, token_hash, token_repo
|
||||
|
||||
async def test_rotate_creates_new_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() creates a new token and returns it."""
|
||||
old_id, user_id, org_id, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
new_id = uuid4()
|
||||
new_hash = f"new_rotated_{uuid4().hex[:8]}"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
result = await repo.rotate(old_hash, new_id, new_hash, new_expires)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == new_id
|
||||
assert result["token_hash"] == new_hash
|
||||
assert result["user_id"] == user_id
|
||||
assert result["active_org_id"] == org_id
|
||||
|
||||
async def test_rotate_marks_old_token_as_rotated(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() sets rotated_to on the old token (not revoked_at)."""
|
||||
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
new_id = uuid4()
|
||||
new_hash = f"new_{uuid4().hex[:8]}"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
await repo.rotate(old_hash, new_id, new_hash, new_expires)
|
||||
|
||||
# Check old token state
|
||||
old_token = await repo.get_by_hash(old_hash)
|
||||
assert old_token["rotated_to"] == new_id
|
||||
assert old_token["revoked_at"] is None # Not revoked, just rotated
|
||||
|
||||
async def test_rotate_fails_for_invalid_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token is invalid."""
|
||||
_, _, _, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.rotate(
|
||||
"nonexistent_hash",
|
||||
uuid4(),
|
||||
f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_rotate_fails_for_expired_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token is expired."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "exp_rotate@example.com", "hash")
|
||||
await org_repo.create(org_id, "Org", "exp-rotate-org")
|
||||
|
||||
old_hash = "expired_for_rotation"
|
||||
await repo.create(
|
||||
uuid4(), user_id, old_hash, org_id,
|
||||
datetime.now(UTC) - timedelta(days=1) # Already expired
|
||||
)
|
||||
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_rotate_fails_for_revoked_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token is revoked."""
|
||||
old_id, _, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
await repo.revoke(old_id)
|
||||
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_rotate_fails_for_already_rotated_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() returns None if old token was already rotated."""
|
||||
_, _, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
# First rotation should succeed
|
||||
result1 = await repo.rotate(
|
||||
old_hash, uuid4(), f"new1_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
assert result1 is not None
|
||||
|
||||
# Second rotation of same token should fail
|
||||
result2 = await repo.rotate(
|
||||
old_hash, uuid4(), f"new2_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
async def test_rotate_with_org_switch(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() can change active_org_id (for org-switch flow)."""
|
||||
_, user_id, old_org_id, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
# Create a new org for the user to switch to
|
||||
org_repo = OrgRepository(db_conn)
|
||||
new_org_id = uuid4()
|
||||
await org_repo.create(new_org_id, "New Org", f"new-org-{uuid4().hex[:8]}")
|
||||
|
||||
new_hash = f"switched_{uuid4().hex[:8]}"
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), new_hash,
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
new_active_org_id=new_org_id # Switch org
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["active_org_id"] == new_org_id
|
||||
assert result["active_org_id"] != old_org_id
|
||||
|
||||
async def test_rotate_validates_expected_user_id(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""rotate() fails if expected_user_id doesn't match token's user."""
|
||||
_, user_id, _, old_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
# Wrong user should fail
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
expected_user_id=uuid4() # Wrong user
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Correct user should work
|
||||
result = await repo.rotate(
|
||||
old_hash, uuid4(), f"new_{uuid4().hex[:8]}",
|
||||
datetime.now(UTC) + timedelta(days=30),
|
||||
expected_user_id=user_id # Correct user
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestTokenReuseDetection:
|
||||
"""Tests for detecting token reuse (stolen token attacks)."""
|
||||
|
||||
async def _setup_rotated_token(
|
||||
self, conn: asyncpg.Connection
|
||||
) -> tuple[uuid4, str, str, RefreshTokenRepository]:
|
||||
"""Create a token and rotate it, returning old and new hashes."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, f"reuse_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Reuse Org", f"reuse-org-{uuid4().hex[:8]}")
|
||||
|
||||
old_hash = f"old_token_{uuid4().hex[:8]}"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
old_token = await token_repo.create(uuid4(), user_id, old_hash, org_id, expires_at)
|
||||
|
||||
new_hash = f"new_token_{uuid4().hex[:8]}"
|
||||
await token_repo.rotate(old_hash, uuid4(), new_hash, expires_at)
|
||||
|
||||
return old_token["id"], old_hash, new_hash, token_repo
|
||||
|
||||
async def test_check_token_reuse_detects_rotated_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""check_token_reuse returns token if it has been rotated."""
|
||||
old_id, old_hash, _, repo = await self._setup_rotated_token(db_conn)
|
||||
|
||||
result = await repo.check_token_reuse(old_hash)
|
||||
|
||||
assert result is not None
|
||||
assert result["id"] == old_id
|
||||
assert result["rotated_to"] is not None
|
||||
|
||||
async def test_check_token_reuse_returns_none_for_active_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""check_token_reuse returns None for token that hasn't been rotated."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "active@example.com", "hash")
|
||||
await org_repo.create(org_id, "Org", "active-org")
|
||||
|
||||
token_hash = "active_token_hash"
|
||||
await repo.create(
|
||||
uuid4(), user_id, token_hash, org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
result = await repo.check_token_reuse(token_hash)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_check_token_reuse_returns_none_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""check_token_reuse returns None for non-existent token."""
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
result = await repo.check_token_reuse("nonexistent_hash")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestTokenChainRevocation:
|
||||
"""Tests for revoking entire token chains (breach response)."""
|
||||
|
||||
async def _setup_token_chain(
|
||||
self, conn: asyncpg.Connection, chain_length: int = 3
|
||||
) -> tuple[list[uuid4], list[str], uuid4, RefreshTokenRepository]:
|
||||
"""Create a chain of rotated tokens."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, f"chain_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Chain Org", f"chain-org-{uuid4().hex[:8]}")
|
||||
|
||||
token_ids = []
|
||||
token_hashes = []
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
|
||||
# Create first token
|
||||
first_hash = f"chain_token_0_{uuid4().hex[:8]}"
|
||||
first_token = await token_repo.create(uuid4(), user_id, first_hash, org_id, expires_at)
|
||||
token_ids.append(first_token["id"])
|
||||
token_hashes.append(first_hash)
|
||||
|
||||
# Rotate to create chain
|
||||
current_hash = first_hash
|
||||
for i in range(1, chain_length):
|
||||
new_hash = f"chain_token_{i}_{uuid4().hex[:8]}"
|
||||
new_id = uuid4()
|
||||
await token_repo.rotate(current_hash, new_id, new_hash, expires_at)
|
||||
token_ids.append(new_id)
|
||||
token_hashes.append(new_hash)
|
||||
current_hash = new_hash
|
||||
|
||||
return token_ids, token_hashes, user_id, token_repo
|
||||
|
||||
async def test_revoke_token_chain_revokes_all_in_chain(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_token_chain revokes the token and all its rotations."""
|
||||
token_ids, token_hashes, _, repo = await self._setup_token_chain(db_conn, chain_length=3)
|
||||
|
||||
# Revoke starting from the first token
|
||||
count = await repo.revoke_token_chain(token_ids[0])
|
||||
|
||||
# Should revoke all 3 tokens in the chain
|
||||
# But note: only the last one wasn't already "consumed" by rotation
|
||||
# Let's check that revoke was called on all that were eligible
|
||||
assert count >= 1 # At least the leaf token
|
||||
|
||||
# Verify the leaf token is revoked
|
||||
leaf_token = await repo.get_by_hash(token_hashes[-1])
|
||||
assert leaf_token["revoked_at"] is not None
|
||||
|
||||
async def test_revoke_token_chain_returns_count(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_token_chain returns count of actually revoked tokens."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "single@example.com", "hash")
|
||||
await org_repo.create(org_id, "Single Org", "single-org")
|
||||
|
||||
token_hash = "single_chain_token"
|
||||
token = await repo.create(
|
||||
uuid4(), user_id, token_hash, org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
count = await repo.revoke_token_chain(token["id"])
|
||||
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestTokenRevocation:
|
||||
"""Tests for token revocation methods."""
|
||||
|
||||
async def _setup_token(self, conn: asyncpg.Connection) -> tuple[uuid4, str, RefreshTokenRepository]:
|
||||
"""Helper to create user, org, and token."""
|
||||
user_repo = UserRepository(conn)
|
||||
org_repo = OrgRepository(conn)
|
||||
token_repo = RefreshTokenRepository(conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
token_id = uuid4()
|
||||
token_hash = f"revoke_token_{uuid4().hex[:8]}"
|
||||
|
||||
await user_repo.create(user_id, f"revoke_{uuid4().hex[:8]}@example.com", "hash")
|
||||
await org_repo.create(org_id, "Revoke Org", f"revoke-org-{uuid4().hex[:8]}")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(token_id, user_id, token_hash, org_id, expires_at)
|
||||
|
||||
return token_id, token_hash, token_repo
|
||||
|
||||
async def test_revoke_sets_revoked_at(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() sets the revoked_at timestamp."""
|
||||
token_id, token_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke(token_id)
|
||||
|
||||
assert result is True
|
||||
token = await repo.get_by_hash(token_hash)
|
||||
assert token["revoked_at"] is not None
|
||||
|
||||
async def test_revoke_returns_true_on_success(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() returns True when token is revoked."""
|
||||
token_id, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke(token_id)
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_revoke_returns_false_for_already_revoked(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() returns False if token already revoked."""
|
||||
token_id, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
await repo.revoke(token_id)
|
||||
result = await repo.revoke(token_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_revoke_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke() returns False for non-existent token."""
|
||||
_, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke(uuid4())
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_revoke_by_hash_works(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_by_hash() revokes token by hash value."""
|
||||
_, token_hash, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke_by_hash(token_hash)
|
||||
|
||||
assert result is True
|
||||
token = await repo.get_by_hash(token_hash)
|
||||
assert token["revoked_at"] is not None
|
||||
|
||||
async def test_revoke_by_hash_returns_false_for_nonexistent(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_by_hash() returns False for non-existent hash."""
|
||||
_, _, repo = await self._setup_token(db_conn)
|
||||
|
||||
result = await repo.revoke_by_hash("nonexistent_hash")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRevokeAllForUser:
|
||||
"""Tests for revoking all tokens for a user."""
|
||||
|
||||
async def test_revoke_all_for_user_revokes_all_tokens(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user() revokes all tokens for the user."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "multi_token@example.com", "hash")
|
||||
await org_repo.create(org_id, "Multi Token Org", "multi-token-org")
|
||||
|
||||
# Create multiple tokens
|
||||
hashes = []
|
||||
for i in range(3):
|
||||
token_hash = f"token_{i}_{uuid4().hex[:8]}"
|
||||
hashes.append(token_hash)
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(uuid4(), user_id, token_hash, org_id, expires_at)
|
||||
|
||||
result = await token_repo.revoke_all_for_user(user_id)
|
||||
|
||||
assert result == 3
|
||||
for token_hash in hashes:
|
||||
token = await token_repo.get_valid_by_hash(token_hash)
|
||||
assert token is None
|
||||
|
||||
async def test_revoke_all_for_user_returns_zero_for_no_tokens(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user() returns 0 if user has no tokens."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, "no_tokens@example.com", "hash")
|
||||
|
||||
result = await token_repo.revoke_all_for_user(user_id)
|
||||
|
||||
assert result == 0
|
||||
|
||||
async def test_revoke_all_for_user_only_affects_user_tokens(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user() doesn't affect other users' tokens."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user1 = uuid4()
|
||||
user2 = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user1, "user1@example.com", "hash")
|
||||
await user_repo.create(user2, "user2@example.com", "hash")
|
||||
await org_repo.create(org_id, "Shared Org", "shared-org")
|
||||
|
||||
user1_hash = f"user1_token_{uuid4().hex[:8]}"
|
||||
user2_hash = f"user2_token_{uuid4().hex[:8]}"
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
await token_repo.create(uuid4(), user1, user1_hash, org_id, expires_at)
|
||||
await token_repo.create(uuid4(), user2, user2_hash, org_id, expires_at)
|
||||
|
||||
await token_repo.revoke_all_for_user(user1)
|
||||
|
||||
# User1's token is revoked
|
||||
assert await token_repo.get_valid_by_hash(user1_hash) is None
|
||||
# User2's token is still valid
|
||||
assert await token_repo.get_valid_by_hash(user2_hash) is not None
|
||||
|
||||
|
||||
class TestRevokeAllExcept:
|
||||
"""Tests for revoking all tokens except current session."""
|
||||
|
||||
async def test_revoke_all_except_keeps_specified_token(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""revoke_all_for_user_except() keeps the specified token active."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "except@example.com", "hash")
|
||||
await org_repo.create(org_id, "Except Org", "except-org")
|
||||
|
||||
# Create multiple tokens
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
keep_token_id = uuid4()
|
||||
keep_hash = f"keep_token_{uuid4().hex[:8]}"
|
||||
await token_repo.create(keep_token_id, user_id, keep_hash, org_id, expires_at)
|
||||
|
||||
other_hashes = []
|
||||
for i in range(2):
|
||||
other_hash = f"other_token_{i}_{uuid4().hex[:8]}"
|
||||
other_hashes.append(other_hash)
|
||||
await token_repo.create(uuid4(), user_id, other_hash, org_id, expires_at)
|
||||
|
||||
result = await token_repo.revoke_all_for_user_except(user_id, keep_token_id)
|
||||
|
||||
assert result == 2 # Revoked 2 other tokens
|
||||
|
||||
# Keep token is still valid
|
||||
assert await token_repo.get_valid_by_hash(keep_hash) is not None
|
||||
|
||||
# Other tokens are revoked
|
||||
for other_hash in other_hashes:
|
||||
assert await token_repo.get_valid_by_hash(other_hash) is None
|
||||
|
||||
|
||||
class TestActiveTokensForUser:
|
||||
"""Tests for listing active tokens for a user."""
|
||||
|
||||
async def test_get_active_tokens_returns_only_active(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""get_active_tokens_for_user() returns only non-revoked, non-expired, non-rotated."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "active_list@example.com", "hash")
|
||||
await org_repo.create(org_id, "Active List Org", "active-list-org")
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(days=30)
|
||||
expired_at = datetime.now(UTC) - timedelta(days=1)
|
||||
|
||||
# Create active token
|
||||
active_hash = f"active_{uuid4().hex[:8]}"
|
||||
await token_repo.create(uuid4(), user_id, active_hash, org_id, expires_at)
|
||||
|
||||
# Create revoked token
|
||||
revoked_id = uuid4()
|
||||
revoked_hash = f"revoked_{uuid4().hex[:8]}"
|
||||
await token_repo.create(revoked_id, user_id, revoked_hash, org_id, expires_at)
|
||||
await token_repo.revoke(revoked_id)
|
||||
|
||||
# Create expired token
|
||||
expired_hash = f"expired_{uuid4().hex[:8]}"
|
||||
await token_repo.create(uuid4(), user_id, expired_hash, org_id, expired_at)
|
||||
|
||||
# Create rotated token
|
||||
rotated_hash = f"rotated_{uuid4().hex[:8]}"
|
||||
await token_repo.create(uuid4(), user_id, rotated_hash, org_id, expires_at)
|
||||
await token_repo.rotate(rotated_hash, uuid4(), f"new_{uuid4().hex[:8]}", expires_at)
|
||||
|
||||
result = await token_repo.get_active_tokens_for_user(user_id)
|
||||
|
||||
# Should only return the active token and the new rotated token
|
||||
assert len(result) == 2
|
||||
hashes = {t["token_hash"] for t in result}
|
||||
assert active_hash in hashes
|
||||
assert revoked_hash not in hashes
|
||||
assert expired_hash not in hashes
|
||||
assert rotated_hash not in hashes
|
||||
|
||||
|
||||
class TestTokenForeignKeys:
|
||||
"""Tests for refresh token foreign key constraints."""
|
||||
|
||||
async def test_token_requires_valid_user_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""refresh_tokens.user_id must reference existing user."""
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
org_id = uuid4()
|
||||
await org_repo.create(org_id, "FK Test Org", "fk-test-org")
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await token_repo.create(
|
||||
uuid4(), uuid4(), "orphan_token", org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
async def test_token_requires_valid_org_foreign_key(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""refresh_tokens.active_org_id must reference existing org."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
await user_repo.create(user_id, "fk_org_test@example.com", "hash")
|
||||
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await token_repo.create(
|
||||
uuid4(), user_id, "orphan_org_token", uuid4(),
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
async def test_token_stores_active_org_id(self, db_conn: asyncpg.Connection) -> None:
|
||||
"""Token stores active_org_id for org context per SPECS.md."""
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
token_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
await user_repo.create(user_id, "active_org@example.com", "hash")
|
||||
await org_repo.create(org_id, "Active Org", "active-org")
|
||||
|
||||
token_hash = f"active_org_token_{uuid4().hex[:8]}"
|
||||
await token_repo.create(
|
||||
uuid4(), user_id, token_hash, org_id,
|
||||
datetime.now(UTC) + timedelta(days=30)
|
||||
)
|
||||
|
||||
token = await token_repo.get_by_hash(token_hash)
|
||||
assert token["active_org_id"] == org_id
|
||||
Reference in New Issue
Block a user