feat: add observability stack and background task infrastructure

Add OpenTelemetry instrumentation with distributed tracing and metrics:
- Structured JSON logging with trace context correlation
- Auto-instrumentation for FastAPI, asyncpg, httpx, redis
- OTLP exporter for traces and Prometheus metrics endpoint

Implement Celery worker and notification task system:
- Celery app with Redis/SQS broker support and configurable queues
- Notification tasks for incident fan-out, webhooks, and escalations
- Pluggable TaskQueue abstraction with in-memory driver for testing

Add Grafana observability stack (Loki, Tempo, Prometheus, Grafana):
- OpenTelemetry Collector for receiving OTLP traces and logs
- Tempo for distributed tracing backend
- Loki for log aggregation with Promtail DaemonSet
- Prometheus for metrics scraping with RBAC configuration
- Grafana with pre-provisioned datasources and API overview dashboard
- Helm templates for all observability components

Enhance application infrastructure:
- Global exception handlers with structured ErrorResponse schema
- Request logging middleware with timing metrics
- Health check updated to verify task queue connectivity
- Non-root user in Dockerfile for security
- Init containers in Helm deployments for dependency ordering
- Production Helm values with autoscaling and retention policies
This commit is contained in:
2026-01-07 20:51:13 -05:00
parent f427d191e0
commit 46ede7757d
45 changed files with 3742 additions and 76 deletions

View File

@@ -2,7 +2,8 @@
from fastapi import APIRouter, Response, status
from app.db import db, redis_client
from app.db import db
from app.taskqueue import task_queue
router = APIRouter()
@@ -16,14 +17,14 @@ async def healthz() -> dict[str, str]:
@router.get("/readyz")
async def readyz(response: Response) -> dict[str, str | dict[str, bool]]:
"""
Readiness probe - checks database and Redis connectivity.
Readiness probe - checks database and task queue connectivity.
- Check Postgres status
- Check Redis status
- Check configured task queue backend
- Return overall healthiness
"""
checks = {
"postgres": False,
"redis": False,
"task_queue": False,
}
try:
@@ -34,7 +35,7 @@ async def readyz(response: Response) -> dict[str, str | dict[str, bool]]:
except Exception:
pass
checks["redis"] = await redis_client.ping()
checks["task_queue"] = await task_queue.ping()
all_healthy = all(checks.values())
if not all_healthy:

View File

@@ -1,5 +1,7 @@
"""Application configuration via pydantic-settings."""
from typing import Literal
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -15,9 +17,22 @@ class Settings(BaseSettings):
# Database
database_url: str
# Redis
# Redis (legacy default for Celery broker)
redis_url: str = "redis://localhost:6379/0"
# Task queue
task_queue_driver: Literal["celery", "inmemory"] = "celery"
task_queue_broker_url: str | None = None
task_queue_backend: Literal["redis", "sqs"] = "redis"
task_queue_default_queue: str = "default"
task_queue_critical_queue: str = "critical"
task_queue_visibility_timeout: int = 600
task_queue_polling_interval: float = 1.0
notification_escalation_delay_seconds: int = 900
# AWS (used when task_queue_backend="sqs")
aws_region: str | None = None
# JWT
jwt_secret_key: str
jwt_algorithm: str = "HS256"
@@ -30,5 +45,22 @@ class Settings(BaseSettings):
debug: bool = False
api_v1_prefix: str = "/v1"
# OpenTelemetry
otel_enabled: bool = True
otel_service_name: str = "incidentops-api"
otel_environment: str = "development"
otel_exporter_otlp_endpoint: str | None = None # e.g., "http://tempo:4317"
otel_exporter_otlp_insecure: bool = True
otel_log_level: str = "INFO"
settings = Settings()
# Metrics
prometheus_port: int = 9464 # Port for Prometheus metrics endpoint
@property
def resolved_task_queue_broker_url(self) -> str:
"""Return the broker URL with redis fallback for backwards compatibility."""
return self.task_queue_broker_url or self.redis_url
settings = Settings() # type: ignore[call-arg]

164
app/core/logging.py Normal file
View File

