"""Shared pytest fixtures for all tests.""" from __future__ import annotations import os from contextlib import asynccontextmanager from typing import AsyncGenerator, Callable, Generator from uuid import UUID, uuid4 import asyncpg import httpx import pytest # Set test environment variables before importing app modules os.environ.setdefault("DATABASE_URL", "postgresql://incidentops:incidentops@localhost:5432/incidentops_test") os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-testing-only") os.environ.setdefault("REDIS_URL", "redis://localhost:6379/1") os.environ.setdefault("TASK_QUEUE_DRIVER", "inmemory") os.environ.setdefault("TASK_QUEUE_BROKER_URL", "redis://localhost:6379/2") from app.main import app from app.taskqueue import task_queue # Module-level setup: create database and run migrations once _db_initialized = False async def _init_test_db() -> None: """Initialize test database and run migrations (once per session).""" global _db_initialized if _db_initialized: return admin_dsn = os.environ["DATABASE_URL"].rsplit("/", 1)[0] + "/postgres" test_db_name = "incidentops_test" admin_conn = await asyncpg.connect(admin_dsn) try: # Terminate existing connections to the test database await admin_conn.execute(f""" SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = '{test_db_name}' AND pid <> pg_backend_pid() """) # Drop and recreate test database await admin_conn.execute(f"DROP DATABASE IF EXISTS {test_db_name}") await admin_conn.execute(f"CREATE DATABASE {test_db_name}") finally: await admin_conn.close() # Connect to test database and run migrations test_dsn = os.environ["DATABASE_URL"] conn = await asyncpg.connect(test_dsn) try: # Run migrations migrations_dir = os.path.join(os.path.dirname(__file__), "..", "migrations") migration_files = sorted( f for f in os.listdir(migrations_dir) if f.endswith(".sql") ) for migration_file in migration_files: migration_path = os.path.join(migrations_dir, migration_file) with open(migration_path) as f: sql = f.read() await conn.execute(sql) finally: await conn.close() _db_initialized = True @pytest.fixture async def db_conn() -> AsyncGenerator[asyncpg.Connection, None]: """Get a database connection with transaction rollback for test isolation.""" await _init_test_db() test_dsn = os.environ["DATABASE_URL"] conn = await asyncpg.connect(test_dsn) # Start a transaction that will be rolled back tr = conn.transaction() await tr.start() try: yield conn finally: await tr.rollback() await conn.close() @pytest.fixture def make_user_id() -> Callable[[], UUID]: """Factory for generating user IDs.""" return lambda: uuid4() @pytest.fixture def make_org_id() -> Callable[[], UUID]: """Factory for generating org IDs.""" return lambda: uuid4() TABLES_TO_TRUNCATE = [ "incident_events", "notification_attempts", "incidents", "notification_targets", "services", "refresh_tokens", "org_members", "orgs", "users", ] async def _truncate_all_tables() -> None: test_dsn = os.environ["DATABASE_URL"] conn = await asyncpg.connect(test_dsn) try: tables = ", ".join(TABLES_TO_TRUNCATE) await conn.execute(f"TRUNCATE TABLE {tables} CASCADE") finally: await conn.close() @pytest.fixture async def clean_database() -> AsyncGenerator[None, None]: """Ensure the database is initialized and truncated before/after tests.""" await _init_test_db() await _truncate_all_tables() yield await _truncate_all_tables() @asynccontextmanager async def _lifespan_manager() -> AsyncGenerator[None, None]: lifespan = app.router.lifespan_context if lifespan is None: yield else: async with lifespan(app): yield @pytest.fixture async def api_client(clean_database: None) -> AsyncGenerator[httpx.AsyncClient, None]: """HTTPX async client bound to the FastAPI app with lifespan support.""" async with _lifespan_manager(): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: yield client @pytest.fixture async def db_admin(clean_database: None) -> AsyncGenerator[asyncpg.Connection, None]: """Plain connection for arranging/inspecting API test data (no rollback).""" test_dsn = os.environ["DATABASE_URL"] conn = await asyncpg.connect(test_dsn) try: yield conn finally: await conn.close() @pytest.fixture(autouse=True) def reset_task_queue() -> Generator[None, None, None]: """Ensure in-memory task queue state is cleared between tests.""" if hasattr(task_queue, "reset"): task_queue.reset() yield if hasattr(task_queue, "reset"): task_queue.reset()