diff --git a/pyproject.toml b/pyproject.toml index e9243f8..2c68ecf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ packages = ["src/ea_chatbot"] dev = [ "pytest>=9.0.2", "pytest-cov>=7.0.0", + "pytest-asyncio>=0.23.0", "ruff>=0.9.3", "mypy>=1.14.1", ] diff --git a/src/ea_chatbot/graph/checkpoint.py b/src/ea_chatbot/graph/checkpoint.py new file mode 100644 index 0000000..e32f15c --- /dev/null +++ b/src/ea_chatbot/graph/checkpoint.py @@ -0,0 +1,44 @@ +import contextlib +from typing import AsyncGenerator +from psycopg_pool import AsyncConnectionPool +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from ea_chatbot.config import Settings + +settings = Settings() + +_pool = None + +def get_pool() -> AsyncConnectionPool: + """Get or create the async connection pool.""" + global _pool + if _pool is None: + _pool = AsyncConnectionPool( + conninfo=settings.history_db_url, + max_size=20, + kwargs={"autocommit": True, "prepare_threshold": 0}, + open=False, # Don't open automatically on init + ) + return _pool + +@contextlib.asynccontextmanager +async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]: + """ + Context manager to get a PostgresSaver checkpointer. + Ensures that the checkpointer is properly initialized and the connection is managed. + """ + pool = get_pool() + # Ensure pool is open + if pool.closed: + await pool.open() + + async with pool.connection() as conn: + checkpointer = AsyncPostgresSaver(conn) + # Ensure the necessary tables exist + await checkpointer.setup() + yield checkpointer + +async def close_pool(): + """Close the connection pool.""" + global _pool + if _pool and not _pool.closed: + await _pool.close() diff --git a/tests/graph/test_checkpoint.py b/tests/graph/test_checkpoint.py new file mode 100644 index 0000000..4a23128 --- /dev/null +++ b/tests/graph/test_checkpoint.py @@ -0,0 +1,49 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock +from ea_chatbot.graph.checkpoint import get_checkpointer + +@pytest.mark.asyncio +async def test_get_checkpointer_initialization(): + """Test that the checkpointer setup is called.""" + mock_conn = AsyncMock() + mock_pool = MagicMock() # Changed from AsyncMock to MagicMock + mock_pool.closed = True + mock_pool.open = AsyncMock() # Ensure open is awaitable + + # Setup mock_pool.connection() to return an async context manager + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_conn) + mock_cm.__aexit__ = AsyncMock(return_value=None) + mock_pool.connection.return_value = mock_cm + + # We need to patch the get_pool function and AsyncPostgresSaver in the module + with MagicMock() as mock_get_pool, \ + MagicMock() as mock_saver_class: + import ea_chatbot.graph.checkpoint as checkpoint + mock_get_pool.return_value = mock_pool + + # Mock AsyncPostgresSaver class + mock_saver_instance = AsyncMock() + mock_saver_class.return_value = mock_saver_instance + + original_get_pool = checkpoint.get_pool + checkpoint.get_pool = mock_get_pool + + # Patch AsyncPostgresSaver where it's imported in checkpoint.py + import langgraph.checkpoint.postgres.aio as pg_aio + original_saver = checkpoint.AsyncPostgresSaver + checkpoint.AsyncPostgresSaver = mock_saver_class + + try: + async with get_checkpointer() as checkpointer: + assert checkpointer == mock_saver_instance + # Verify setup was called + mock_saver_instance.setup.assert_called_once() + + # Verify pool was opened + mock_pool.open.assert_called_once() + # Verify connection was requested + mock_pool.connection.assert_called_once() + finally: + checkpoint.get_pool = original_get_pool + checkpoint.AsyncPostgresSaver = original_saver