import contextlib from typing import AsyncGenerator from psycopg_pool import AsyncConnectionPool from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from ea_chatbot.config import Settings settings = Settings() _pool = None def get_pool() -> AsyncConnectionPool: """Get or create the async connection pool.""" global _pool if _pool is None: _pool = AsyncConnectionPool( conninfo=settings.history_db_url, max_size=20, kwargs={"autocommit": True, "prepare_threshold": 0}, open=False, # Don't open automatically on init ) return _pool @contextlib.asynccontextmanager async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]: """ Context manager to get a PostgresSaver checkpointer. Ensures that the checkpointer is properly initialized and the connection is managed. """ pool = get_pool() # Ensure pool is open if pool.closed: await pool.open() async with pool.connection() as conn: checkpointer = AsyncPostgresSaver(conn) # Ensure the necessary tables exist await checkpointer.setup() yield checkpointer async def close_pool(): """Close the connection pool.""" global _pool if _pool and not _pool.closed: await _pool.close()