import pytest from unittest.mock import AsyncMock, MagicMock from ea_chatbot.graph.checkpoint import get_checkpointer @pytest.mark.asyncio async def test_get_checkpointer_initialization(): """Test that the checkpointer setup is called.""" mock_conn = AsyncMock() mock_pool = MagicMock() # Changed from AsyncMock to MagicMock mock_pool.closed = True mock_pool.open = AsyncMock() # Ensure open is awaitable # Setup mock_pool.connection() to return an async context manager mock_cm = MagicMock() mock_cm.__aenter__ = AsyncMock(return_value=mock_conn) mock_cm.__aexit__ = AsyncMock(return_value=None) mock_pool.connection.return_value = mock_cm # We need to patch the get_pool function and AsyncPostgresSaver in the module with MagicMock() as mock_get_pool, \ MagicMock() as mock_saver_class: import ea_chatbot.graph.checkpoint as checkpoint mock_get_pool.return_value = mock_pool # Mock AsyncPostgresSaver class mock_saver_instance = AsyncMock() mock_saver_class.return_value = mock_saver_instance original_get_pool = checkpoint.get_pool checkpoint.get_pool = mock_get_pool # Patch AsyncPostgresSaver where it's imported in checkpoint.py import langgraph.checkpoint.postgres.aio as pg_aio original_saver = checkpoint.AsyncPostgresSaver checkpoint.AsyncPostgresSaver = mock_saver_class try: async with get_checkpointer() as checkpointer: assert checkpointer == mock_saver_instance # Verify setup was called mock_saver_instance.setup.assert_called_once() # Verify pool was opened mock_pool.open.assert_called_once() # Verify connection was requested mock_pool.connection.assert_called_once() finally: checkpoint.get_pool = original_get_pool checkpoint.AsyncPostgresSaver = original_saver