feat(auth): implement auth stack
This commit is contained in:
80
tests/db/test_get_conn.py
Normal file
80
tests/db/test_get_conn.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Tests for the get_conn dependency helper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import db, get_conn
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class _FakeConnection:
|
||||
def __init__(self, idx: int) -> None:
|
||||
self.idx = idx
|
||||
|
||||
|
||||
class _AcquireContext:
|
||||
def __init__(self, conn: _FakeConnection, tracker: "_FakePool") -> None:
|
||||
self._conn = conn
|
||||
self._tracker = tracker
|
||||
|
||||
async def __aenter__(self) -> _FakeConnection:
|
||||
self._tracker.active += 1
|
||||
return self._conn
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
self._tracker.active -= 1
|
||||
|
||||
|
||||
class _FakePool:
|
||||
def __init__(self) -> None:
|
||||
self.acquire_calls = 0
|
||||
self.active = 0
|
||||
|
||||
def acquire(self) -> _AcquireContext:
|
||||
conn = _FakeConnection(self.acquire_calls)
|
||||
self.acquire_calls += 1
|
||||
return _AcquireContext(conn, self)
|
||||
|
||||
|
||||
async def _collect_single_connection():
|
||||
connection = None
|
||||
async for conn in get_conn():
|
||||
connection = conn
|
||||
return connection
|
||||
|
||||
|
||||
async def test_get_conn_reuses_connection_within_scope():
|
||||
original_pool = db.pool
|
||||
fake_pool = _FakePool()
|
||||
db.pool = fake_pool
|
||||
try:
|
||||
captured: list[_FakeConnection] = []
|
||||
|
||||
async for outer in get_conn():
|
||||
captured.append(outer)
|
||||
async for inner in get_conn():
|
||||
captured.append(inner)
|
||||
|
||||
assert len(captured) == 2
|
||||
assert captured[0] is captured[1]
|
||||
assert fake_pool.acquire_calls == 1
|
||||
finally:
|
||||
db.pool = original_pool
|
||||
|
||||
|
||||
async def test_get_conn_acquires_new_connection_per_root_scope():
|
||||
original_pool = db.pool
|
||||
fake_pool = _FakePool()
|
||||
db.pool = fake_pool
|
||||
try:
|
||||
first = await _collect_single_connection()
|
||||
second = await _collect_single_connection()
|
||||
|
||||
assert first is not None and second is not None
|
||||
assert first is not second
|
||||
assert fake_pool.acquire_calls == 2
|
||||
finally:
|
||||
db.pool = original_pool
|
||||
Reference in New Issue
Block a user