feat(auth): implement auth stack
This commit is contained in:
101
app/api/deps.py
Normal file
101
app/api/deps.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Shared FastAPI dependencies (auth, RBAC, ownership)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.core import exceptions as exc, security
|
||||
from app.db import db
|
||||
from app.repositories import OrgRepository, UserRepository
|
||||
|
||||
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
ROLE_RANKS: dict[str, int] = {"viewer": 0, "member": 1, "admin": 2}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CurrentUser:
|
||||
"""Authenticated user context derived from the access token."""
|
||||
|
||||
user_id: UUID
|
||||
email: str
|
||||
org_id: UUID
|
||||
org_role: str
|
||||
token: str
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
) -> CurrentUser:
|
||||
"""Extract and validate the current user from the Authorization header."""
|
||||
|
||||
if credentials is None or credentials.scheme.lower() != "bearer":
|
||||
raise exc.UnauthorizedError("Missing bearer token")
|
||||
|
||||
try:
|
||||
payload = security.TokenPayload(security.decode_access_token(credentials.credentials))
|
||||
except security.JWTError as err: # pragma: no cover - jose error types
|
||||
raise exc.UnauthorizedError("Invalid access token") from err
|
||||
|
||||
async with db.connection() as conn:
|
||||
user_repo = UserRepository(conn)
|
||||
user = await user_repo.get_by_id(payload.user_id)
|
||||
if user is None:
|
||||
raise exc.UnauthorizedError("User not found")
|
||||
|
||||
org_repo = OrgRepository(conn)
|
||||
membership = await org_repo.get_member(payload.user_id, payload.org_id)
|
||||
if membership is None:
|
||||
raise exc.ForbiddenError("Organization access denied")
|
||||
|
||||
return CurrentUser(
|
||||
user_id=payload.user_id,
|
||||
email=user["email"],
|
||||
org_id=payload.org_id,
|
||||
org_role=membership["role"],
|
||||
token=credentials.credentials,
|
||||
)
|
||||
|
||||
|
||||
class RoleChecker:
|
||||
"""Dependency that enforces a minimum organization role."""
|
||||
|
||||
def __init__(self, minimum_role: str) -> None:
|
||||
if minimum_role not in ROLE_RANKS:
|
||||
raise ValueError(f"Unknown role '{minimum_role}'")
|
||||
self.minimum_role = minimum_role
|
||||
|
||||
def __call__(self, current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
|
||||
if ROLE_RANKS[current_user.org_role] < ROLE_RANKS[self.minimum_role]:
|
||||
raise exc.ForbiddenError("Insufficient role for this operation")
|
||||
return current_user
|
||||
|
||||
|
||||
def require_role(min_role: str) -> Callable[[CurrentUser], CurrentUser]:
|
||||
"""Factory that returns a dependency enforcing the specified role."""
|
||||
|
||||
return RoleChecker(min_role)
|
||||
|
||||
|
||||
def ensure_org_access(resource_org_id: UUID, current_user: CurrentUser) -> None:
|
||||
"""Verify that the resource belongs to the active org in the token."""
|
||||
|
||||
if resource_org_id != current_user.org_id:
|
||||
raise exc.ForbiddenError("Resource does not belong to the active organization")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CurrentUser",
|
||||
"ROLE_RANKS",
|
||||
"RoleChecker",
|
||||
"bearer_scheme",
|
||||
"ensure_org_access",
|
||||
"get_current_user",
|
||||
"require_role",
|
||||
]
|
||||
59
app/api/v1/auth.py
Normal file
59
app/api/v1/auth.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Authentication API endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from app.api.deps import CurrentUser, get_current_user
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
LogoutRequest,
|
||||
RefreshRequest,
|
||||
RegisterRequest,
|
||||
SwitchOrgRequest,
|
||||
TokenResponse,
|
||||
)
|
||||
from app.services import AuthService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
auth_service = AuthService()
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register_user(payload: RegisterRequest) -> TokenResponse:
|
||||
"""Register a new user and default org, returning auth tokens."""
|
||||
|
||||
return await auth_service.register_user(payload)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login_user(payload: LoginRequest) -> TokenResponse:
|
||||
"""Authenticate an existing user and issue tokens."""
|
||||
|
||||
return await auth_service.login_user(payload)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_tokens(payload: RefreshRequest) -> TokenResponse:
|
||||
"""Rotate refresh token and mint a new access token."""
|
||||
|
||||
return await auth_service.refresh_tokens(payload)
|
||||
|
||||
|
||||
@router.post("/switch-org", response_model=TokenResponse)
|
||||
async def switch_org(
|
||||
payload: SwitchOrgRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> TokenResponse:
|
||||
"""Switch the active organization for the authenticated user."""
|
||||
|
||||
return await auth_service.switch_org(current_user, payload)
|
||||
|
||||
|
||||
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def logout(
|
||||
payload: LogoutRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> None:
|
||||
"""Revoke the provided refresh token for the current session."""
|
||||
|
||||
await auth_service.logout(current_user, payload)
|
||||
33
app/db.py
33
app/db.py
@@ -2,8 +2,10 @@
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
|
||||
import asyncpg
|
||||
from asyncpg.pool import PoolConnectionProxy
|
||||
import redis.asyncio as redis
|
||||
|
||||
|
||||
@@ -27,7 +29,7 @@ class Database:
|
||||
await self.pool.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self) -> AsyncGenerator[asyncpg.Connection, None]:
|
||||
async def connection(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
|
||||
"""Acquire a connection from the pool."""
|
||||
if not self.pool:
|
||||
raise RuntimeError("Database not connected")
|
||||
@@ -35,7 +37,7 @@ class Database:
|
||||
yield conn
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self) -> AsyncGenerator[asyncpg.Connection, None]:
|
||||
async def transaction(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
|
||||
"""Acquire a connection with an active transaction."""
|
||||
if not self.pool:
|
||||
raise RuntimeError("Database not connected")
|
||||
@@ -74,7 +76,26 @@ db = Database()
|
||||
redis_client = RedisClient()
|
||||
|
||||
|
||||
async def get_conn() -> AsyncGenerator[asyncpg.Connection, None]:
|
||||
"""Dependency for getting a database connection."""
|
||||
async with db.connection() as conn:
|
||||
yield conn
|
||||
_connection_ctx: ContextVar[asyncpg.Connection | PoolConnectionProxy | None] = ContextVar(
|
||||
"db_connection",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
async def get_conn() -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
|
||||
"""Dependency that reuses the same DB connection within a request context."""
|
||||
|
||||
existing_conn = _connection_ctx.get()
|
||||
if existing_conn is not None:
|
||||
yield existing_conn
|
||||
return
|
||||
|
||||
if not db.pool:
|
||||
raise RuntimeError("Database not connected")
|
||||
|
||||
async with db.pool.acquire() as conn:
|
||||
token = _connection_ctx.set(conn)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
_connection_ctx.reset(token)
|
||||
|
||||
39
app/main.py
39
app/main.py
@@ -4,8 +4,9 @@ from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from app.api.v1 import health
|
||||
from app.api.v1 import auth, health
|
||||
from app.config import settings
|
||||
from app.db import db, redis_client
|
||||
|
||||
@@ -26,8 +27,44 @@ app = FastAPI(
|
||||
title="IncidentOps",
|
||||
description="Incident management API with multi-tenant org support",
|
||||
version="0.1.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.openapi_tags = [
|
||||
{"name": "auth", "description": "Registration, login, token lifecycle"},
|
||||
{"name": "health", "description": "Service health probes"},
|
||||
]
|
||||
|
||||
|
||||
def custom_openapi() -> dict:
|
||||
"""Add JWT bearer security scheme to the generated OpenAPI schema."""
|
||||
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
security_schemes = openapi_schema.setdefault("components", {}).setdefault("securitySchemes", {})
|
||||
security_schemes["BearerToken"] = {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
"description": "Paste the JWT access token returned by /auth endpoints",
|
||||
}
|
||||
openapi_schema["security"] = [{"BearerToken": []}]
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi # type: ignore[assignment]
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(health.router, prefix=settings.api_v1_prefix, tags=["health"])
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
LogoutRequest,
|
||||
RefreshRequest,
|
||||
RegisterRequest,
|
||||
SwitchOrgRequest,
|
||||
@@ -27,6 +28,7 @@ from app.schemas.org import (
|
||||
__all__ = [
|
||||
# Auth
|
||||
"LoginRequest",
|
||||
"LogoutRequest",
|
||||
"RefreshRequest",
|
||||
"RegisterRequest",
|
||||
"SwitchOrgRequest",
|
||||
|
||||
@@ -33,6 +33,12 @@ class SwitchOrgRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request body for logging out and revoking a refresh token."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Response containing access and refresh tokens."""
|
||||
|
||||
|
||||
5
app/services/__init__.py
Normal file
5
app/services/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Service layer entrypoints."""
|
||||
|
||||
from app.services.auth import AuthService
|
||||
|
||||
__all__ = ["AuthService"]
|
||||
269
app/services/auth.py
Normal file
269
app/services/auth.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Authentication service providing business logic for auth flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import cast
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import asyncpg
|
||||
from asyncpg.pool import PoolConnectionProxy
|
||||
|
||||
from app.api.deps import CurrentUser
|
||||
from app.config import settings
|
||||
from app.core import exceptions as exc, security
|
||||
from app.db import Database, db
|
||||
from app.repositories import OrgRepository, RefreshTokenRepository, UserRepository
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
LogoutRequest,
|
||||
RefreshRequest,
|
||||
RegisterRequest,
|
||||
SwitchOrgRequest,
|
||||
TokenResponse,
|
||||
)
|
||||
|
||||
|
||||
_SLUG_PATTERN = re.compile(r"[^a-z0-9]+")
|
||||
|
||||
|
||||
def _as_conn(conn: asyncpg.Connection | PoolConnectionProxy) -> asyncpg.Connection:
|
||||
"""Helper to satisfy typing when a pool proxy is returned."""
|
||||
|
||||
return cast(asyncpg.Connection, conn)
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Encapsulates authentication workflows (register/login/refresh/logout)."""
|
||||
|
||||
def __init__(self, database: Database | None = None) -> None:
|
||||
self.db = database or db
|
||||
self._access_token_expires_in = settings.access_token_expire_minutes * 60
|
||||
|
||||
async def register_user(self, data: RegisterRequest) -> TokenResponse:
|
||||
"""Create a new user, default org, membership, and token pair."""
|
||||
|
||||
async with self.db.transaction() as conn:
|
||||
db_conn = _as_conn(conn)
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
refresh_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
if await user_repo.exists_by_email(data.email):
|
||||
raise exc.ConflictError("Email already registered")
|
||||
|
||||
user_id = uuid4()
|
||||
org_id = uuid4()
|
||||
member_id = uuid4()
|
||||
password_hash = security.hash_password(data.password)
|
||||
|
||||
await user_repo.create(user_id, data.email, password_hash)
|
||||
slug = await self._generate_unique_org_slug(org_repo, data.org_name)
|
||||
await org_repo.create(org_id, data.org_name, slug)
|
||||
await org_repo.add_member(member_id, user_id, org_id, "admin")
|
||||
|
||||
return await self._issue_token_pair(
|
||||
refresh_repo,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
role="admin",
|
||||
)
|
||||
|
||||
async def login_user(self, data: LoginRequest) -> TokenResponse:
|
||||
"""Authenticate a user and issue tokens for their first organization."""
|
||||
|
||||
async with self.db.connection() as conn:
|
||||
db_conn = _as_conn(conn)
|
||||
user_repo = UserRepository(db_conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
refresh_repo = RefreshTokenRepository(db_conn)
|
||||
|
||||
user = await user_repo.get_by_email(data.email)
|
||||
if not user or not security.verify_password(data.password, user["password_hash"]):
|
||||
raise exc.UnauthorizedError("Invalid email or password")
|
||||
|
||||
orgs = await org_repo.get_user_orgs(user["id"])
|
||||
if not orgs:
|
||||
raise exc.ForbiddenError("User does not belong to any organization")
|
||||
|
||||
active_org = orgs[0]
|
||||
return await self._issue_token_pair(
|
||||
refresh_repo,
|
||||
user_id=user["id"],
|
||||
org_id=active_org["id"],
|
||||
role=active_org["role"],
|
||||
)
|
||||
|
||||
async def refresh_tokens(self, data: RefreshRequest) -> TokenResponse:
|
||||
"""Rotate refresh token and mint a new access token."""
|
||||
|
||||
old_hash = security.hash_token(data.refresh_token)
|
||||
new_refresh_token = security.generate_refresh_token()
|
||||
new_refresh_hash = security.hash_token(new_refresh_token)
|
||||
new_refresh_id = uuid4()
|
||||
new_refresh_expiry = security.get_refresh_token_expiry()
|
||||
|
||||
rotated: dict | None = None
|
||||
membership: dict | None = None
|
||||
|
||||
async with self.db.transaction() as conn:
|
||||
db_conn = _as_conn(conn)
|
||||
refresh_repo = RefreshTokenRepository(db_conn)
|
||||
rotated = await refresh_repo.rotate(
|
||||
old_token_hash=old_hash,
|
||||
new_token_id=new_refresh_id,
|
||||
new_token_hash=new_refresh_hash,
|
||||
new_expires_at=new_refresh_expiry,
|
||||
)
|
||||
|
||||
if rotated is not None:
|
||||
org_repo = OrgRepository(db_conn)
|
||||
membership = await org_repo.get_member(rotated["user_id"], rotated["active_org_id"])
|
||||
if membership is None:
|
||||
raise exc.UnauthorizedError("Invalid refresh token")
|
||||
|
||||
if rotated is None or membership is None:
|
||||
await self._handle_invalid_refresh(old_hash)
|
||||
|
||||
assert rotated is not None and membership is not None
|
||||
access_token = security.create_access_token(
|
||||
sub=str(rotated["user_id"]),
|
||||
org_id=str(rotated["active_org_id"]),
|
||||
org_role=membership["role"],
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
expires_in=self._access_token_expires_in,
|
||||
)
|
||||
|
||||
async def switch_org(
|
||||
self,
|
||||
current_user: CurrentUser,
|
||||
data: SwitchOrgRequest,
|
||||
) -> TokenResponse:
|
||||
"""Switch active organization (rotates refresh token + issues new JWT)."""
|
||||
|
||||
target_org_id = data.org_id
|
||||
old_hash = security.hash_token(data.refresh_token)
|
||||
new_refresh_token = security.generate_refresh_token()
|
||||
new_refresh_hash = security.hash_token(new_refresh_token)
|
||||
new_refresh_expiry = security.get_refresh_token_expiry()
|
||||
|
||||
rotated: dict | None = None
|
||||
membership: dict | None = None
|
||||
|
||||
async with self.db.transaction() as conn:
|
||||
db_conn = _as_conn(conn)
|
||||
org_repo = OrgRepository(db_conn)
|
||||
membership = await org_repo.get_member(current_user.user_id, target_org_id)
|
||||
if membership is None:
|
||||
raise exc.ForbiddenError("Not a member of the requested organization")
|
||||
|
||||
refresh_repo = RefreshTokenRepository(db_conn)
|
||||
rotated = await refresh_repo.rotate(
|
||||
old_token_hash=old_hash,
|
||||
new_token_id=uuid4(),
|
||||
new_token_hash=new_refresh_hash,
|
||||
new_expires_at=new_refresh_expiry,
|
||||
new_active_org_id=target_org_id,
|
||||
expected_user_id=current_user.user_id,
|
||||
)
|
||||
|
||||
if rotated is None:
|
||||
await self._handle_invalid_refresh(old_hash)
|
||||
|
||||
access_token = security.create_access_token(
|
||||
sub=str(current_user.user_id),
|
||||
org_id=str(target_org_id),
|
||||
org_role=membership["role"],
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
expires_in=self._access_token_expires_in,
|
||||
)
|
||||
|
||||
async def logout(self, current_user: CurrentUser, data: LogoutRequest) -> None:
|
||||
"""Revoke the provided refresh token for the current session."""
|
||||
|
||||
token_hash = security.hash_token(data.refresh_token)
|
||||
|
||||
async with self.db.transaction() as conn:
|
||||
refresh_repo = RefreshTokenRepository(_as_conn(conn))
|
||||
token = await refresh_repo.get_by_hash(token_hash)
|
||||
if token and token["user_id"] != current_user.user_id:
|
||||
raise exc.ForbiddenError("Refresh token does not belong to this user")
|
||||
|
||||
if not token:
|
||||
return
|
||||
|
||||
await refresh_repo.revoke(token["id"])
|
||||
|
||||
async def _issue_token_pair(
|
||||
self,
|
||||
refresh_repo: RefreshTokenRepository,
|
||||
*,
|
||||
user_id: UUID,
|
||||
org_id: UUID,
|
||||
role: str,
|
||||
) -> TokenResponse:
|
||||
"""Create access/refresh tokens and persist the refresh token."""
|
||||
|
||||
access_token = security.create_access_token(
|
||||
sub=str(user_id),
|
||||
org_id=str(org_id),
|
||||
org_role=role,
|
||||
)
|
||||
|
||||
refresh_token = security.generate_refresh_token()
|
||||
await refresh_repo.create(
|
||||
token_id=uuid4(),
|
||||
user_id=user_id,
|
||||
token_hash=security.hash_token(refresh_token),
|
||||
active_org_id=org_id,
|
||||
expires_at=security.get_refresh_token_expiry(),
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=self._access_token_expires_in,
|
||||
)
|
||||
|
||||
async def _handle_invalid_refresh(self, token_hash: str) -> None:
|
||||
"""Raise appropriate errors for invalid/compromised refresh tokens."""
|
||||
|
||||
async with self.db.connection() as conn:
|
||||
refresh_repo = RefreshTokenRepository(_as_conn(conn))
|
||||
reused = await refresh_repo.check_token_reuse(token_hash)
|
||||
if reused:
|
||||
await refresh_repo.revoke_token_chain(reused["id"])
|
||||
raise exc.UnauthorizedError("Refresh token reuse detected")
|
||||
|
||||
raise exc.UnauthorizedError("Invalid refresh token")
|
||||
|
||||
async def _generate_unique_org_slug(
|
||||
self,
|
||||
org_repo: OrgRepository,
|
||||
org_name: str,
|
||||
) -> str:
|
||||
"""Slugify the org name and append a counter until unique."""
|
||||
|
||||
base_slug = self._slugify(org_name)
|
||||
candidate = base_slug
|
||||
counter = 1
|
||||
while await org_repo.slug_exists(candidate):
|
||||
suffix = f"-{counter}"
|
||||
max_base_len = 50 - len(suffix)
|
||||
candidate = f"{base_slug[:max_base_len]}{suffix}"
|
||||
counter += 1
|
||||
return candidate
|
||||
|
||||
def _slugify(self, value: str) -> str:
|
||||
"""Convert arbitrary text into a URL-friendly slug."""
|
||||
|
||||
slug = _SLUG_PATTERN.sub("-", value.strip().lower()).strip("-")
|
||||
return slug[:50] or "org"
|
||||
65
tests/api/helpers.py
Normal file
65
tests/api/helpers.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Shared helpers for API integration tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import asyncpg
|
||||
from httpx import AsyncClient
|
||||
|
||||
API_PREFIX = "/v1"
|
||||
|
||||
|
||||
async def register_user(
|
||||
client: AsyncClient,
|
||||
*,
|
||||
email: str,
|
||||
password: str,
|
||||
org_name: str = "Test Org",
|
||||
) -> dict[str, Any]:
|
||||
"""Call the register endpoint and return JSON body (raises on failure)."""
|
||||
|
||||
response = await client.post(
|
||||
f"{API_PREFIX}/auth/register",
|
||||
json={"email": email, "password": password, "org_name": org_name},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
async def create_org(
|
||||
conn: asyncpg.Connection,
|
||||
*,
|
||||
name: str,
|
||||
slug: str | None = None,
|
||||
) -> UUID:
|
||||
"""Insert an organization row and return its ID."""
|
||||
|
||||
org_id = uuid4()
|
||||
slug_value = slug or name.lower().replace(" ", "-")
|
||||
await conn.execute(
|
||||
"INSERT INTO orgs (id, name, slug) VALUES ($1, $2, $3)",
|
||||
org_id,
|
||||
name,
|
||||
slug_value,
|
||||
)
|
||||
return org_id
|
||||
|
||||
|
||||
async def add_membership(
|
||||
conn: asyncpg.Connection,
|
||||
*,
|
||||
user_id: UUID,
|
||||
org_id: UUID,
|
||||
role: str,
|
||||
) -> None:
|
||||
"""Insert a membership record for the user/org pair."""
|
||||
|
||||
await conn.execute(
|
||||
"INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)",
|
||||
uuid4(),
|
||||
user_id,
|
||||
org_id,
|
||||
role,
|
||||
)
|
||||
213
tests/api/test_auth.py
Normal file
213
tests/api/test_auth.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Integration tests for FastAPI auth endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.core import security
|
||||
from tests.api import helpers
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
API_PREFIX = "/v1/auth"
|
||||
|
||||
|
||||
async def test_register_endpoint_persists_user_and_membership(
|
||||
api_client: AsyncClient,
|
||||
db_admin: asyncpg.Connection,
|
||||
) -> None:
|
||||
data = await helpers.register_user(
|
||||
api_client,
|
||||
email="api-register@example.com",
|
||||
password="SuperSecret1!",
|
||||
org_name="API Org",
|
||||
)
|
||||
assert "access_token" in data and "refresh_token" in data
|
||||
|
||||
token_payload = security.decode_access_token(data["access_token"])
|
||||
assert token_payload["org_role"] == "admin"
|
||||
|
||||
stored_user = await db_admin.fetchrow("SELECT email FROM users WHERE email = $1", "api-register@example.com")
|
||||
assert stored_user is not None
|
||||
|
||||
membership = await db_admin.fetchrow(
|
||||
"SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2",
|
||||
UUID(token_payload["sub"]),
|
||||
UUID(token_payload["org_id"]),
|
||||
)
|
||||
assert membership is not None and membership["role"] == "admin"
|
||||
|
||||
|
||||
async def test_login_endpoint_rejects_bad_credentials(
|
||||
api_client: AsyncClient,
|
||||
) -> None:
|
||||
register_payload = {
|
||||
"email": "api-login@example.com",
|
||||
"password": "CorrectHorse1!",
|
||||
"org_name": "Login Org",
|
||||
}
|
||||
await helpers.register_user(api_client, **register_payload)
|
||||
|
||||
response = await api_client.post(
|
||||
f"{API_PREFIX}/login",
|
||||
json={"email": register_payload["email"], "password": "wrong"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
async def test_refresh_endpoint_rotates_refresh_token(
|
||||
api_client: AsyncClient,
|
||||
db_admin: asyncpg.Connection,
|
||||
) -> None:
|
||||
register_payload = {
|
||||
"email": "api-refresh@example.com",
|
||||
"password": "RefreshPass1!",
|
||||
"org_name": "Refresh Org",
|
||||
}
|
||||
initial = await helpers.register_user(api_client, **register_payload)
|
||||
|
||||
response = await api_client.post(
|
||||
f"{API_PREFIX}/refresh",
|
||||
json={"refresh_token": initial["refresh_token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["refresh_token"] != initial["refresh_token"]
|
||||
|
||||
old_hash = security.hash_token(initial["refresh_token"])
|
||||
old_row = await db_admin.fetchrow(
|
||||
"SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1",
|
||||
old_hash,
|
||||
)
|
||||
assert old_row is not None and old_row["rotated_to"] is not None
|
||||
|
||||
|
||||
async def test_refresh_endpoint_detects_reuse(
|
||||
api_client: AsyncClient,
|
||||
db_admin: asyncpg.Connection,
|
||||
) -> None:
|
||||
tokens = await helpers.register_user(
|
||||
api_client,
|
||||
email="api-reuse@example.com",
|
||||
password="ReusePass1!",
|
||||
org_name="Reuse Org",
|
||||
)
|
||||
|
||||
rotated = await api_client.post(
|
||||
f"{API_PREFIX}/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
assert rotated.status_code == 200
|
||||
|
||||
reuse_response = await api_client.post(
|
||||
f"{API_PREFIX}/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
assert reuse_response.status_code == 401
|
||||
|
||||
old_hash = security.hash_token(tokens["refresh_token"])
|
||||
old_row = await db_admin.fetchrow(
|
||||
"SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1",
|
||||
old_hash,
|
||||
)
|
||||
assert old_row is not None and old_row["revoked_at"] is not None
|
||||
|
||||
|
||||
async def test_switch_org_changes_active_org(
|
||||
api_client: AsyncClient,
|
||||
db_admin: asyncpg.Connection,
|
||||
) -> None:
|
||||
email = "api-switch@example.com"
|
||||
register_payload = {
|
||||
"email": email,
|
||||
"password": "SwitchPass1!",
|
||||
"org_name": "Primary Org",
|
||||
}
|
||||
tokens = await helpers.register_user(api_client, **register_payload)
|
||||
|
||||
user_id_row = await db_admin.fetchrow("SELECT id FROM users WHERE email = $1", email)
|
||||
assert user_id_row is not None
|
||||
user_id = user_id_row["id"]
|
||||
|
||||
target_org_id = await helpers.create_org(db_admin, name="Secondary Org", slug="secondary-org")
|
||||
await helpers.add_membership(db_admin, user_id=user_id, org_id=target_org_id, role="member")
|
||||
|
||||
response = await api_client.post(
|
||||
f"{API_PREFIX}/switch-org",
|
||||
json={"org_id": str(target_org_id), "refresh_token": tokens["refresh_token"]},
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
payload = security.decode_access_token(data["access_token"])
|
||||
assert payload["org_id"] == str(target_org_id)
|
||||
assert payload["org_role"] == "member"
|
||||
|
||||
new_hash = security.hash_token(data["refresh_token"])
|
||||
new_row = await db_admin.fetchrow(
|
||||
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
|
||||
new_hash,
|
||||
)
|
||||
assert new_row is not None and new_row["active_org_id"] == target_org_id
|
||||
|
||||
|
||||
async def test_switch_org_forbidden_without_membership(
|
||||
api_client: AsyncClient,
|
||||
db_admin: asyncpg.Connection,
|
||||
) -> None:
|
||||
tokens = await helpers.register_user(
|
||||
api_client,
|
||||
email="api-switch-no-access@example.com",
|
||||
password="SwitchBlock1!",
|
||||
org_name="Primary",
|
||||
)
|
||||
|
||||
foreign_org = await helpers.create_org(db_admin, name="Foreign Org", slug="foreign-org")
|
||||
|
||||
response = await api_client.post(
|
||||
f"{API_PREFIX}/switch-org",
|
||||
json={"org_id": str(foreign_org), "refresh_token": tokens["refresh_token"]},
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
# ensure refresh token still valid after failed attempt
|
||||
retry = await api_client.post(
|
||||
f"{API_PREFIX}/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
assert retry.status_code == 200
|
||||
|
||||
|
||||
async def test_logout_revokes_refresh_token(
|
||||
api_client: AsyncClient,
|
||||
) -> None:
|
||||
register_payload = {
|
||||
"email": "api-logout@example.com",
|
||||
"password": "LogoutPass1!",
|
||||
"org_name": "Logout Org",
|
||||
}
|
||||
tokens = await helpers.register_user(api_client, **register_payload)
|
||||
|
||||
logout_response = await api_client.post(
|
||||
f"{API_PREFIX}/logout",
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert logout_response.status_code == 204
|
||||
|
||||
refresh_response = await api_client.post(
|
||||
f"{API_PREFIX}/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
|
||||
assert refresh_response.status_code == 401
|
||||
@@ -3,9 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from uuid import uuid4
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Callable
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import asyncpg
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Set test environment variables before importing app modules
|
||||
@@ -13,6 +16,8 @@ os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@loca
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only")
|
||||
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1")
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
# Module-level setup: create database and run migrations once
|
||||
_db_initialized = False
|
||||
@@ -65,7 +70,7 @@ async def _init_test_db() -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
async def db_conn() -> AsyncGenerator[asyncpg.Connection, None]:
|
||||
"""Get a database connection with transaction rollback for test isolation."""
|
||||
await _init_test_db()
|
||||
|
||||
@@ -84,12 +89,77 @@ async def db_conn() -> asyncpg.Connection:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_user_id() -> uuid4:
|
||||
def make_user_id() -> Callable[[], UUID]:
|
||||
"""Factory for generating user IDs."""
|
||||
return lambda: uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_org_id() -> uuid4:
|
||||
def make_org_id() -> Callable[[], UUID]:
|
||||
"""Factory for generating org IDs."""
|
||||
return lambda: uuid4()
|
||||
|
||||
|
||||
TABLES_TO_TRUNCATE = [
|
||||
"incident_events",
|
||||
"notification_attempts",
|
||||
"incidents",
|
||||
"notification_targets",
|
||||
"services",
|
||||
"refresh_tokens",
|
||||
"org_members",
|
||||
"orgs",
|
||||
"users",
|
||||
]
|
||||
|
||||
|
||||
async def _truncate_all_tables() -> None:
|
||||
test_dsn = os.environ["DATABASE_URL"]
|
||||
conn = await asyncpg.connect(test_dsn)
|
||||
try:
|
||||
tables = ", ".join(TABLES_TO_TRUNCATE)
|
||||
await conn.execute(f"TRUNCATE TABLE {tables} CASCADE")
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def clean_database() -> AsyncGenerator[None, None]:
|
||||
"""Ensure the database is initialized and truncated before/after tests."""
|
||||
|
||||
await _init_test_db()
|
||||
await _truncate_all_tables()
|
||||
yield
|
||||
await _truncate_all_tables()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan_manager() -> AsyncGenerator[None, None]:
|
||||
lifespan = app.router.lifespan_context
|
||||
if lifespan is None:
|
||||
yield
|
||||
else:
|
||||
async with lifespan(app):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def api_client(clean_database: None) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""HTTPX async client bound to the FastAPI app with lifespan support."""
|
||||
|
||||
async with _lifespan_manager():
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_admin(clean_database: None) -> AsyncGenerator[asyncpg.Connection, None]:
|
||||
"""Plain connection for arranging/inspecting API test data (no rollback)."""
|
||||
|
||||
test_dsn = os.environ["DATABASE_URL"]
|
||||
conn = await asyncpg.connect(test_dsn)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
80
tests/db/test_get_conn.py
Normal file
80
tests/db/test_get_conn.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Tests for the get_conn dependency helper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import db, get_conn
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class _FakeConnection:
|
||||
def __init__(self, idx: int) -> None:
|
||||
self.idx = idx
|
||||
|
||||
|
||||
class _AcquireContext:
|
||||
def __init__(self, conn: _FakeConnection, tracker: "_FakePool") -> None:
|
||||
self._conn = conn
|
||||
self._tracker = tracker
|
||||
|
||||
async def __aenter__(self) -> _FakeConnection:
|
||||
self._tracker.active += 1
|
||||
return self._conn
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
self._tracker.active -= 1
|
||||
|
||||
|
||||
class _FakePool:
|
||||
def __init__(self) -> None:
|
||||
self.acquire_calls = 0
|
||||
self.active = 0
|
||||
|
||||
def acquire(self) -> _AcquireContext:
|
||||
conn = _FakeConnection(self.acquire_calls)
|
||||
self.acquire_calls += 1
|
||||
return _AcquireContext(conn, self)
|
||||
|
||||
|
||||
async def _collect_single_connection():
|
||||
connection = None
|
||||
async for conn in get_conn():
|
||||
connection = conn
|
||||
return connection
|
||||
|
||||
|
||||
async def test_get_conn_reuses_connection_within_scope():
|
||||
original_pool = db.pool
|
||||
fake_pool = _FakePool()
|
||||
db.pool = fake_pool
|
||||
try:
|
||||
captured: list[_FakeConnection] = []
|
||||
|
||||
async for outer in get_conn():
|
||||
captured.append(outer)
|
||||
async for inner in get_conn():
|
||||
captured.append(inner)
|
||||
|
||||
assert len(captured) == 2
|
||||
assert captured[0] is captured[1]
|
||||
assert fake_pool.acquire_calls == 1
|
||||
finally:
|
||||
db.pool = original_pool
|
||||
|
||||
|
||||
async def test_get_conn_acquires_new_connection_per_root_scope():
|
||||
original_pool = db.pool
|
||||
fake_pool = _FakePool()
|
||||
db.pool = fake_pool
|
||||
try:
|
||||
first = await _collect_single_connection()
|
||||
second = await _collect_single_connection()
|
||||
|
||||
assert first is not None and second is not None
|
||||
assert first is not second
|
||||
assert fake_pool.acquire_calls == 2
|
||||
finally:
|
||||
db.pool = original_pool
|
||||
260
tests/services/test_auth_service.py
Normal file
260
tests/services/test_auth_service.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""Unit tests covering AuthService flows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.api.deps import CurrentUser
|
||||
from app.core import security
|
||||
from app.db import Database
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
LogoutRequest,
|
||||
RefreshRequest,
|
||||
RegisterRequest,
|
||||
SwitchOrgRequest,
|
||||
)
|
||||
from app.services.auth import AuthService
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class _SingleConnectionDatabase(Database):
|
||||
"""Database stub that reuses a single asyncpg connection."""
|
||||
|
||||
def __init__(self, conn) -> None: # type: ignore[override]
|
||||
self._conn = conn
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self): # type: ignore[override]
|
||||
yield self._conn
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self): # type: ignore[override]
|
||||
tr = self._conn.transaction()
|
||||
await tr.start()
|
||||
try:
|
||||
yield self._conn
|
||||
except Exception:
|
||||
await tr.rollback()
|
||||
raise
|
||||
else:
|
||||
await tr.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_service(db_conn):
|
||||
"""AuthService bound to the per-test database connection."""
|
||||
|
||||
return AuthService(database=_SingleConnectionDatabase(db_conn))
|
||||
|
||||
|
||||
async def _create_user(conn, email: str, password: str) -> UUID:
|
||||
user_id = uuid4()
|
||||
password_hash = security.hash_password(password)
|
||||
await conn.execute(
|
||||
"INSERT INTO users (id, email, password_hash) VALUES ($1, $2, $3)",
|
||||
user_id,
|
||||
email,
|
||||
password_hash,
|
||||
)
|
||||
return user_id
|
||||
|
||||
|
||||
async def _create_org(
|
||||
conn,
|
||||
name: str,
|
||||
slug: str | None = None,
|
||||
*,
|
||||
created_at: datetime | None = None,
|
||||
) -> UUID:
|
||||
org_id = uuid4()
|
||||
slug_value = slug or f"{name.lower().replace(' ', '-')}-{org_id.hex[:6]}"
|
||||
created = created_at or datetime.now(UTC)
|
||||
await conn.execute(
|
||||
"INSERT INTO orgs (id, name, slug, created_at) VALUES ($1, $2, $3, $4)",
|
||||
org_id,
|
||||
name,
|
||||
slug_value,
|
||||
created,
|
||||
)
|
||||
return org_id
|
||||
|
||||
|
||||
async def _add_membership(conn, user_id: UUID, org_id: UUID, role: str) -> None:
|
||||
await conn.execute(
|
||||
"INSERT INTO org_members (id, user_id, org_id, role) VALUES ($1, $2, $3, $4)",
|
||||
uuid4(),
|
||||
user_id,
|
||||
org_id,
|
||||
role,
|
||||
)
|
||||
|
||||
|
||||
async def test_register_user_creates_admin_membership(auth_service, db_conn):
|
||||
request = RegisterRequest(
|
||||
email="founder@example.com",
|
||||
password="SuperSecret1!",
|
||||
org_name="Founders Inc",
|
||||
)
|
||||
|
||||
response = await auth_service.register_user(request)
|
||||
|
||||
payload = security.decode_access_token(response.access_token)
|
||||
assert payload["org_role"] == "admin"
|
||||
|
||||
user_id = UUID(payload["sub"])
|
||||
org_id = UUID(payload["org_id"])
|
||||
|
||||
user = await db_conn.fetchrow("SELECT email FROM users WHERE id = $1", user_id)
|
||||
assert user is not None and user["email"] == request.email
|
||||
|
||||
membership = await db_conn.fetchrow(
|
||||
"SELECT role FROM org_members WHERE user_id = $1 AND org_id = $2",
|
||||
user_id,
|
||||
org_id,
|
||||
)
|
||||
assert membership is not None and membership["role"] == "admin"
|
||||
|
||||
refresh_hash = security.hash_token(response.refresh_token)
|
||||
refresh_row = await db_conn.fetchrow(
|
||||
"SELECT user_id, active_org_id FROM refresh_tokens WHERE token_hash = $1",
|
||||
refresh_hash,
|
||||
)
|
||||
assert refresh_row is not None
|
||||
assert refresh_row["user_id"] == user_id
|
||||
assert refresh_row["active_org_id"] == org_id
|
||||
|
||||
|
||||
async def test_login_user_returns_tokens_for_valid_credentials(auth_service, db_conn):
|
||||
email = "member@example.com"
|
||||
password = "Password123!"
|
||||
user_id = await _create_user(db_conn, email, password)
|
||||
org_id = await _create_org(
|
||||
db_conn,
|
||||
name="Member Org",
|
||||
slug="member-org",
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
await _add_membership(db_conn, user_id, org_id, "member")
|
||||
|
||||
response = await auth_service.login_user(LoginRequest(email=email, password=password))
|
||||
|
||||
payload = security.decode_access_token(response.access_token)
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["org_id"] == str(org_id)
|
||||
|
||||
refresh_hash = security.hash_token(response.refresh_token)
|
||||
refresh_row = await db_conn.fetchrow(
|
||||
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
|
||||
refresh_hash,
|
||||
)
|
||||
assert refresh_row is not None and refresh_row["active_org_id"] == org_id
|
||||
|
||||
|
||||
async def test_refresh_tokens_rotates_existing_token(auth_service, db_conn):
|
||||
email = "rotate@example.com"
|
||||
password = "Rotate123!"
|
||||
user_id = await _create_user(db_conn, email, password)
|
||||
org_id = await _create_org(db_conn, name="Rotate Org", slug="rotate-org")
|
||||
await _add_membership(db_conn, user_id, org_id, "member")
|
||||
|
||||
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
|
||||
|
||||
rotated = await auth_service.refresh_tokens(
|
||||
RefreshRequest(refresh_token=initial.refresh_token)
|
||||
)
|
||||
|
||||
assert rotated.refresh_token != initial.refresh_token
|
||||
|
||||
old_hash = security.hash_token(initial.refresh_token)
|
||||
old_row = await db_conn.fetchrow(
|
||||
"SELECT rotated_to FROM refresh_tokens WHERE token_hash = $1",
|
||||
old_hash,
|
||||
)
|
||||
assert old_row is not None and old_row["rotated_to"] is not None
|
||||
|
||||
new_hash = security.hash_token(rotated.refresh_token)
|
||||
new_row = await db_conn.fetchrow(
|
||||
"SELECT user_id FROM refresh_tokens WHERE token_hash = $1",
|
||||
new_hash,
|
||||
)
|
||||
assert new_row is not None and new_row["user_id"] == user_id
|
||||
|
||||
|
||||
async def test_switch_org_updates_active_org(auth_service, db_conn):
|
||||
email = "switcher@example.com"
|
||||
password = "Switch123!"
|
||||
user_id = await _create_user(db_conn, email, password)
|
||||
|
||||
primary_org = await _create_org(
|
||||
db_conn,
|
||||
name="Primary Org",
|
||||
slug="primary-org",
|
||||
created_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
await _add_membership(db_conn, user_id, primary_org, "member")
|
||||
|
||||
secondary_org = await _create_org(
|
||||
db_conn,
|
||||
name="Secondary Org",
|
||||
slug="secondary-org",
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
await _add_membership(db_conn, user_id, secondary_org, "admin")
|
||||
|
||||
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
|
||||
current_user = CurrentUser(
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
org_id=primary_org,
|
||||
org_role="member",
|
||||
token=initial.access_token,
|
||||
)
|
||||
|
||||
switched = await auth_service.switch_org(
|
||||
current_user,
|
||||
SwitchOrgRequest(org_id=secondary_org, refresh_token=initial.refresh_token),
|
||||
)
|
||||
|
||||
payload = security.decode_access_token(switched.access_token)
|
||||
assert payload["org_id"] == str(secondary_org)
|
||||
assert payload["org_role"] == "admin"
|
||||
|
||||
new_hash = security.hash_token(switched.refresh_token)
|
||||
new_row = await db_conn.fetchrow(
|
||||
"SELECT active_org_id FROM refresh_tokens WHERE token_hash = $1",
|
||||
new_hash,
|
||||
)
|
||||
assert new_row is not None and new_row["active_org_id"] == secondary_org
|
||||
|
||||
|
||||
async def test_logout_revokes_refresh_token(auth_service, db_conn):
|
||||
email = "logout@example.com"
|
||||
password = "Logout123!"
|
||||
user_id = await _create_user(db_conn, email, password)
|
||||
org_id = await _create_org(db_conn, name="Logout Org", slug="logout-org")
|
||||
await _add_membership(db_conn, user_id, org_id, "member")
|
||||
|
||||
initial = await auth_service.login_user(LoginRequest(email=email, password=password))
|
||||
current_user = CurrentUser(
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
org_id=org_id,
|
||||
org_role="member",
|
||||
token=initial.access_token,
|
||||
)
|
||||
|
||||
await auth_service.logout(current_user, LogoutRequest(refresh_token=initial.refresh_token))
|
||||
|
||||
token_hash = security.hash_token(initial.refresh_token)
|
||||
row = await db_conn.fetchrow(
|
||||
"SELECT revoked_at FROM refresh_tokens WHERE token_hash = $1",
|
||||
token_hash,
|
||||
)
|
||||
assert row is not None and row["revoked_at"] is not None
|
||||
Reference in New Issue
Block a user