feat(auth): implement auth stack

This commit is contained in:
2025-12-29 09:55:30 +00:00
parent 3170f10e86
commit ad94833830
13 changed files with 1199 additions and 11 deletions

101
app/api/deps.py Normal file
View File

@@ -0,0 +1,101 @@
"""Shared FastAPI dependencies (auth, RBAC, ownership)."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable
from uuid import UUID
from fastapi import Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.core import exceptions as exc, security
from app.db import db
from app.repositories import OrgRepository, UserRepository
bearer_scheme = HTTPBearer(auto_error=False)
ROLE_RANKS: dict[str, int] = {"viewer": 0, "member": 1, "admin": 2}
@dataclass(slots=True)
class CurrentUser:
"""Authenticated user context derived from the access token."""
user_id: UUID
email: str
org_id: UUID
org_role: str
token: str
async def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
) -> CurrentUser:
"""Extract and validate the current user from the Authorization header."""
if credentials is None or credentials.scheme.lower() != "bearer":
raise exc.UnauthorizedError("Missing bearer token")
try:
payload = security.TokenPayload(security.decode_access_token(credentials.credentials))
except security.JWTError as err: # pragma: no cover - jose error types
raise exc.UnauthorizedError("Invalid access token") from err
async with db.connection() as conn:
user_repo = UserRepository(conn)
user = await user_repo.get_by_id(payload.user_id)
if user is None:
raise exc.UnauthorizedError("User not found")
org_repo = OrgRepository(conn)
membership = await org_repo.get_member(payload.user_id, payload.org_id)
if membership is None:
raise exc.ForbiddenError("Organization access denied")
return CurrentUser(
user_id=payload.user_id,
email=user["email"],
org_id=payload.org_id,
org_role=membership["role"],
token=credentials.credentials,
)
class RoleChecker:
"""Dependency that enforces a minimum organization role."""
def __init__(self, minimum_role: str) -> None:
if minimum_role not in ROLE_RANKS:
raise ValueError(f"Unknown role '{minimum_role}'")
self.minimum_role = minimum_role
def __call__(self, current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
if ROLE_RANKS[current_user.org_role] < ROLE_RANKS[self.minimum_role]:
raise exc.ForbiddenError("Insufficient role for this operation")
return current_user
def require_role(min_role: str) -> Callable[[CurrentUser], CurrentUser]:
"""Factory that returns a dependency enforcing the specified role."""
return RoleChecker(min_role)
def ensure_org_access(resource_org_id: UUID, current_user: CurrentUser) -> None:
"""Verify that the resource belongs to the active org in the token."""
if resource_org_id != current_user.org_id:
raise exc.ForbiddenError("Resource does not belong to the active organization")
__all__ = [
"CurrentUser",
"ROLE_RANKS",
"RoleChecker",
"bearer_scheme",
"ensure_org_access",
"get_current_user",
"require_role",
]

59
app/api/v1/auth.py Normal file
View File

@@ -0,0 +1,59 @@
"""Authentication API endpoints."""
from fastapi import APIRouter, Depends, status
from app.api.deps import CurrentUser, get_current_user
from app.schemas.auth import (
LoginRequest,
LogoutRequest,
RefreshRequest,
RegisterRequest,
SwitchOrgRequest,
TokenResponse,
)
from app.services import AuthService
router = APIRouter(prefix="/auth", tags=["auth"])
auth_service = AuthService()
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
async def register_user(payload: RegisterRequest) -> TokenResponse:
"""Register a new user and default org, returning auth tokens."""
return await auth_service.register_user(payload)
@router.post("/login", response_model=TokenResponse)
async def login_user(payload: LoginRequest) -> TokenResponse:
"""Authenticate an existing user and issue tokens."""
return await auth_service.login_user(payload)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_tokens(payload: RefreshRequest) -> TokenResponse:
"""Rotate refresh token and mint a new access token."""
return await auth_service.refresh_tokens(payload)
@router.post("/switch-org", response_model=TokenResponse)
async def switch_org(
payload: SwitchOrgRequest,
current_user: CurrentUser = Depends(get_current_user),
) -> TokenResponse:
"""Switch the active organization for the authenticated user."""
return await auth_service.switch_org(current_user, payload)
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
async def logout(
payload: LogoutRequest,
current_user: CurrentUser = Depends(get_current_user),
) -> None:
"""Revoke the provided refresh token for the current session."""
await auth_service.logout(current_user, payload)

View File

@@ -2,8 +2,10 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from contextvars import ContextVar
import asyncpg
from asyncpg.pool import PoolConnectionProxy
import redis.asyncio as redis
@@ -27,7 +29,7 @@ class Database:
await self.pool.close()
@asynccontextmanager
async def connection(self) -> AsyncGenerator[asyncpg.Connection, None]:
async def connection(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
"""Acquire a connection from the pool."""
if not self.pool:
raise RuntimeError("Database not connected")
@@ -35,7 +37,7 @@ class Database:
yield conn
@asynccontextmanager
async def transaction(self) -> AsyncGenerator[asyncpg.Connection, None]:
async def transaction(self) -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
"""Acquire a connection with an active transaction."""
if not self.pool:
raise RuntimeError("Database not connected")
@@ -74,7 +76,26 @@ db = Database()
redis_client = RedisClient()
async def get_conn() -> AsyncGenerator[asyncpg.Connection, None]:
"""Dependency for getting a database connection."""
async with db.connection() as conn:
yield conn
_connection_ctx: ContextVar[asyncpg.Connection | PoolConnectionProxy | None] = ContextVar(
"db_connection",
default=None,
)
async def get_conn() -> AsyncGenerator[asyncpg.Connection | PoolConnectionProxy, None]:
"""Dependency that reuses the same DB connection within a request context."""
existing_conn = _connection_ctx.get()
if existing_conn is not None:
yield existing_conn
return
if not db.pool:
raise RuntimeError("Database not connected")
async with db.pool.acquire() as conn:
token = _connection_ctx.set(conn)
try:
yield conn
finally:
_connection_ctx.reset(token)

View File

@@ -4,8 +4,9 @@ from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from app.api.v1 import health
from app.api.v1 import auth, health
from app.config import settings
from app.db import db, redis_client
@@ -26,8 +27,44 @@ app = FastAPI(
title="IncidentOps",
description="Incident management API with multi-tenant org support",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
lifespan=lifespan,
)
app.openapi_tags = [
{"name": "auth", "description": "Registration, login, token lifecycle"},
{"name": "health", "description": "Service health probes"},
]
def custom_openapi() -> dict:
"""Add JWT bearer security scheme to the generated OpenAPI schema."""
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
security_schemes = openapi_schema.setdefault("components", {}).setdefault("securitySchemes", {})
security_schemes["BearerToken"] = {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "Paste the JWT access token returned by /auth endpoints",
}
openapi_schema["security"] = [{"BearerToken": []}]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore[assignment]
# Include routers
app.include_router(auth.router, prefix=settings.api_v1_prefix)
app.include_router(health.router, prefix=settings.api_v1_prefix, tags=["health"])

View File

@@ -2,6 +2,7 @@
from app.schemas.auth import (
LoginRequest,
LogoutRequest,
RefreshRequest,
RegisterRequest,
SwitchOrgRequest,
@@ -27,6 +28,7 @@ from app.schemas.org import (
__all__ = [
# Auth
"LoginRequest",
"LogoutRequest",
"RefreshRequest",
"RegisterRequest",
"SwitchOrgRequest",

View File

@@ -33,6 +33,12 @@ class SwitchOrgRequest(BaseModel):
refresh_token: str
class LogoutRequest(BaseModel):
"""Request body for logging out and revoking a refresh token."""
refresh_token: str
class TokenResponse(BaseModel):
"""Response containing access and refresh tokens."""

5
app/services/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Service layer entrypoints."""
from app.services.auth import AuthService
__all__ = ["AuthService"]

