Files
ea-chatbot-lg/src/ea_chatbot/graph/checkpoint.py

45 lines
1.3 KiB
Python

import contextlib
from typing import AsyncGenerator
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from ea_chatbot.config import Settings
settings = Settings()
_pool = None
def get_pool() -> AsyncConnectionPool:
"""Get or create the async connection pool."""
global _pool
if _pool is None:
_pool = AsyncConnectionPool(
conninfo=settings.history_db_url,
max_size=20,
kwargs={"autocommit": True, "prepare_threshold": 0},
open=False, # Don't open automatically on init
)
return _pool
@contextlib.asynccontextmanager
async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]:
"""
Context manager to get a PostgresSaver checkpointer.
Ensures that the checkpointer is properly initialized and the connection is managed.
"""
pool = get_pool()
# Ensure pool is open
if pool.closed:
await pool.open()
async with pool.connection() as conn:
checkpointer = AsyncPostgresSaver(conn)
# Ensure the necessary tables exist
await checkpointer.setup()
yield checkpointer
async def close_pool():
"""Close the connection pool."""
global _pool
if _pool and not _pool.closed:
await _pool.close()