"""End-to-end Celery worker tests against the real Redis broker.""" from __future__ import annotations import asyncio import inspect from uuid import UUID, uuid4 import asyncpg import pytest import redis from app.config import settings from app.repositories.incident import IncidentRepository from app.taskqueue import CeleryTaskQueue from celery.contrib.testing.worker import start_worker from worker.celery_app import celery_app pytestmark = pytest.mark.asyncio @pytest.fixture(scope="module", autouse=True) def ensure_redis_available() -> None: """Skip the module if the configured Redis broker is unreachable.""" client = redis.Redis.from_url(settings.resolved_task_queue_broker_url) try: client.ping() except redis.RedisError as exc: # pragma: no cover - diagnostic-only path pytest.skip(f"Redis broker unavailable: {exc}") finally: client.close() @pytest.fixture(scope="module") def celery_worker_instance(ensure_redis_available: None): """Run a real Celery worker connected to Redis for the duration of the module.""" queues = [settings.task_queue_default_queue, settings.task_queue_critical_queue] with start_worker( celery_app, loglevel="INFO", pool="solo", concurrency=1, queues=queues, perform_ping_check=False, ): yield @pytest.fixture(autouse=True) def purge_celery_queues(): """Clear any pending tasks before and after each test for isolation.""" celery_app.control.purge() yield celery_app.control.purge() @pytest.fixture def celery_queue() -> CeleryTaskQueue: return CeleryTaskQueue( default_queue=settings.task_queue_default_queue, critical_queue=settings.task_queue_critical_queue, ) async def _seed_incident_with_target(conn: asyncpg.Connection) -> tuple[UUID, UUID]: org_id = uuid4() service_id = uuid4() incident_id = uuid4() target_id = uuid4() await conn.execute( "INSERT INTO orgs (id, name, slug) VALUES ($1, $2, $3)", org_id, "Celery Org", f"celery-{org_id.hex[:6]}", ) await conn.execute( "INSERT INTO services (id, org_id, name, slug) VALUES ($1, $2, $3, $4)", service_id, org_id, "API", f"svc-{service_id.hex[:6]}", ) repo = IncidentRepository(conn) await repo.create( incident_id=incident_id, org_id=org_id, service_id=service_id, title="Latency spike", description="", severity="high", ) await conn.execute( """ INSERT INTO notification_targets (id, org_id, name, target_type, webhook_url, enabled) VALUES ($1, $2, $3, $4, $5, $6) """, target_id, org_id, "Primary Webhook", "webhook", "https://example.com/hook", True, ) return org_id, incident_id async def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.1) -> None: deadline = asyncio.get_running_loop().time() + timeout while True: result = predicate() if inspect.isawaitable(result): result = await result if result: return if asyncio.get_running_loop().time() >= deadline: raise AssertionError("Timed out waiting for Celery worker to finish") await asyncio.sleep(interval) async def _attempt_sent(conn: asyncpg.Connection, incident_id: UUID) -> bool: row = await conn.fetchrow( "SELECT status FROM notification_attempts WHERE incident_id = $1", incident_id, ) return bool(row and row["status"] == "sent") async def _attempt_count(conn: asyncpg.Connection, incident_id: UUID) -> int: count = await conn.fetchval( "SELECT COUNT(*) FROM notification_attempts WHERE incident_id = $1", incident_id, ) return int(count or 0) async def _attempt_count_is(conn: asyncpg.Connection, incident_id: UUID, expected: int) -> bool: return await _attempt_count(conn, incident_id) == expected async def test_incident_triggered_task_marks_attempt_sent( db_admin: asyncpg.Connection, celery_worker_instance: None, celery_queue: CeleryTaskQueue, ) -> None: org_id, incident_id = await _seed_incident_with_target(db_admin) celery_queue.incident_triggered( incident_id=incident_id, org_id=org_id, triggered_by=uuid4(), ) await _wait_until(lambda: _attempt_sent(db_admin, incident_id)) async def test_escalate_task_refires_when_incident_still_triggered( db_admin: asyncpg.Connection, celery_worker_instance: None, celery_queue: CeleryTaskQueue, ) -> None: org_id, incident_id = await _seed_incident_with_target(db_admin) celery_queue.schedule_escalation_check( incident_id=incident_id, org_id=org_id, delay_seconds=0, ) await _wait_until(lambda: _attempt_count_is(db_admin, incident_id, 1)) async def test_escalate_task_skips_when_incident_acknowledged( db_admin: asyncpg.Connection, celery_worker_instance: None, celery_queue: CeleryTaskQueue, ) -> None: org_id, incident_id = await _seed_incident_with_target(db_admin) await db_admin.execute( "UPDATE incidents SET status = 'acknowledged' WHERE id = $1", incident_id, ) celery_queue.schedule_escalation_check( incident_id=incident_id, org_id=org_id, delay_seconds=0, ) await asyncio.sleep(1) assert await _attempt_count(db_admin, incident_id) == 0