@@ -0,0 +1,164 @@
"""Structured JSON logging configuration with OpenTelemetry integration."""
import json
import logging
import sys
from datetime import datetime, timezone
from typing import Any
from app.config import settings
class JSONFormatter(logging.Formatter):
"""
JSON log formatter that outputs structured logs with trace context.
Log format includes:
- timestamp: ISO 8601 format
- level: Log level name
- message: Log message
- logger: Logger name
- trace_id: OpenTelemetry trace ID (if available)
- span_id: OpenTelemetry span ID (if available)
- Extra fields from log record
"""
def format(self, record: logging.LogRecord) -> str:
log_data: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"message": record.getMessage(),
"logger": record.name,
}
# Add trace context if available (injected by OpenTelemetry LoggingInstrumentor)
if hasattr(record, "otelTraceID") and record.otelTraceID != "0":
log_data["trace_id"] = record.otelTraceID
if hasattr(record, "otelSpanID") and record.otelSpanID != "0":
log_data["span_id"] = record.otelSpanID
# Add exception info if present
if record.exc_info:
log_data["exception"] = self.formatException(record.exc_info)
# Add extra fields (excluding standard LogRecord attributes)
standard_attrs = {
"name",
"msg",
"args",
"created",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"exc_info",
"exc_text",
"thread",
"threadName",
"taskName",
"message",
"otelTraceID",
"otelSpanID",
"otelTraceSampled",
"otelServiceName",
}
for key, value in record.__dict__.items():
if key not in standard_attrs and not key.startswith("_"):
log_data[key] = value
return json.dumps(log_data, default=str)
class DevelopmentFormatter(logging.Formatter):
"""
Human-readable formatter for development with color support.
Format: [TIME] LEVEL logger - message [trace_id]
"""
COLORS = {
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
}
RESET = "\033[0m"
def format(self, record: logging.LogRecord) -> str:
color = self.COLORS.get(record.levelname, "")
reset = self.RESET
# Format timestamp
timestamp = datetime.now(timezone.utc).strftime("%H:%M:%S.%f")[:-3]
# Build message
msg = f"[{timestamp}] {color}{record.levelname:8}{reset} {record.name} - {record.getMessage()}"
# Add trace context if available
if hasattr(record, "otelTraceID") and record.otelTraceID != "0":
msg += f" [{record.otelTraceID[:8]}...]"
# Add exception if present
if record.exc_info:
msg += f"\n{self.formatException(record.exc_info)}"
return msg
def setup_logging() -> None:
"""
Configure application logging.
- JSON format in production (OTEL enabled)
- Human-readable format in development
- Integrates with OpenTelemetry trace context
"""
# Determine log level
log_level = getattr(logging, settings.otel_log_level.upper(), logging.INFO)
# Choose formatter based on environment
if settings.otel_enabled and not settings.debug:
formatter = JSONFormatter()
else:
formatter = DevelopmentFormatter()
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
# Remove existing handlers
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Add stdout handler
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root_logger.addHandler(handler)
# Reduce noise from third-party libraries (keep uvicorn access at INFO so requests are logged)
logging.getLogger("uvicorn.access").setLevel(logging.INFO)
logging.getLogger("asyncpg").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.info(
"Logging configured",
extra={
"log_level": settings.otel_log_level,
"format": "json" if settings.otel_enabled and not settings.debug else "dev",
},
)
def get_logger(name: str) -> logging.Logger:
"""Get a logger instance with the given name."""
return logging.getLogger(name)

271
app/core/telemetry.py Normal file
View File

