45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
import contextlib
|
|
from typing import AsyncGenerator
|
|
from psycopg_pool import AsyncConnectionPool
|
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
from ea_chatbot.config import Settings
|
|
|
|
settings = Settings()
|
|
|
|
_pool = None
|
|
|
|
def get_pool() -> AsyncConnectionPool:
|
|
"""Get or create the async connection pool."""
|
|
global _pool
|
|
if _pool is None:
|
|
_pool = AsyncConnectionPool(
|
|
conninfo=settings.history_db_url,
|
|
max_size=20,
|
|
kwargs={"autocommit": True, "prepare_threshold": 0},
|
|
open=False, # Don't open automatically on init
|
|
)
|
|
return _pool
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]:
|
|
"""
|
|
Context manager to get a PostgresSaver checkpointer.
|
|
Ensures that the checkpointer is properly initialized and the connection is managed.
|
|
"""
|
|
pool = get_pool()
|
|
# Ensure pool is open
|
|
if pool.closed:
|
|
await pool.open()
|
|
|
|
async with pool.connection() as conn:
|
|
checkpointer = AsyncPostgresSaver(conn)
|
|
# Ensure the necessary tables exist
|
|
await checkpointer.setup()
|
|
yield checkpointer
|
|
|
|
async def close_pool():
|
|
"""Close the connection pool."""
|
|
global _pool
|
|
if _pool and not _pool.closed:
|
|
await _pool.close()
|