269
app/services/auth.py Normal file
View File

@@ -0,0 +1,269 @@
"""Authentication service providing business logic for auth flows."""
from __future__ import annotations
import re
from typing import cast
from uuid import UUID, uuid4
import asyncpg
from asyncpg.pool import PoolConnectionProxy
from app.api.deps import CurrentUser
from app.config import settings
from app.core import exceptions as exc, security
from app.db import Database, db
from app.repositories import OrgRepository, RefreshTokenRepository, UserRepository
from app.schemas.auth import (
LoginRequest,
LogoutRequest,
RefreshRequest,
RegisterRequest,
SwitchOrgRequest,
TokenResponse,
)
_SLUG_PATTERN = re.compile(r"[^a-z0-9]+")
def _as_conn(conn: asyncpg.Connection | PoolConnectionProxy) -> asyncpg.Connection:
"""Helper to satisfy typing when a pool proxy is returned."""
return cast(asyncpg.Connection, conn)
class AuthService:
"""Encapsulates authentication workflows (register/login/refresh/logout)."""
def __init__(self, database: Database | None = None) -> None:
self.db = database or db
self._access_token_expires_in = settings.access_token_expire_minutes * 60
async def register_user(self, data: RegisterRequest) -> TokenResponse:
"""Create a new user, default org, membership, and token pair."""
async with self.db.transaction() as conn:
db_conn = _as_conn(conn)
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
refresh_repo = RefreshTokenRepository(db_conn)
if await user_repo.exists_by_email(data.email):
raise exc.ConflictError("Email already registered")
user_id = uuid4()
org_id = uuid4()
member_id = uuid4()
password_hash = security.hash_password(data.password)
await user_repo.create(user_id, data.email, password_hash)
slug = await self._generate_unique_org_slug(org_repo, data.org_name)
await org_repo.create(org_id, data.org_name, slug)
await org_repo.add_member(member_id, user_id, org_id, "admin")
return await self._issue_token_pair(
refresh_repo,
user_id=user_id,
org_id=org_id,
role="admin",
)
async def login_user(self, data: LoginRequest) -> TokenResponse:
"""Authenticate a user and issue tokens for their first organization."""
async with self.db.connection() as conn:
db_conn = _as_conn(conn)
user_repo = UserRepository(db_conn)
org_repo = OrgRepository(db_conn)
refresh_repo = RefreshTokenRepository(db_conn)
user = await user_repo.get_by_email(data.email)
if not user or not security.verify_password(data.password, user["password_hash"]):
raise exc.UnauthorizedError("Invalid email or password")
orgs = await org_repo.get_user_orgs(user["id"])
if not orgs:
raise exc.ForbiddenError("User does not belong to any organization")
active_org = orgs[0]
return await self._issue_token_pair(
refresh_repo,
user_id=user["id"],
org_id=active_org["id"],
role=active_org["role"],
)
async def refresh_tokens(self, data: RefreshRequest) -> TokenResponse:
"""Rotate refresh token and mint a new access token."""
old_hash = security.hash_token(data.refresh_token)
new_refresh_token = security.generate_refresh_token()
new_refresh_hash = security.hash_token(new_refresh_token)
new_refresh_id = uuid4()
new_refresh_expiry = security.get_refresh_token_expiry()
rotated: dict | None = None
membership: dict | None = None
async with self.db.transaction() as conn:
db_conn = _as_conn(conn)
refresh_repo = RefreshTokenRepository(db_conn)
rotated = await refresh_repo.rotate(
old_token_hash=old_hash,
new_token_id=new_refresh_id,
new_token_hash=new_refresh_hash,
new_expires_at=new_refresh_expiry,
)
if rotated is not None:
org_repo = OrgRepository(db_conn)
membership = await org_repo.get_member(rotated["user_id"], rotated["active_org_id"])
if membership is None:
raise exc.UnauthorizedError("Invalid refresh token")
if rotated is None or membership is None:
await self._handle_invalid_refresh(old_hash)
assert rotated is not None and membership is not None
access_token = security.create_access_token(
sub=str(rotated["user_id"]),
org_id=str(rotated["active_org_id"]),
org_role=membership["role"],
)
return TokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
expires_in=self._access_token_expires_in,
)
async def switch_org(
self,
current_user: CurrentUser,
data: SwitchOrgRequest,
) -> TokenResponse:
"""Switch active organization (rotates refresh token + issues new JWT)."""
target_org_id = data.org_id
old_hash = security.hash_token(data.refresh_token)
new_refresh_token = security.generate_refresh_token()
new_refresh_hash = security.hash_token(new_refresh_token)
new_refresh_expiry = security.get_refresh_token_expiry()
rotated: dict | None = None
membership: dict | None = None
async with self.db.transaction() as conn:
db_conn = _as_conn(conn)
org_repo = OrgRepository(db_conn)
membership = await org_repo.get_member(current_user.user_id, target_org_id)
if membership is None:
raise exc.ForbiddenError("Not a member of the requested organization")
refresh_repo = RefreshTokenRepository(db_conn)
rotated = await refresh_repo.rotate(
old_token_hash=old_hash,
new_token_id=uuid4(),
new_token_hash=new_refresh_hash,
new_expires_at=new_refresh_expiry,
new_active_org_id=target_org_id,
expected_user_id=current_user.user_id,
)
if rotated is None:
await self._handle_invalid_refresh(old_hash)
access_token = security.create_access_token(
sub=str(current_user.user_id),
org_id=str(target_org_id),
org_role=membership["role"],
)
return TokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
expires_in=self._access_token_expires_in,
)
async def logout(self, current_user: CurrentUser, data: LogoutRequest) -> None:
"""Revoke the provided refresh token for the current session."""
token_hash = security.hash_token(data.refresh_token)
async with self.db.transaction() as conn:
refresh_repo = RefreshTokenRepository(_as_conn(conn))
token = await refresh_repo.get_by_hash(token_hash)
if token and token["user_id"] != current_user.user_id:
raise exc.ForbiddenError("Refresh token does not belong to this user")
if not token:
return
await refresh_repo.revoke(token["id"])
async def _issue_token_pair(
self,
refresh_repo: RefreshTokenRepository,
*,
user_id: UUID,
org_id: UUID,
role: str,
) -> TokenResponse:
"""Create access/refresh tokens and persist the refresh token."""
access_token = security.create_access_token(
sub=str(user_id),
org_id=str(org_id),
org_role=role,
)
refresh_token = security.generate_refresh_token()
await refresh_repo.create(
token_id=uuid4(),
user_id=user_id,
token_hash=security.hash_token(refresh_token),
active_org_id=org_id,
expires_at=security.get_refresh_token_expiry(),
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=self._access_token_expires_in,
)
async def _handle_invalid_refresh(self, token_hash: str) -> None:
"""Raise appropriate errors for invalid/compromised refresh tokens."""
async with self.db.connection() as conn:
refresh_repo = RefreshTokenRepository(_as_conn(conn))
reused = await refresh_repo.check_token_reuse(token_hash)
if reused:
await refresh_repo.revoke_token_chain(reused["id"])
raise exc.UnauthorizedError("Refresh token reuse detected")
raise exc.UnauthorizedError("Invalid refresh token")
async def _generate_unique_org_slug(
self,
org_repo: OrgRepository,
org_name: str,
) -> str:
"""Slugify the org name and append a counter until unique."""
base_slug = self._slugify(org_name)
candidate = base_slug
counter = 1
while await org_repo.slug_exists(candidate):
suffix = f"-{counter}"
max_base_len = 50 - len(suffix)
candidate = f"{base_slug[:max_base_len]}{suffix}"
counter += 1
return candidate
def _slugify(self, value: str) -> str:
"""Convert arbitrary text into a URL-friendly slug."""
slug = _SLUG_PATTERN.sub("-", value.strip().lower()).strip("-")
return slug[:50] or "org"