feat(auth): implement auth stack
This commit is contained in:
@@ -3,9 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from uuid import uuid4
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Callable
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import asyncpg
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Set test environment variables before importing app modules
|
||||
@@ -13,6 +16,8 @@ os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@loca
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only")
|
||||
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1")
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
# Module-level setup: create database and run migrations once
|
||||
_db_initialized = False
|
||||
@@ -65,7 +70,7 @@ async def _init_test_db() -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
async def db_conn() -> AsyncGenerator[asyncpg.Connection, None]:
|
||||
"""Get a database connection with transaction rollback for test isolation."""
|
||||
await _init_test_db()
|
||||
|
||||
@@ -84,12 +89,77 @@ async def db_conn() -> asyncpg.Connection:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_user_id() -> uuid4:
|
||||
def make_user_id() -> Callable[[], UUID]:
|
||||
"""Factory for generating user IDs."""
|
||||
return lambda: uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_org_id() -> uuid4:
|
||||
def make_org_id() -> Callable[[], UUID]:
|
||||
"""Factory for generating org IDs."""
|
||||
return lambda: uuid4()
|
||||
|
||||
|
||||
TABLES_TO_TRUNCATE = [
|
||||
"incident_events",
|
||||
"notification_attempts",
|
||||
"incidents",
|
||||
"notification_targets",
|
||||
"services",
|
||||
"refresh_tokens",
|
||||
"org_members",
|
||||
"orgs",
|
||||
"users",
|
||||
]
|
||||
|
||||
|
||||
async def _truncate_all_tables() -> None:
|
||||
test_dsn = os.environ["DATABASE_URL"]
|
||||
conn = await asyncpg.connect(test_dsn)
|
||||
try:
|
||||
tables = ", ".join(TABLES_TO_TRUNCATE)
|
||||
await conn.execute(f"TRUNCATE TABLE {tables} CASCADE")
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def clean_database() -> AsyncGenerator[None, None]:
|
||||
"""Ensure the database is initialized and truncated before/after tests."""
|
||||
|
||||
await _init_test_db()
|
||||
await _truncate_all_tables()
|
||||
yield
|
||||
await _truncate_all_tables()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan_manager() -> AsyncGenerator[None, None]:
|
||||
lifespan = app.router.lifespan_context
|
||||
if lifespan is None:
|
||||
yield
|
||||
else:
|
||||
async with lifespan(app):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def api_client(clean_database: None) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""HTTPX async client bound to the FastAPI app with lifespan support."""
|
||||
|
||||
async with _lifespan_manager():
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_admin(clean_database: None) -> AsyncGenerator[asyncpg.Connection, None]:
|
||||
"""Plain connection for arranging/inspecting API test data (no rollback)."""
|
||||
|
||||
test_dsn = os.environ["DATABASE_URL"]
|
||||
conn = await asyncpg.connect(test_dsn)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
Reference in New Issue
Block a user