feat(graph): Implement PostgresSaver checkpointer integration
This commit is contained in:
44
src/ea_chatbot/graph/checkpoint.py
Normal file
44
src/ea_chatbot/graph/checkpoint.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import contextlib
|
||||
from typing import AsyncGenerator
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
_pool = None
|
||||
|
||||
def get_pool() -> AsyncConnectionPool:
|
||||
"""Get or create the async connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = AsyncConnectionPool(
|
||||
conninfo=settings.history_db_url,
|
||||
max_size=20,
|
||||
kwargs={"autocommit": True, "prepare_threshold": 0},
|
||||
open=False, # Don't open automatically on init
|
||||
)
|
||||
return _pool
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]:
|
||||
"""
|
||||
Context manager to get a PostgresSaver checkpointer.
|
||||
Ensures that the checkpointer is properly initialized and the connection is managed.
|
||||
"""
|
||||
pool = get_pool()
|
||||
# Ensure pool is open
|
||||
if pool.closed:
|
||||
await pool.open()
|
||||
|
||||
async with pool.connection() as conn:
|
||||
checkpointer = AsyncPostgresSaver(conn)
|
||||
# Ensure the necessary tables exist
|
||||
await checkpointer.setup()
|
||||
yield checkpointer
|
||||
|
||||
async def close_pool():
|
||||
"""Close the connection pool."""
|
||||
global _pool
|
||||
if _pool and not _pool.closed:
|
||||
await _pool.close()
|
||||
Reference in New Issue
Block a user