feat(graph): Implement PostgresSaver checkpointer integration

This commit is contained in:
Yunxiao Xu
2026-02-10 03:32:14 -08:00
parent 60429e1adc
commit 452d4a1957
3 changed files with 94 additions and 0 deletions

View File

@@ -40,6 +40,7 @@ packages = ["src/ea_chatbot"]
dev = [
"pytest>=9.0.2",
"pytest-cov>=7.0.0",
"pytest-asyncio>=0.23.0",
"ruff>=0.9.3",
"mypy>=1.14.1",
]

View File

@@ -0,0 +1,44 @@
import contextlib
from typing import AsyncGenerator
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from ea_chatbot.config import Settings
settings = Settings()
_pool = None
def get_pool() -> AsyncConnectionPool:
"""Get or create the async connection pool."""
global _pool
if _pool is None:
_pool = AsyncConnectionPool(
conninfo=settings.history_db_url,
max_size=20,
kwargs={"autocommit": True, "prepare_threshold": 0},
open=False, # Don't open automatically on init
)
return _pool
@contextlib.asynccontextmanager
async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]:
"""
Context manager to get a PostgresSaver checkpointer.
Ensures that the checkpointer is properly initialized and the connection is managed.
"""
pool = get_pool()
# Ensure pool is open
if pool.closed:
await pool.open()
async with pool.connection() as conn:
checkpointer = AsyncPostgresSaver(conn)
# Ensure the necessary tables exist
await checkpointer.setup()
yield checkpointer
async def close_pool():
"""Close the connection pool."""
global _pool
if _pool and not _pool.closed:
await _pool.close()

View File

@@ -0,0 +1,49 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from ea_chatbot.graph.checkpoint import get_checkpointer
@pytest.mark.asyncio
async def test_get_checkpointer_initialization():
"""Test that the checkpointer setup is called."""
mock_conn = AsyncMock()
mock_pool = MagicMock() # Changed from AsyncMock to MagicMock
mock_pool.closed = True
mock_pool.open = AsyncMock() # Ensure open is awaitable
# Setup mock_pool.connection() to return an async context manager
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
mock_cm.__aexit__ = AsyncMock(return_value=None)
mock_pool.connection.return_value = mock_cm
# We need to patch the get_pool function and AsyncPostgresSaver in the module
with MagicMock() as mock_get_pool, \
MagicMock() as mock_saver_class:
import ea_chatbot.graph.checkpoint as checkpoint
mock_get_pool.return_value = mock_pool
# Mock AsyncPostgresSaver class
mock_saver_instance = AsyncMock()
mock_saver_class.return_value = mock_saver_instance
original_get_pool = checkpoint.get_pool
checkpoint.get_pool = mock_get_pool
# Patch AsyncPostgresSaver where it's imported in checkpoint.py
import langgraph.checkpoint.postgres.aio as pg_aio
original_saver = checkpoint.AsyncPostgresSaver
checkpoint.AsyncPostgresSaver = mock_saver_class
try:
async with get_checkpointer() as checkpointer:
assert checkpointer == mock_saver_instance
# Verify setup was called
mock_saver_instance.setup.assert_called_once()
# Verify pool was opened
mock_pool.open.assert_called_once()
# Verify connection was requested
mock_pool.connection.assert_called_once()
finally:
checkpoint.get_pool = original_get_pool
checkpoint.AsyncPostgresSaver = original_saver