@@ -0,0 +1,271 @@
"""OpenTelemetry instrumentation for tracing, metrics, and logging."""
import logging
from contextlib import contextmanager
from typing import Any
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.prometheus import PrometheusMetricReader
from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.system_metrics import SystemMetricsInstrumentor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
from opentelemetry.semconv.resource import ResourceAttributes
from prometheus_client import REGISTRY, start_http_server
from app.config import settings
logger = logging.getLogger(__name__)
_tracer_provider: TracerProvider | None = None
_meter_provider: MeterProvider | None = None
# Custom metrics
_request_counter = None
_request_duration = None
_active_requests = None
_error_counter = None
def setup_telemetry(app: Any) -> None:
"""
Initialize OpenTelemetry with tracing, metrics, and logging instrumentation.
Configures:
- OTLP exporter for traces (to Tempo/Jaeger)
- Prometheus exporter for metrics (scraped by Prometheus)
- Auto-instrumentation for FastAPI, asyncpg, httpx, redis
- System metrics (CPU, memory, etc.)
- Logging instrumentation for trace context injection
"""
global _tracer_provider, _meter_provider
global _request_counter, _request_duration, _active_requests, _error_counter
if not settings.otel_enabled:
logger.info("OpenTelemetry disabled")
return
# Create resource with service info
resource = Resource.create(
{
ResourceAttributes.SERVICE_NAME: settings.otel_service_name,
ResourceAttributes.SERVICE_VERSION: "0.1.0",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: settings.otel_environment,
}
)
# =========================================
# TRACING SETUP
# =========================================
_tracer_provider = TracerProvider(resource=resource)
if settings.otel_exporter_otlp_endpoint:
otlp_exporter = OTLPSpanExporter(
endpoint=settings.otel_exporter_otlp_endpoint,
insecure=settings.otel_exporter_otlp_insecure,
)
_tracer_provider.add_span_processor(BatchSpanProcessor(otlp_exporter))
logger.info(f"OTLP exporter configured: {settings.otel_exporter_otlp_endpoint}")
else:
_tracer_provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter()))
logger.info("Console span exporter configured (no OTLP endpoint)")
trace.set_tracer_provider(_tracer_provider)
# =========================================
# METRICS SETUP
# =========================================
# Prometheus metric reader exposes metrics at /metrics endpoint
prometheus_reader = PrometheusMetricReader()
_meter_provider = MeterProvider(resource=resource, metric_readers=[prometheus_reader])
metrics.set_meter_provider(_meter_provider)
# Start Prometheus HTTP server on port 9464
prometheus_port = settings.prometheus_port
try:
start_http_server(port=prometheus_port, registry=REGISTRY)
logger.info(f"Prometheus metrics server started on port {prometheus_port}")
except OSError as e:
logger.warning(f"Could not start Prometheus server on port {prometheus_port}: {e}")
# Create custom metrics
meter = metrics.get_meter(__name__)
_request_counter = meter.create_counter(
name="http_requests_total",
description="Total number of HTTP requests",
unit="1",
)
_request_duration = meter.create_histogram(
name="http_request_duration_seconds",
description="HTTP request duration in seconds",
unit="s",
)
_active_requests = meter.create_up_down_counter(
name="http_requests_active",
description="Number of active HTTP requests",
unit="1",
)
_error_counter = meter.create_counter(
name="http_errors_total",
description="Total number of HTTP errors",
unit="1",
)
# Instrument system metrics (CPU, memory, etc.)
SystemMetricsInstrumentor().instrument()
logger.info("System metrics instrumentation enabled")
# =========================================
# LIBRARY INSTRUMENTATION
# =========================================
FastAPIInstrumentor.instrument_app(
app,
excluded_urls="healthz,readyz,metrics",
tracer_provider=_tracer_provider,
meter_provider=_meter_provider,
)
AsyncPGInstrumentor().instrument(tracer_provider=_tracer_provider)
HTTPXClientInstrumentor().instrument(tracer_provider=_tracer_provider)
RedisInstrumentor().instrument(tracer_provider=_tracer_provider)
# Inject trace context into logs
LoggingInstrumentor().instrument(
set_logging_format=True,
log_level=logging.INFO,
)
logger.info(
f"OpenTelemetry initialized: service={settings.otel_service_name}, "
f"env={settings.otel_environment}, metrics_port={prometheus_port}"
)
async def shutdown_telemetry() -> None:
"""Gracefully shutdown the tracer and meter providers."""
global _tracer_provider, _meter_provider
if _tracer_provider:
_tracer_provider.shutdown()
_tracer_provider = None
logger.info("Tracer provider shutdown complete")
if _meter_provider:
_meter_provider.shutdown()
_meter_provider = None
logger.info("Meter provider shutdown complete")
def get_tracer(name: str) -> trace.Tracer:
"""Get a tracer instance for manual span creation."""
return trace.get_tracer(name)
def get_meter(name: str) -> metrics.Meter:
"""Get a meter instance for custom metrics."""
return metrics.get_meter(name)
def get_current_trace_id() -> str | None:
"""Get the current trace ID for request correlation."""
span = trace.get_current_span()
if span and span.get_span_context().is_valid:
return format(span.get_span_context().trace_id, "032x")
return None
def get_current_span_id() -> str | None:
"""Get the current span ID."""
span = trace.get_current_span()
if span and span.get_span_context().is_valid:
return format(span.get_span_context().span_id, "016x")
return None
@contextmanager
def create_span(name: str, attributes: dict[str, Any] | None = None):
"""Context manager for creating manual spans."""
tracer = get_tracer(__name__)
with tracer.start_as_current_span(name, attributes=attributes) as span:
yield span
def add_span_attributes(attributes: dict[str, Any]) -> None:
"""Add attributes to the current span."""
span = trace.get_current_span()
if span:
for key, value in attributes.items():
span.set_attribute(key, value)
def record_exception(exception: Exception) -> None:
"""Record an exception on the current span."""
span = trace.get_current_span()
if span:
span.record_exception(exception)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(exception)))
# =========================================
# CUSTOM METRICS HELPERS
# =========================================
def record_request(method: str, endpoint: str, status_code: int) -> None:
"""Record a request metric."""
if _request_counter:
_request_counter.add(
1,
{
"method": method,
"endpoint": endpoint,
"status_code": str(status_code),
},
)
def record_request_duration(method: str, endpoint: str, duration: float) -> None:
"""Record request duration in seconds."""
if _request_duration:
_request_duration.record(
duration,
{
"method": method,
"endpoint": endpoint,
},
)
def increment_active_requests(method: str, endpoint: str) -> None:
"""Increment active requests counter."""
if _active_requests:
_active_requests.add(1, {"method": method, "endpoint": endpoint})
def decrement_active_requests(method: str, endpoint: str) -> None:
"""Decrement active requests counter."""
if _active_requests:
_active_requests.add(-1, {"method": method, "endpoint": endpoint})
def record_error(method: str, endpoint: str, error_type: str) -> None:
"""Record an error metric."""
if _error_counter:
_error_counter.add(
1,
{
"method": method,
"endpoint": endpoint,
"error_type": error_type,
},
)

