feat(graph): Implement PostgresSaver checkpointer integration
This commit is contained in:
@@ -40,6 +40,7 @@ packages = ["src/ea_chatbot"]
|
|||||||
dev = [
|
dev = [
|
||||||
"pytest>=9.0.2",
|
"pytest>=9.0.2",
|
||||||
"pytest-cov>=7.0.0",
|
"pytest-cov>=7.0.0",
|
||||||
|
"pytest-asyncio>=0.23.0",
|
||||||
"ruff>=0.9.3",
|
"ruff>=0.9.3",
|
||||||
"mypy>=1.14.1",
|
"mypy>=1.14.1",
|
||||||
]
|
]
|
||||||
|
|||||||
44
src/ea_chatbot/graph/checkpoint.py
Normal file
44
src/ea_chatbot/graph/checkpoint.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import contextlib
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
_pool = None
|
||||||
|
|
||||||
|
def get_pool() -> AsyncConnectionPool:
|
||||||
|
"""Get or create the async connection pool."""
|
||||||
|
global _pool
|
||||||
|
if _pool is None:
|
||||||
|
_pool = AsyncConnectionPool(
|
||||||
|
conninfo=settings.history_db_url,
|
||||||
|
max_size=20,
|
||||||
|
kwargs={"autocommit": True, "prepare_threshold": 0},
|
||||||
|
open=False, # Don't open automatically on init
|
||||||
|
)
|
||||||
|
return _pool
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]:
|
||||||
|
"""
|
||||||
|
Context manager to get a PostgresSaver checkpointer.
|
||||||
|
Ensures that the checkpointer is properly initialized and the connection is managed.
|
||||||
|
"""
|
||||||
|
pool = get_pool()
|
||||||
|
# Ensure pool is open
|
||||||
|
if pool.closed:
|
||||||
|
await pool.open()
|
||||||
|
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
checkpointer = AsyncPostgresSaver(conn)
|
||||||
|
# Ensure the necessary tables exist
|
||||||
|
await checkpointer.setup()
|
||||||
|
yield checkpointer
|
||||||
|
|
||||||
|
async def close_pool():
|
||||||
|
"""Close the connection pool."""
|
||||||
|
global _pool
|
||||||
|
if _pool and not _pool.closed:
|
||||||
|
await _pool.close()
|
||||||
49
tests/graph/test_checkpoint.py
Normal file
49
tests/graph/test_checkpoint.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from ea_chatbot.graph.checkpoint import get_checkpointer
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_checkpointer_initialization():
|
||||||
|
"""Test that the checkpointer setup is called."""
|
||||||
|
mock_conn = AsyncMock()
|
||||||
|
mock_pool = MagicMock() # Changed from AsyncMock to MagicMock
|
||||||
|
mock_pool.closed = True
|
||||||
|
mock_pool.open = AsyncMock() # Ensure open is awaitable
|
||||||
|
|
||||||
|
# Setup mock_pool.connection() to return an async context manager
|
||||||
|
mock_cm = MagicMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_pool.connection.return_value = mock_cm
|
||||||
|
|
||||||
|
# We need to patch the get_pool function and AsyncPostgresSaver in the module
|
||||||
|
with MagicMock() as mock_get_pool, \
|
||||||
|
MagicMock() as mock_saver_class:
|
||||||
|
import ea_chatbot.graph.checkpoint as checkpoint
|
||||||
|
mock_get_pool.return_value = mock_pool
|
||||||
|
|
||||||
|
# Mock AsyncPostgresSaver class
|
||||||
|
mock_saver_instance = AsyncMock()
|
||||||
|
mock_saver_class.return_value = mock_saver_instance
|
||||||
|
|
||||||
|
original_get_pool = checkpoint.get_pool
|
||||||
|
checkpoint.get_pool = mock_get_pool
|
||||||
|
|
||||||
|
# Patch AsyncPostgresSaver where it's imported in checkpoint.py
|
||||||
|
import langgraph.checkpoint.postgres.aio as pg_aio
|
||||||
|
original_saver = checkpoint.AsyncPostgresSaver
|
||||||
|
checkpoint.AsyncPostgresSaver = mock_saver_class
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_checkpointer() as checkpointer:
|
||||||
|
assert checkpointer == mock_saver_instance
|
||||||
|
# Verify setup was called
|
||||||
|
mock_saver_instance.setup.assert_called_once()
|
||||||
|
|
||||||
|
# Verify pool was opened
|
||||||
|
mock_pool.open.assert_called_once()
|
||||||
|
# Verify connection was requested
|
||||||
|
mock_pool.connection.assert_called_once()
|
||||||
|
finally:
|
||||||
|
checkpoint.get_pool = original_get_pool
|
||||||
|
checkpoint.AsyncPostgresSaver = original_saver
|
||||||
Reference in New Issue
Block a user