feat(auth): implement auth stack

This commit is contained in:
2025-12-29 09:55:30 +00:00
parent 3170f10e86
commit ad94833830
13 changed files with 1199 additions and 11 deletions

View File

@@ -0,0 +1,260 @@
"""Unit tests covering AuthService flows."""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from uuid import UUID, uuid4
import pytest
from app.api.deps import CurrentUser
from app.core import security
from app.db import Database
from app.schemas.auth import (
LoginRequest,
LogoutRequest,
RefreshRequest,
RegisterRequest,
SwitchOrgRequest,
)
from app.services.auth import AuthService
pytestmark = pytest.mark.asyncio
class _SingleConnectionDatabase(Database):
"""Database stub that reuses a single asyncpg connection."""
def __init__(self, conn) -> None: # type: ignore[override]
self._conn = conn
@asynccontextmanager
async def connection(self): # type: ignore[override]
yield self._conn
@asynccontextmanager
async def transaction(self): # type: ignore[override]
tr = self._conn.transaction()
await tr.start()
try:
yield self._conn
except Exception:
await tr.rollback()
raise
else:
await tr.commit()
@pytest.fixture
async def auth_service(db_conn):
"""AuthService bound to the per-test database connection."""
return AuthService(database=_SingleConnectionDatabase(db_conn))
async def _create_user(conn, email: str, password: str) -> UUID:
user_id = uuid4()
password_hash = security.hash_password(password)
await conn.execute(
"INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3)",
user_id,
email,
password_hash,
)
return user_id
async def _create_org(
conn,
name: str,
slug: str | None = None,
*,
created_at: datetime | None = None,
) -> UUID:
org_id = uuid4()
slug_value = slug or f"{name.lower().replace(' ', '-')}-{org_id.hex[:6]}"
created = created_at or datetime.now(UTC)
await conn.execute(
"INSERT INTO orgs (id, name, slug, created_at) VALUES ($1, $2, $3, $4)",
org_id,
name,
slug_value,
created,
)
return org_id
async def _add_membership(conn, user_id: UUID, org_id: UUID, role: str) -> None:
await conn.execute(
"INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)",
uuid4(),
user_id,
org_id,
role,
)
async def test_register_user_creates_admin_membership(auth_service, db_conn):
request = RegisterRequest(
email="founder@example.com",
password="SuperSecret1!",
org_name="Founders Inc",
)
response = await auth_service.register_user(request)
payload = security.decode_access_token(response.access_token)
assert payload["org_role"] == "admin"
user_id = UUID(payload["sub"])
org_id = UUID(payload["org_id"])
user = await db_conn.fetchrow("SELECT email FROM users WHERE id = $1", user_id)
assert user is not None and user["email"] == request.email
membership = await db_conn.fetchrow(
"SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2",
user_id,
org_id,
)
assert membership is not None and membership["role"] == "admin"
refresh_hash = security.hash_token(response.refresh_token)
refresh_row = await db_conn.fetchrow(
"SELECT user_id, active_org_id FROM refresh_tokens WHERE token_hash = $1",
refresh_hash,
)
assert refresh_row is not None
assert refresh_row["user_id"] == user_id
assert refresh_row["active_org_id"] == org_id
async def test_login_user_returns_tokens_for_valid_credentials(auth_service, db_conn):
email = "member@example.com"
password = "Password123!"
user_id = await _create_user(db_conn, email, password)
org_id = await _create_org(
db_conn,
name="Member Org",
slug="member-org",
created_at=datetime.now(UTC) - timedelta(days=1),
)
await _add_membership(db_conn, user_id, org_id, "member")
response = await auth_service.login_user(LoginRequest(email=email, password=password))
payload = security.decode_access_token(response.access_token)
assert payload["sub"] == str(user_id)
assert payload["org_id"] == str(org_id)
refresh_hash = security.hash_token(response.refresh_token)
refresh_row = await db_conn.fetchrow(
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
refresh_hash,
)
assert refresh_row is not None and refresh_row["active_org_id"] == org_id
async def test_refresh_tokens_rotates_existing_token(auth_service, db_conn):
email = "rotate@example.com"
password = "Rotate123!"
user_id = await _create_user(db_conn, email, password)
org_id = await _create_org(db_conn, name="Rotate Org", slug="rotate-org")
await _add_membership(db_conn, user_id, org_id, "member")
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
rotated = await auth_service.refresh_tokens(
RefreshRequest(refresh_token=initial.refresh_token)
)
assert rotated.refresh_token != initial.refresh_token
old_hash = security.hash_token(initial.refresh_token)
old_row = await db_conn.fetchrow(
"SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1",
old_hash,
)
assert old_row is not None and old_row["rotated_to"] is not None
new_hash = security.hash_token(rotated.refresh_token)
new_row = await db_conn.fetchrow(
"SELECT user_id FROM refresh_tokens WHERE token_hash = $1",
new_hash,
)
assert new_row is not None and new_row["user_id"] == user_id
async def test_switch_org_updates_active_org(auth_service, db_conn):
email = "switcher@example.com"
password = "Switch123!"
user_id = await _create_user(db_conn, email, password)
primary_org = await _create_org(
db_conn,
name="Primary Org",
slug="primary-org",
created_at=datetime.now(UTC) - timedelta(days=2),
)
await _add_membership(db_conn, user_id, primary_org, "member")
secondary_org = await _create_org(
db_conn,
name="Secondary Org",
slug="secondary-org",
created_at=datetime.now(UTC) - timedelta(days=1),
)
await _add_membership(db_conn, user_id, secondary_org, "admin")
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
current_user = CurrentUser(
user_id=user_id,
email=email,
org_id=primary_org,
org_role="member",
token=initial.access_token,
)
switched = await auth_service.switch_org(
current_user,
SwitchOrgRequest(org_id=secondary_org, refresh_token=initial.refresh_token),
)
payload = security.decode_access_token(switched.access_token)
assert payload["org_id"] == str(secondary_org)
assert payload["org_role"] == "admin"
new_hash = security.hash_token(switched.refresh_token)
new_row = await db_conn.fetchrow(
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
new_hash,
)
assert new_row is not None and new_row["active_org_id"] == secondary_org
async def test_logout_revokes_refresh_token(auth_service, db_conn):
email = "logout@example.com"
password = "Logout123!"
user_id = await _create_user(db_conn, email, password)
org_id = await _create_org(db_conn, name="Logout Org", slug="logout-org")
await _add_membership(db_conn, user_id, org_id, "member")
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
current_user = CurrentUser(
user_id=user_id,
email=email,
org_id=org_id,
org_role="member",
token=initial.access_token,
)
await auth_service.logout(current_user, LogoutRequest(refresh_token=initial.refresh_token))
token_hash = security.hash_token(initial.refresh_token)
row = await db_conn.fetchrow(
"SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1",
token_hash,
)
assert row is not None and row["revoked_at"] is not None