View File

@@ -6,7 +6,6 @@ from contextvars import ContextVar
import asyncpg
from asyncpg.pool import PoolConnectionProxy
import redis.asyncio as redis
class Database:
@@ -46,34 +45,8 @@ class Database:
yield conn
class RedisClient:
"""Manages Redis connection."""
client: redis.Redis | None = None
async def connect(self, url: str) -> None:
"""Create Redis connection."""
self.client = redis.from_url(url, decode_responses=True)
async def disconnect(self) -> None:
"""Close Redis connection."""
if self.client:
await self.client.aclose()
async def ping(self) -> bool:
"""Check if Redis is reachable."""
if not self.client:
return False
try:
await self.client.ping()
return True
except redis.RedisError:
return False
# Global instances
# Global instance
db = Database()
redis_client = RedisClient()
_connection_ctx: ContextVar[asyncpg.Connection | PoolConnectionProxy | None] = ContextVar(

View File

@@ -1,26 +1,50 @@
"""FastAPI application entry point."""
import logging
import time
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi import FastAPI, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.api.v1 import auth, health, incidents, org
from app.config import settings
from app.db import db, redis_client
from app.core.logging import setup_logging
from app.core.telemetry import (
get_current_trace_id,
record_exception,
setup_telemetry,
shutdown_telemetry,
)
from app.db import db
from app.schemas.common import ErrorDetail, ErrorResponse
from app.taskqueue import task_queue
# Initialize logging before anything else
setup_logging()
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Manage application lifecycle - connect/disconnect resources."""
# Startup
logger.info("Starting IncidentOps API")
await db.connect(settings.database_url)
await redis_client.connect(settings.redis_url)
await task_queue.startup()
logger.info("Startup complete")
yield
# Shutdown
await redis_client.disconnect()
logger.info("Shutting down IncidentOps API")
await task_queue.shutdown()
await db.disconnect()
await shutdown_telemetry()
logger.info("Shutdown complete")
app = FastAPI(
@@ -33,6 +57,26 @@ app = FastAPI(
lifespan=lifespan,
)
# Set up OpenTelemetry instrumentation
setup_telemetry(app)
@app.middleware("http")
async def request_logging_middleware(request: Request, call_next):
start = time.time()
response = await call_next(request)
duration_ms = (time.time() - start) * 1000
logger.info(
"request",
extra={
"method": request.method,
"path": request.url.path,
"status_code": response.status_code,
"duration_ms": round(duration_ms, 2),
},
)
return response
app.openapi_tags = [
{"name": "auth", "description": "Registration, login, token lifecycle"},
{"name": "org", "description": "Organization membership, services, and notifications"},
@@ -41,9 +85,133 @@ app.openapi_tags = [
]
def custom_openapi() -> dict:
"""Add JWT bearer security scheme to the generated OpenAPI schema."""
# ---------------------------------------------------------------------------
# Global Exception Handlers
# ---------------------------------------------------------------------------
def _build_error_response(
error: str,
message: str,
status_code: int,
details: list[ErrorDetail] | None = None,
) -> JSONResponse:
"""Build a structured error response with trace context."""
response = ErrorResponse(
error=error,
message=message,
details=details,
request_id=get_current_trace_id(),
)
return JSONResponse(
status_code=status_code,
content=jsonable_encoder(response),
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(
request: Request, exc: StarletteHTTPException
) -> JSONResponse:
"""Handle HTTP exceptions with structured error responses."""
# Map status codes to error type strings
error_types = {
400: "bad_request",
401: "unauthorized",
403: "forbidden",
404: "not_found",
409: "conflict",
422: "validation_error",
429: "rate_limited",
500: "internal_error",
502: "bad_gateway",
503: "service_unavailable",
}
error_type = error_types.get(exc.status_code, "error")
logger.warning(
"HTTP exception",
extra={
"status_code": exc.status_code,
"error": error_type,
"detail": exc.detail,
"path": str(request.url.path),
"method": request.method,
},
)
return _build_error_response(
error=error_type,
message=str(exc.detail),
status_code=exc.status_code,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
"""Handle Pydantic validation errors with detailed error responses."""
details = [
ErrorDetail(
loc=[str(loc) for loc in error["loc"]],
msg=error["msg"],
type=error["type"],
)
for error in exc.errors()
]
logger.warning(
"Validation error",
extra={
"path": str(request.url.path),
"method": request.method,
"error_count": len(details),
},
)
return _build_error_response(
error="validation_error",
message="Request validation failed",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
details=details,
)
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Handle unexpected exceptions with logging and safe error response."""
# Record exception in the current span for tracing
record_exception(exc)
logger.exception(
"Unhandled exception",
extra={
"path": str(request.url.path),
"method": request.method,
"exception_type": type(exc).__name__,
},
)
# Don't leak internal error details in production
message = "An unexpected error occurred"
if settings.debug:
message = f"{type(exc).__name__}: {exc}"
return _build_error_response(
error="internal_error",
message=message,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# ---------------------------------------------------------------------------
# OpenAPI Customization
# ---------------------------------------------------------------------------
def custom_openapi() -> dict:
"""Add JWT bearer security scheme and error responses to OpenAPI schema."""
if app.openapi_schema:
return app.openapi_schema
@@ -52,8 +220,12 @@ def custom_openapi() -> dict:
version=app.version,
description=app.description,
routes=app.routes,
tags=app.openapi_tags,
)
security_schemes = openapi_schema.setdefault("components", {}).setdefault("securitySchemes", {})
# Add security schemes
components = openapi_schema.setdefault("components", {})
security_schemes = components.setdefault("securitySchemes", {})
security_schemes["BearerToken"] = {
"type": "http",
"scheme": "bearer",
@@ -61,6 +233,42 @@ def custom_openapi() -> dict:
"description": "Paste the JWT access token returned by /auth endpoints",
}
openapi_schema["security"] = [{"BearerToken": []}]
# Add common error response schemas
schemas = components.setdefault("schemas", {})
schemas["ErrorResponse"] = {
"type": "object",
"properties": {
"error": {"type": "string", "description": "Error type identifier"},
"message": {"type": "string", "description": "Human-readable error message"},
"details": {
"type": "array",
"items": {"$ref": "#/components/schemas/ErrorDetail"},
"nullable": True,
"description": "Validation error details",
},
"request_id": {
"type": "string",
"nullable": True,
"description": "Trace ID for debugging",
},
},
"required": ["error", "message"],
}
schemas["ErrorDetail"] = {
"type": "object",
"properties": {
"loc": {
"type": "array",
"items": {"oneOf": [{"type": "string"}, {"type": "integer"}]},
"description": "Error location path",
},
"msg": {"type": "string", "description": "Error message"},
"type": {"type": "string", "description": "Error type"},
},
"required": ["loc", "msg", "type"],
}
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@@ -8,7 +8,7 @@ from app.schemas.auth import (
SwitchOrgRequest,
TokenResponse,
)
from app.schemas.common import CursorParams, PaginatedResponse
from app.schemas.common import CursorParams, ErrorDetail, ErrorResponse, PaginatedResponse
from app.schemas.incident import (
CommentRequest,
IncidentCreate,
@@ -35,6 +35,8 @@ __all__ = [
"TokenResponse",
# Common
"CursorParams",
"ErrorDetail",
"ErrorResponse",
"PaginatedResponse",
# Incident
"CommentRequest",

View File

@@ -3,6 +3,47 @@
from pydantic import BaseModel, Field
class ErrorDetail(BaseModel):
"""Individual error detail for validation errors."""
loc: list[str | int] = Field(description="Location of the error (field path)")
msg: str = Field(description="Error message")
type: str = Field(description="Error type identifier")
class ErrorResponse(BaseModel):
"""Structured error response returned by all error handlers."""
error: str = Field(description="Error type (e.g., 'not_found', 'validation_error')")
message: str = Field(description="Human-readable error message")
details: list[ErrorDetail] | None = Field(
default=None, description="Additional error details for validation errors"
)
request_id: str | None = Field(
default=None, description="Request trace ID for debugging"
)
model_config = {
"json_schema_extra": {
"examples": [
{
"error": "not_found",
"message": "Incident not found",
"request_id": "abc123def456",
},
{
"error": "validation_error",
"message": "Request validation failed",
"details": [
{"loc": ["body", "title"], "msg": "Field required", "type": "missing"}
],
"request_id": "abc123def456",
},
]
}
}
class CursorParams(BaseModel):
"""Pagination parameters using cursor-based pagination."""

View File

@@ -10,6 +10,7 @@ import asyncpg
from asyncpg.pool import PoolConnectionProxy
from app.api.deps import CurrentUser, ensure_org_access
from app.config import settings
from app.core import exceptions as exc
from app.db import Database, db
from app.repositories import IncidentRepository, ServiceRepository
@@ -21,7 +22,8 @@ from app.schemas.incident import (
IncidentResponse,
TransitionRequest,
)
from app.taskqueue import TaskQueue
from app.taskqueue import task_queue as default_task_queue
_ALLOWED_TRANSITIONS: dict[str, set[str]] = {
"triggered": {"acknowledged"},
@@ -40,8 +42,19 @@ def _as_conn(conn: asyncpg.Connection | PoolConnectionProxy) -> asyncpg.Connecti
class IncidentService:
"""Encapsulates incident lifecycle operations within an org context."""
def __init__(self, database: Database | None = None) -> None:
def __init__(
self,
database: Database | None = None,
task_queue: TaskQueue | None = None,
escalation_delay_seconds: int | None = None,
) -> None:
self.db = database or db
self.task_queue = task_queue or default_task_queue
self.escalation_delay_seconds = (
escalation_delay_seconds
if escalation_delay_seconds is not None
else settings.notification_escalation_delay_seconds
)
async def create_incident(
self,
@@ -83,7 +96,22 @@ class IncidentService:
},
)
return IncidentResponse(**incident)
incident_response = IncidentResponse(**incident)
self.task_queue.incident_triggered(
incident_id=incident_response.id,
org_id=current_user.org_id,
triggered_by=current_user.user_id,
)
if self.escalation_delay_seconds > 0:
self.task_queue.schedule_escalation_check(
incident_id=incident_response.id,
org_id=current_user.org_id,
delay_seconds=self.escalation_delay_seconds,
)
return incident_response
async def get_incidents(
self,

188
app/taskqueue.py Normal file
View File

@@ -0,0 +1,188 @@
"""Task queue abstractions for scheduling background work."""
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
from uuid import UUID
from app.config import settings
try:
from worker.celery_app import celery_app
except Exception: # pragma: no cover - celery app may not import during docs builds
celery_app = None # type: ignore[assignment]
class TaskQueue(ABC):
"""Interface for enqueueing background work."""
async def startup(self) -> None: # pragma: no cover - default no-op
"""Hook for queue initialization."""
async def shutdown(self) -> None: # pragma: no cover - default no-op
"""Hook for queue teardown."""
async def ping(self) -> bool:
"""Check if the queue backend is reachable."""
return True
def reset(self) -> None: # pragma: no cover - optional for in-memory impls
"""Reset any in-memory state (used in tests)."""
@abstractmethod
def incident_triggered(
self,
*,
incident_id: UUID,
org_id: UUID,
triggered_by: UUID | None,
) -> None:
"""Fan out an incident triggered notification."""
@abstractmethod
def schedule_escalation_check(
self,
*,
incident_id: UUID,
org_id: UUID,
delay_seconds: int,
) -> None:
"""Schedule a delayed escalation check."""
class CeleryTaskQueue(TaskQueue):
"""Celery-backed task queue that can use Redis or SQS brokers."""
def __init__(self, default_queue: str, critical_queue: str) -> None:
if celery_app is None: # pragma: no cover - guarded by try/except
raise RuntimeError("Celery application is unavailable")
self._celery = celery_app
self._default_queue = default_queue
self._critical_queue = critical_queue
def incident_triggered(
self,
*,
incident_id: UUID,
org_id: UUID,
triggered_by: UUID | None,
) -> None:
self._celery.send_task(
"worker.tasks.notifications.incident_triggered",
kwargs={
"incident_id": str(incident_id),
"org_id": str(org_id),
"triggered_by": str(triggered_by) if triggered_by else None,
},
queue=self._default_queue,
)
def schedule_escalation_check(
self,
*,
incident_id: UUID,
org_id: UUID,
delay_seconds: int,
) -> None:
self._celery.send_task(
"worker.tasks.notifications.escalate_if_unacked",
kwargs={
"incident_id": str(incident_id),
"org_id": str(org_id),
},
countdown=max(delay_seconds, 0),
queue=self._critical_queue,
)
async def ping(self) -> bool:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self._ping_sync)
def _ping_sync(self) -> bool:
connection = self._celery.connection()
try:
connection.connect()
return True
except Exception:
return False
finally:
try:
connection.release()
except Exception: # pragma: no cover - release best effort
pass
@dataclass
class InMemoryTaskQueue(TaskQueue):
"""Test-friendly queue that records dispatched tasks in memory."""
dispatched: list[tuple[str, dict[str, Any]]] | None = None
def __post_init__(self) -> None:
if self.dispatched is None:
self.dispatched = []
def incident_triggered(
self,
*,
incident_id: UUID,
org_id: UUID,
triggered_by: UUID | None,
) -> None:
self.dispatched.append(
(
"incident_triggered",
{
"incident_id": incident_id,
"org_id": org_id,
"triggered_by": triggered_by,
},
)
)
def schedule_escalation_check(
self,
*,
incident_id: UUID,
org_id: UUID,
delay_seconds: int,
) -> None:
self.dispatched.append(
(
"escalate_if_unacked",
{
"incident_id": incident_id,
"org_id": org_id,
"delay_seconds": delay_seconds,
},
)
)
def reset(self) -> None:
if self.dispatched is not None:
self.dispatched.clear()
def _build_task_queue() -> TaskQueue:
if settings.task_queue_driver == "inmemory":
return InMemoryTaskQueue()
return CeleryTaskQueue(
default_queue=settings.task_queue_default_queue,
critical_queue=settings.task_queue_critical_queue,
)
task_queue = _build_task_queue()
__all__ = [
"CeleryTaskQueue",
"InMemoryTaskQueue",
"TaskQueue",
"task_queue",
]