Refactor: Move backend files to backend/ directory and split .gitignore
This commit is contained in:
0
backend/src/ea_chatbot/__init__.py
Normal file
0
backend/src/ea_chatbot/__init__.py
Normal file
0
backend/src/ea_chatbot/api/__init__.py
Normal file
0
backend/src/ea_chatbot/api/__init__.py
Normal file
46
backend/src/ea_chatbot/api/dependencies.py
Normal file
46
backend/src/ea_chatbot/api/dependencies.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
from ea_chatbot.api.utils import decode_access_token
|
||||
from ea_chatbot.history.models import User
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Shared instances
|
||||
history_manager = HistoryManager(settings.history_db_url)
|
||||
|
||||
oidc_client = None
|
||||
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
|
||||
oidc_client = OIDCClient(
|
||||
client_id=settings.oidc_client_id,
|
||||
client_secret=settings.oidc_client_secret,
|
||||
server_metadata_url=settings.oidc_server_metadata_url,
|
||||
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:3000/auth/callback")
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
|
||||
"""Dependency to get the current authenticated user from the JWT token."""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise credentials_exception
|
||||
|
||||
user_id: str | None = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
|
||||
user = history_manager.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
34
backend/src/ea_chatbot/api/main.py
Normal file
34
backend/src/ea_chatbot/api/main.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ea_chatbot.api.routers import auth, history, artifacts, agent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI(
|
||||
title="Election Analytics Chatbot API",
|
||||
description="Backend API for the LangGraph-based Election Analytics Chatbot",
|
||||
version="0.1.0"
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Adjust for production
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(auth.router)
|
||||
app.include_router(history.router)
|
||||
app.include_router(artifacts.router)
|
||||
app.include_router(agent.router)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
0
backend/src/ea_chatbot/api/routers/__init__.py
Normal file
0
backend/src/ea_chatbot/api/routers/__init__.py
Normal file
162
backend/src/ea_chatbot/api/routers/agent.py
Normal file
162
backend/src/ea_chatbot/api/routers/agent.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
||||
from ea_chatbot.api.utils import convert_to_json_compatible
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.checkpoint import get_checkpointer
|
||||
from ea_chatbot.history.models import User as UserDB, Conversation
|
||||
from ea_chatbot.api.schemas import ChatRequest
|
||||
import io
|
||||
import base64
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["agent"])
|
||||
|
||||
async def stream_agent_events(
|
||||
message: str,
|
||||
thread_id: str,
|
||||
user_id: str,
|
||||
summary: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Generator that invokes the LangGraph agent and yields SSE formatted events.
|
||||
Persists assistant responses and plots to the database.
|
||||
"""
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": message,
|
||||
"summary": summary,
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
assistant_chunks: List[str] = []
|
||||
assistant_plots: List[bytes] = []
|
||||
final_response: str = ""
|
||||
new_summary: str = ""
|
||||
|
||||
try:
|
||||
async with get_checkpointer() as checkpointer:
|
||||
async for event in app.astream_events(
|
||||
initial_state,
|
||||
config,
|
||||
version="v2",
|
||||
checkpointer=checkpointer
|
||||
):
|
||||
kind = event.get("event")
|
||||
name = event.get("name")
|
||||
data = event.get("data", {})
|
||||
|
||||
# Standardize event for frontend
|
||||
output_event = {
|
||||
"type": kind,
|
||||
"name": name,
|
||||
"data": data
|
||||
}
|
||||
|
||||
# Buffer assistant chunks (summarizer and researcher might stream)
|
||||
if kind == "on_chat_model_stream" and name in ["summarizer", "researcher"]:
|
||||
chunk = data.get("chunk", "")
|
||||
# Use utility to safely extract text content from the chunk
|
||||
chunk_data = convert_to_json_compatible(chunk)
|
||||
if isinstance(chunk_data, dict) and "content" in chunk_data:
|
||||
assistant_chunks.append(str(chunk_data["content"]))
|
||||
else:
|
||||
assistant_chunks.append(str(chunk_data))
|
||||
|
||||
# Buffer and encode plots
|
||||
if kind == "on_chain_end" and name == "executor":
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "plots" in output:
|
||||
plots = output["plots"]
|
||||
encoded_plots = []
|
||||
for fig in plots:
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png")
|
||||
plot_bytes = buf.getvalue()
|
||||
assistant_plots.append(plot_bytes)
|
||||
encoded_plots.append(base64.b64encode(plot_bytes).decode('utf-8'))
|
||||
output_event["data"]["encoded_plots"] = encoded_plots
|
||||
|
||||
# Collect final response from terminal nodes
|
||||
if kind == "on_chain_end" and name in ["summarizer", "researcher", "clarification"]:
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "messages" in output:
|
||||
last_msg = output["messages"][-1]
|
||||
|
||||
# Use centralized utility to extract clean text content
|
||||
# Since convert_to_json_compatible returns a dict for BaseMessage,
|
||||
# we can extract 'content' from it.
|
||||
msg_data = convert_to_json_compatible(last_msg)
|
||||
if isinstance(msg_data, dict) and "content" in msg_data:
|
||||
final_response = msg_data["content"]
|
||||
else:
|
||||
final_response = str(msg_data)
|
||||
|
||||
# Collect new summary
|
||||
if kind == "on_chain_end" and name == "summarize_conversation":
|
||||
output = data.get("output", {})
|
||||
if isinstance(output, dict) and "summary" in output:
|
||||
new_summary = output["summary"]
|
||||
|
||||
# Convert to JSON compatible format to avoid serialization errors
|
||||
compatible_output = convert_to_json_compatible(output_event)
|
||||
yield f"data: {json.dumps(compatible_output)}\n\n"
|
||||
|
||||
# If we didn't get a final_response from node output, use buffered chunks
|
||||
if not final_response and assistant_chunks:
|
||||
final_response = "".join(assistant_chunks)
|
||||
|
||||
# Save assistant message to DB
|
||||
if final_response:
|
||||
history_manager.add_message(thread_id, "assistant", final_response, plots=assistant_plots)
|
||||
|
||||
# Update summary in DB
|
||||
if new_summary:
|
||||
history_manager.update_conversation_summary(thread_id, new_summary)
|
||||
|
||||
yield "data: {\"type\": \"done\"}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Agent execution failed: {str(e)}"
|
||||
history_manager.add_message(thread_id, "assistant", error_msg)
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': error_msg})}\n\n"
|
||||
|
||||
@router.post("/stream")
|
||||
async def chat_stream(
|
||||
request: ChatRequest,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Stream agent execution events via SSE.
|
||||
"""
|
||||
with history_manager.get_session() as session:
|
||||
conv = session.get(Conversation, request.thread_id)
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
if conv.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this conversation")
|
||||
|
||||
# Save user message immediately
|
||||
history_manager.add_message(request.thread_id, "user", request.message)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_agent_events(
|
||||
request.message,
|
||||
request.thread_id,
|
||||
current_user.id,
|
||||
request.summary or ""
|
||||
),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
45
backend/src/ea_chatbot/api/routers/artifacts.py
Normal file
45
backend/src/ea_chatbot/api/routers/artifacts.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager
|
||||
from ea_chatbot.history.models import Plot, Message, User as UserDB
|
||||
import io
|
||||
|
||||
router = APIRouter(prefix="/artifacts", tags=["artifacts"])
|
||||
|
||||
@router.get("/plots/{plot_id}")
|
||||
async def get_plot(
|
||||
plot_id: str,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Retrieve a binary plot image (PNG)."""
|
||||
with history_manager.get_session() as session:
|
||||
plot = session.get(Plot, plot_id)
|
||||
if not plot:
|
||||
raise HTTPException(status_code=404, detail="Plot not found")
|
||||
|
||||
# Verify ownership via message -> conversation -> user
|
||||
message = session.get(Message, plot.message_id)
|
||||
if not message:
|
||||
raise HTTPException(status_code=404, detail="Associated message not found")
|
||||
|
||||
# In a real app, we should check message.conversation.user_id == current_user.id
|
||||
# For now, we'll assume the client has the ID correctly.
|
||||
# But let's do a basic check since it's "secure artifact access".
|
||||
if message.conversation.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this artifact")
|
||||
|
||||
return Response(content=plot.image_data, media_type="image/png")
|
||||
|
||||
@router.get("/data/{message_id}")
|
||||
async def get_message_data(
|
||||
message_id: str,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve structured dataframe data associated with a message.
|
||||
Currently returns 404 as dataframes are not yet persisted in the DB.
|
||||
"""
|
||||
# TODO: Implement persistence for DataFrames in Phase 4 or a future track
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Structured data not found for this message"
|
||||
)
|
||||
86
backend/src/ea_chatbot/api/routers/auth.py
Normal file
86
backend/src/ea_chatbot/api/routers/auth.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from ea_chatbot.api.dependencies import history_manager, oidc_client, get_current_user
|
||||
from ea_chatbot.history.models import User as UserDB
|
||||
from ea_chatbot.api.schemas import Token, UserCreate, UserResponse
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(user_in: UserCreate):
|
||||
"""Register a new user."""
|
||||
user = history_manager.get_user(user_in.email)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User already exists"
|
||||
)
|
||||
|
||||
user = history_manager.create_user(
|
||||
email=user_in.email,
|
||||
password=user_in.password,
|
||||
display_name=user_in.display_name
|
||||
)
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": user.username,
|
||||
"display_name": user.display_name
|
||||
}
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
"""Login with email and password to get a JWT."""
|
||||
user = history_manager.authenticate_user(form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
@router.get("/oidc/login")
|
||||
async def oidc_login():
|
||||
"""Get the OIDC authorization URL."""
|
||||
if not oidc_client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_510_NOT_EXTENDED,
|
||||
detail="OIDC is not configured"
|
||||
)
|
||||
|
||||
url = oidc_client.get_login_url()
|
||||
return {"url": url}
|
||||
|
||||
@router.get("/oidc/callback", response_model=Token)
|
||||
async def oidc_callback(code: str):
|
||||
"""Handle the OIDC callback and issue a JWT."""
|
||||
if not oidc_client:
|
||||
raise HTTPException(status_code=status.HTTP_510_NOT_EXTENDED, detail="OIDC not configured")
|
||||
|
||||
try:
|
||||
token = oidc_client.exchange_code_for_token(code)
|
||||
user_info = oidc_client.get_user_info(token)
|
||||
email = user_info.get("email")
|
||||
name = user_info.get("name") or user_info.get("preferred_username")
|
||||
|
||||
if not email:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email not provided by OIDC")
|
||||
|
||||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"OIDC authentication failed: {str(e)}")
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(current_user: UserDB = Depends(get_current_user)):
|
||||
"""Get the current authenticated user's profile."""
|
||||
return {
|
||||
"id": str(current_user.id),
|
||||
"email": current_user.username,
|
||||
"display_name": current_user.display_name
|
||||
}
|
||||
99
backend/src/ea_chatbot/api/routers/history.py
Normal file
99
backend/src/ea_chatbot/api/routers/history.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
||||
from typing import List, Optional
|
||||
from ea_chatbot.api.dependencies import get_current_user, history_manager, settings
|
||||
from ea_chatbot.history.models import User as UserDB
|
||||
from ea_chatbot.api.schemas import ConversationResponse, MessageResponse, ConversationUpdate, ConversationCreate
|
||||
|
||||
router = APIRouter(prefix="/conversations", tags=["history"])
|
||||
|
||||
@router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_conversation(
|
||||
conv_in: ConversationCreate,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new conversation."""
|
||||
state = conv_in.data_state or settings.data_state
|
||||
conv = history_manager.create_conversation(
|
||||
user_id=current_user.id,
|
||||
data_state=state,
|
||||
name=conv_in.name,
|
||||
summary=conv_in.summary
|
||||
)
|
||||
return {
|
||||
"id": str(conv.id),
|
||||
"name": conv.name,
|
||||
"summary": conv.summary,
|
||||
"created_at": conv.created_at,
|
||||
"data_state": conv.data_state
|
||||
}
|
||||
|
||||
@router.get("", response_model=List[ConversationResponse])
|
||||
async def list_conversations(
|
||||
current_user: UserDB = Depends(get_current_user),
|
||||
data_state: Optional[str] = None
|
||||
):
|
||||
"""List all conversations for the authenticated user."""
|
||||
# Use settings default if not provided
|
||||
state = data_state or settings.data_state
|
||||
conversations = history_manager.get_conversations(current_user.id, state)
|
||||
return [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"name": c.name,
|
||||
"summary": c.summary,
|
||||
"created_at": c.created_at,
|
||||
"data_state": c.data_state
|
||||
} for c in conversations
|
||||
]
|
||||
|
||||
@router.get("/{conversation_id}/messages", response_model=List[MessageResponse])
|
||||
async def get_conversation_messages(
|
||||
conversation_id: str,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Get all messages for a specific conversation."""
|
||||
# TODO: Verify that the conversation belongs to the user
|
||||
messages = history_manager.get_messages(conversation_id)
|
||||
return [
|
||||
{
|
||||
"id": str(m.id),
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"created_at": m.created_at
|
||||
} for m in messages
|
||||
]
|
||||
|
||||
@router.patch("/{conversation_id}", response_model=ConversationResponse)
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
update: ConversationUpdate,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Rename or update the summary of a conversation."""
|
||||
conv = None
|
||||
if update.name:
|
||||
conv = history_manager.rename_conversation(conversation_id, update.name)
|
||||
if update.summary:
|
||||
conv = history_manager.update_conversation_summary(conversation_id, update.summary)
|
||||
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
return {
|
||||
"id": str(conv.id),
|
||||
"name": conv.name,
|
||||
"summary": conv.summary,
|
||||
"created_at": conv.created_at,
|
||||
"data_state": conv.data_state
|
||||
}
|
||||
|
||||
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a conversation."""
|
||||
success = history_manager.delete_conversation(conversation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
51
backend/src/ea_chatbot/api/schemas.py
Normal file
51
backend/src/ea_chatbot/api/schemas.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# --- Auth Schemas ---
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
display_name: Optional[str] = None
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
display_name: str
|
||||
|
||||
# --- History Schemas ---
|
||||
|
||||
class ConversationResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
summary: Optional[str] = None
|
||||
created_at: datetime
|
||||
data_state: str
|
||||
|
||||
class ConversationCreate(BaseModel):
|
||||
name: str
|
||||
data_state: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
content: str
|
||||
created_at: datetime
|
||||
# Plots are fetched separately via artifact endpoints
|
||||
|
||||
class ConversationUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
|
||||
# --- Agent Schemas ---
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
thread_id: str # Maps to conversation_id
|
||||
summary: Optional[str] = ""
|
||||
94
backend/src/ea_chatbot/api/utils.py
Normal file
94
backend/src/ea_chatbot/api/utils.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Union, Any, List, Dict
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
|
||||
Args:
|
||||
data: The payload data to encode.
|
||||
expires_delta: Optional expiration time delta.
|
||||
|
||||
Returns:
|
||||
str: The encoded JWT token.
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if expires_delta:
|
||||
expire = now + expires_delta
|
||||
else:
|
||||
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": now,
|
||||
"iss": "ea-chatbot-api"
|
||||
})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def decode_access_token(token: str) -> Optional[dict]:
|
||||
"""
|
||||
Decode a JWT access token.
|
||||
|
||||
Args:
|
||||
token: The token to decode.
|
||||
|
||||
Returns:
|
||||
Optional[dict]: The decoded payload if valid, None otherwise.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
def convert_to_json_compatible(obj: Any) -> Any:
|
||||
"""Recursively convert LangChain objects, Pydantic models, and others to JSON compatible formats."""
|
||||
if isinstance(obj, list):
|
||||
return [convert_to_json_compatible(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, BaseMessage):
|
||||
# Handle content that might be a list of blocks (e.g. from Gemini or OpenAI tools)
|
||||
content = obj.content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
elif isinstance(block, dict):
|
||||
if block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
# You could also handle other block types if needed
|
||||
content = "".join(text_parts)
|
||||
|
||||
# Prefer .text property if available (common in some message types)
|
||||
if hasattr(obj, "text") and isinstance(obj.text, str) and obj.text:
|
||||
content = obj.text
|
||||
|
||||
return {"type": obj.type, "content": content, **convert_to_json_compatible(obj.additional_kwargs)}
|
||||
elif isinstance(obj, BaseModel):
|
||||
return convert_to_json_compatible(obj.model_dump())
|
||||
elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel
|
||||
try:
|
||||
return convert_to_json_compatible(obj.model_dump())
|
||||
except Exception:
|
||||
return str(obj)
|
||||
elif hasattr(obj, "dict"): # Fallback for Pydantic v1 or other objects
|
||||
try:
|
||||
return convert_to_json_compatible(obj.dict())
|
||||
except Exception:
|
||||
return str(obj)
|
||||
elif hasattr(obj, "content"):
|
||||
return str(obj.content)
|
||||
elif isinstance(obj, (datetime, timezone)):
|
||||
return obj.isoformat()
|
||||
return obj
|
||||
451
backend/src/ea_chatbot/app.py
Normal file
451
backend/src/ea_chatbot/app.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import streamlit as st
|
||||
import asyncio
|
||||
import os
|
||||
import io
|
||||
from dotenv import load_dotenv
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.utils.helpers import merge_agent_state
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import OIDCClient, AuthType, get_user_auth_type
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize Config and Manager
|
||||
settings = Settings()
|
||||
history_manager = HistoryManager(settings.history_db_url)
|
||||
|
||||
# Initialize OIDC Client if configured
|
||||
oidc_client = None
|
||||
if settings.oidc_client_id and settings.oidc_client_secret and settings.oidc_server_metadata_url:
|
||||
oidc_client = OIDCClient(
|
||||
client_id=settings.oidc_client_id,
|
||||
client_secret=settings.oidc_client_secret,
|
||||
server_metadata_url=settings.oidc_server_metadata_url,
|
||||
# Redirect back to the same page
|
||||
redirect_uri=os.getenv("OIDC_REDIRECT_URI", "http://localhost:8501")
|
||||
)
|
||||
|
||||
# Initialize Logger
|
||||
logger = get_logger(level=settings.log_level, log_file="logs/app.jsonl")
|
||||
|
||||
# --- Authentication Helpers ---
|
||||
|
||||
def login_user(user):
|
||||
st.session_state.user = user
|
||||
st.session_state.messages = []
|
||||
st.session_state.summary = ""
|
||||
st.session_state.current_conversation_id = None
|
||||
st.rerun()
|
||||
|
||||
def logout_user():
|
||||
for key in list(st.session_state.keys()):
|
||||
del st.session_state[key]
|
||||
st.rerun()
|
||||
|
||||
def load_conversation(conv_id):
|
||||
messages = history_manager.get_messages(conv_id)
|
||||
formatted_messages = []
|
||||
for m in messages:
|
||||
# Convert DB models to session state dicts
|
||||
msg_dict = {
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"plots": [p.image_data for p in m.plots]
|
||||
}
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
st.session_state.messages = formatted_messages
|
||||
st.session_state.current_conversation_id = conv_id
|
||||
# Fetch summary from DB
|
||||
with history_manager.get_session() as session:
|
||||
from ea_chatbot.history.models import Conversation
|
||||
conv = session.get(Conversation, conv_id)
|
||||
st.session_state.summary = conv.summary if conv else ""
|
||||
st.rerun()
|
||||
|
||||
def main():
|
||||
st.set_page_config(
|
||||
page_title="Election Analytics Chatbot",
|
||||
page_icon="🗳️",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
# Check for OIDC Callback
|
||||
if "code" in st.query_params and oidc_client:
|
||||
code = st.query_params["code"]
|
||||
try:
|
||||
token = oidc_client.exchange_code_for_token(code)
|
||||
user_info = oidc_client.get_user_info(token)
|
||||
email = user_info.get("email")
|
||||
name = user_info.get("name") or user_info.get("preferred_username")
|
||||
|
||||
if email:
|
||||
user = history_manager.sync_user_from_oidc(email=email, display_name=name)
|
||||
# Clear query params
|
||||
st.query_params.clear()
|
||||
login_user(user)
|
||||
except Exception as e:
|
||||
st.error(f"OIDC Login failed: {str(e)}")
|
||||
|
||||
# Display Login Screen if not authenticated
|
||||
if "user" not in st.session_state:
|
||||
st.title("🗳️ Election Analytics Chatbot")
|
||||
|
||||
# Initialize Login State
|
||||
if "login_step" not in st.session_state:
|
||||
st.session_state.login_step = "email"
|
||||
if "login_email" not in st.session_state:
|
||||
st.session_state.login_email = ""
|
||||
|
||||
col1, col2 = st.columns([1, 1])
|
||||
|
||||
with col1:
|
||||
st.header("Login")
|
||||
|
||||
# Step 1: Identification
|
||||
if st.session_state.login_step == "email":
|
||||
st.write("Please enter your email to begin:")
|
||||
with st.form("email_form"):
|
||||
email_input = st.text_input("Email", value=st.session_state.login_email)
|
||||
submitted = st.form_submit_button("Next")
|
||||
|
||||
if submitted:
|
||||
if not email_input.strip():
|
||||
st.error("Email cannot be empty.")
|
||||
else:
|
||||
st.session_state.login_email = email_input.strip()
|
||||
auth_type = get_user_auth_type(st.session_state.login_email, history_manager)
|
||||
|
||||
if auth_type == AuthType.LOCAL:
|
||||
st.session_state.login_step = "login_password"
|
||||
elif auth_type == AuthType.OIDC:
|
||||
st.session_state.login_step = "oidc_login"
|
||||
else: # AuthType.NEW
|
||||
st.session_state.login_step = "register_details"
|
||||
st.rerun()
|
||||
|
||||
# Step 2a: Local Login
|
||||
elif st.session_state.login_step == "login_password":
|
||||
st.info(f"Welcome back, **{st.session_state.login_email}**!")
|
||||
with st.form("password_form"):
|
||||
password = st.text_input("Password", type="password")
|
||||
|
||||
col_login, col_back = st.columns([1, 1])
|
||||
submitted = col_login.form_submit_button("Login")
|
||||
back = col_back.form_submit_button("Back")
|
||||
|
||||
if back:
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
if submitted:
|
||||
user = history_manager.authenticate_user(st.session_state.login_email, password)
|
||||
if user:
|
||||
login_user(user)
|
||||
else:
|
||||
st.error("Invalid email or password")
|
||||
|
||||
# Step 2b: Registration
|
||||
elif st.session_state.login_step == "register_details":
|
||||
st.info(f"Create an account for **{st.session_state.login_email}**")
|
||||
with st.form("register_form"):
|
||||
reg_name = st.text_input("Display Name")
|
||||
reg_password = st.text_input("Password", type="password")
|
||||
|
||||
col_reg, col_back = st.columns([1, 1])
|
||||
submitted = col_reg.form_submit_button("Register & Login")
|
||||
back = col_back.form_submit_button("Back")
|
||||
|
||||
if back:
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
if submitted:
|
||||
if not reg_password:
|
||||
st.error("Password is required for registration.")
|
||||
else:
|
||||
user = history_manager.create_user(st.session_state.login_email, reg_password, reg_name)
|
||||
st.success("Registered! Logging in...")
|
||||
login_user(user)
|
||||
|
||||
# Step 2c: OIDC Redirection
|
||||
elif st.session_state.login_step == "oidc_login":
|
||||
st.info(f"**{st.session_state.login_email}** is configured for Single Sign-On (SSO).")
|
||||
|
||||
col_sso, col_back = st.columns([1, 1])
|
||||
|
||||
with col_sso:
|
||||
if oidc_client:
|
||||
login_url = oidc_client.get_login_url()
|
||||
st.link_button("Login with SSO", login_url, type="primary", use_container_width=True)
|
||||
else:
|
||||
st.error("OIDC is not configured.")
|
||||
|
||||
with col_back:
|
||||
if st.button("Back", use_container_width=True):
|
||||
st.session_state.login_step = "email"
|
||||
st.rerun()
|
||||
|
||||
with col2:
|
||||
if oidc_client:
|
||||
st.header("Single Sign-On")
|
||||
st.write("Login with your organizational account.")
|
||||
if st.button("Login with SSO"):
|
||||
login_url = oidc_client.get_login_url()
|
||||
st.link_button("Go to **YXXU**", login_url, type="primary")
|
||||
else:
|
||||
st.info("SSO is not configured.")
|
||||
|
||||
st.stop()
|
||||
|
||||
# --- Main App (Authenticated) ---
|
||||
|
||||
user = st.session_state.user
|
||||
|
||||
# Sidebar configuration
|
||||
with st.sidebar:
|
||||
st.title(f"Hi, {user.display_name or user.username}!")
|
||||
|
||||
if st.button("Logout"):
|
||||
logout_user()
|
||||
|
||||
st.divider()
|
||||
|
||||
st.header("History")
|
||||
if st.button("➕ New Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.session_state.summary = ""
|
||||
st.session_state.current_conversation_id = None
|
||||
st.rerun()
|
||||
|
||||
# List conversations for the current user and data state
|
||||
conversations = history_manager.get_conversations(user.id, settings.data_state)
|
||||
|
||||
for conv in conversations:
|
||||
col_c, col_r, col_d = st.columns([0.7, 0.15, 0.15])
|
||||
|
||||
is_current = st.session_state.get("current_conversation_id") == conv.id
|
||||
label = f"💬 {conv.name}" if not is_current else f"👉 {conv.name}"
|
||||
|
||||
if col_c.button(label, key=f"conv_{conv.id}", use_container_width=True):
|
||||
load_conversation(conv.id)
|
||||
|
||||
if col_r.button("✏️", key=f"ren_{conv.id}"):
|
||||
st.session_state.renaming_id = conv.id
|
||||
|
||||
if col_d.button("🗑️", key=f"del_{conv.id}"):
|
||||
if history_manager.delete_conversation(conv.id):
|
||||
if is_current:
|
||||
st.session_state.current_conversation_id = None
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
|
||||
# Rename dialog
|
||||
if st.session_state.get("renaming_id"):
|
||||
rid = st.session_state.renaming_id
|
||||
with st.form("rename_form"):
|
||||
new_name = st.text_input("New Name")
|
||||
if st.form_submit_button("Save"):
|
||||
history_manager.rename_conversation(rid, new_name)
|
||||
del st.session_state.renaming_id
|
||||
st.rerun()
|
||||
if st.form_submit_button("Cancel"):
|
||||
del st.session_state.renaming_id
|
||||
st.rerun()
|
||||
|
||||
st.divider()
|
||||
st.header("Settings")
|
||||
# Check for DEV_MODE env var (defaults to False)
|
||||
default_dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
|
||||
dev_mode = st.checkbox("Dev Mode", value=default_dev_mode, help="Enable to see code generation and raw reasoning steps.")
|
||||
|
||||
st.title("🗳️ Election Analytics Chatbot")
|
||||
|
||||
# Initialize chat history state
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "summary" not in st.session_state:
|
||||
st.session_state.summary = ""
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
if message.get("plan") and dev_mode:
|
||||
with st.expander("Reasoning Plan"):
|
||||
st.code(message["plan"], language="yaml")
|
||||
if message.get("code") and dev_mode:
|
||||
with st.expander("Generated Code"):
|
||||
st.code(message["code"], language="python")
|
||||
|
||||
st.markdown(message["content"])
|
||||
if message.get("plots"):
|
||||
for plot_data in message["plots"]:
|
||||
# If plot_data is bytes, convert to image
|
||||
if isinstance(plot_data, bytes):
|
||||
st.image(plot_data)
|
||||
else:
|
||||
# Fallback for old session state or non-binary
|
||||
st.pyplot(plot_data)
|
||||
if message.get("dfs"):
|
||||
for df_name, df in message["dfs"].items():
|
||||
st.subheader(f"Data: {df_name}")
|
||||
st.dataframe(df)
|
||||
|
||||
# Accept user input
|
||||
if prompt := st.chat_input("Ask a question about election data..."):
|
||||
# Ensure we have a conversation ID
|
||||
if not st.session_state.get("current_conversation_id"):
|
||||
# Auto-create conversation
|
||||
conv_name = (prompt[:30] + '...') if len(prompt) > 30 else prompt
|
||||
conv = history_manager.create_conversation(user.id, settings.data_state, conv_name)
|
||||
st.session_state.current_conversation_id = conv.id
|
||||
|
||||
conv_id = st.session_state.current_conversation_id
|
||||
|
||||
# Save user message to DB
|
||||
history_manager.add_message(conv_id, "user", prompt)
|
||||
|
||||
# Add user message to session state
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message in chat message container
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# Prepare graph input
|
||||
initial_state: AgentState = {
|
||||
"messages": st.session_state.messages[:-1], # Pass history (excluding the current prompt)
|
||||
"question": prompt,
|
||||
"summary": st.session_state.summary,
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Placeholder for graph output
|
||||
with st.chat_message("assistant"):
|
||||
final_state = initial_state
|
||||
# Real-time node updates
|
||||
with st.status("Thinking...", expanded=True) as status:
|
||||
try:
|
||||
# Use app.stream to capture node transitions
|
||||
for event in app.stream(initial_state):
|
||||
for node_name, state_update in event.items():
|
||||
prev_error = final_state.get("error")
|
||||
# Use helper to merge state correctly (appending messages/plots, updating dfs)
|
||||
final_state = merge_agent_state(final_state, state_update)
|
||||
|
||||
if node_name == "query_analyzer":
|
||||
analysis = state_update.get("analysis", {})
|
||||
next_action = state_update.get("next_action", "unknown")
|
||||
status.write(f"🔍 **Analyzed Query:**")
|
||||
for k,v in analysis.items():
|
||||
status.write(f"- {k:<8}: {v}")
|
||||
status.markdown(f"Next Step: {next_action.capitalize()}")
|
||||
|
||||
elif node_name == "planner":
|
||||
status.write("📋 **Plan Generated**")
|
||||
# Render artifacts
|
||||
if state_update.get("plan") and dev_mode:
|
||||
with st.expander("Reasoning Plan", expanded=True):
|
||||
st.code(state_update["plan"], language="yaml")
|
||||
|
||||
elif node_name == "researcher":
|
||||
status.write("🌐 **Research Complete**")
|
||||
if state_update.get("messages") and dev_mode:
|
||||
for msg in state_update["messages"]:
|
||||
# Extract content from BaseMessage or show raw string
|
||||
content = getattr(msg, "text", msg.content)
|
||||
status.markdown(content)
|
||||
|
||||
elif node_name == "coder":
|
||||
status.write("💻 **Code Generated**")
|
||||
if state_update.get("code") and dev_mode:
|
||||
with st.expander("Generated Code"):
|
||||
st.code(state_update["code"], language="python")
|
||||
|
||||
elif node_name == "error_corrector":
|
||||
status.write("🛠️ **Fixing Execution Error...**")
|
||||
if prev_error:
|
||||
truncated_error = prev_error.strip()
|
||||
if len(truncated_error) > 180:
|
||||
truncated_error = truncated_error[:180] + "..."
|
||||
status.write(f"Previous error: {truncated_error}")
|
||||
if state_update.get("code") and dev_mode:
|
||||
with st.expander("Corrected Code"):
|
||||
st.code(state_update["code"], language="python")
|
||||
|
||||
elif node_name == "executor":
|
||||
if state_update.get("error"):
|
||||
if dev_mode:
|
||||
status.write(f"❌ **Execution Error:** {state_update.get('error')}...")
|
||||
else:
|
||||
status.write(f"❌ **Execution Error:** {state_update.get('error')[:100]}...")
|
||||
else:
|
||||
status.write("✅ **Execution Successful**")
|
||||
if state_update.get("plots"):
|
||||
status.write(f"📊 Generated {len(state_update['plots'])} plot(s)")
|
||||
|
||||
elif node_name == "summarizer":
|
||||
status.write("📝 **Summarizing Results...**")
|
||||
|
||||
|
||||
status.update(label="Complete!", state="complete", expanded=False)
|
||||
|
||||
except Exception as e:
|
||||
status.update(label="Error!", state="error")
|
||||
st.error(f"Error during graph execution: {str(e)}")
|
||||
|
||||
# Extract results
|
||||
response_text: str = ""
|
||||
if final_state.get("messages"):
|
||||
# The last message is the Assistant's response
|
||||
last_msg = final_state["messages"][-1]
|
||||
response_text = getattr(last_msg, "text", str(last_msg.content))
|
||||
st.markdown(response_text)
|
||||
|
||||
# Collect plot bytes for saving to DB
|
||||
plot_bytes_list = []
|
||||
if final_state.get("plots"):
|
||||
for fig in final_state["plots"]:
|
||||
st.pyplot(fig)
|
||||
# Convert fig to bytes
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png")
|
||||
plot_bytes_list.append(buf.getvalue())
|
||||
|
||||
if final_state.get("dfs"):
|
||||
for df_name, df in final_state["dfs"].items():
|
||||
st.subheader(f"Data: {df_name}")
|
||||
st.dataframe(df)
|
||||
|
||||
# Save assistant message to DB
|
||||
history_manager.add_message(conv_id, "assistant", response_text, plots=plot_bytes_list)
|
||||
|
||||
# Update summary in DB
|
||||
new_summary = final_state.get("summary", "")
|
||||
if new_summary:
|
||||
history_manager.update_conversation_summary(conv_id, new_summary)
|
||||
|
||||
# Store assistant response in session history
|
||||
st.session_state.messages.append({
|
||||
"role": "assistant",
|
||||
"content": response_text,
|
||||
"plan": final_state.get("plan"),
|
||||
"code": final_state.get("code"),
|
||||
"plots": plot_bytes_list,
|
||||
"dfs": final_state.get("dfs")
|
||||
})
|
||||
st.session_state.summary = new_summary
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
97
backend/src/ea_chatbot/auth.py
Normal file
97
backend/src/ea_chatbot/auth.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import requests
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
|
||||
class AuthType(Enum):
|
||||
LOCAL = "local"
|
||||
OIDC = "oidc"
|
||||
NEW = "new"
|
||||
|
||||
def get_user_auth_type(email: str, history_manager: Any) -> AuthType:
|
||||
"""
|
||||
Determine the authentication type for a given email.
|
||||
|
||||
Args:
|
||||
email: The user's email address.
|
||||
history_manager: Instance of HistoryManager to check the DB.
|
||||
|
||||
Returns:
|
||||
AuthType: LOCAL if password exists, OIDC if user exists but no password, NEW otherwise.
|
||||
"""
|
||||
user = history_manager.get_user(email)
|
||||
|
||||
if not user:
|
||||
return AuthType.NEW
|
||||
|
||||
if user.password_hash:
|
||||
return AuthType.LOCAL
|
||||
|
||||
return AuthType.OIDC
|
||||
|
||||
class OIDCClient:
|
||||
"""
|
||||
Client for OIDC Authentication using Authlib.
|
||||
Designed to work within a Streamlit environment.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
server_metadata_url: str,
|
||||
redirect_uri: str = "http://localhost:8501"
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.server_metadata_url = server_metadata_url
|
||||
self.redirect_uri = redirect_uri
|
||||
self.oauth_session = OAuth2Session(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
scope="openid email profile"
|
||||
)
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
|
||||
def fetch_metadata(self) -> Dict[str, Any]:
|
||||
"""Fetch OIDC provider metadata if not already fetched."""
|
||||
if not self.metadata:
|
||||
self.metadata = requests.get(self.server_metadata_url).json()
|
||||
return self.metadata
|
||||
|
||||
def get_login_url(self) -> str:
|
||||
"""Generate the authorization URL."""
|
||||
metadata = self.fetch_metadata()
|
||||
authorization_endpoint = metadata.get("authorization_endpoint")
|
||||
if not authorization_endpoint:
|
||||
raise ValueError("authorization_endpoint not found in OIDC metadata")
|
||||
|
||||
uri, state = self.oauth_session.create_authorization_url(authorization_endpoint)
|
||||
return uri
|
||||
|
||||
def exchange_code_for_token(self, code: str) -> Dict[str, Any]:
|
||||
"""Exchange the authorization code for an access token."""
|
||||
metadata = self.fetch_metadata()
|
||||
token_endpoint = metadata.get("token_endpoint")
|
||||
if not token_endpoint:
|
||||
raise ValueError("token_endpoint not found in OIDC metadata")
|
||||
|
||||
token = self.oauth_session.fetch_token(
|
||||
token_endpoint,
|
||||
code=code,
|
||||
client_secret=self.client_secret
|
||||
)
|
||||
return token
|
||||
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Fetch user information using the access token."""
|
||||
metadata = self.fetch_metadata()
|
||||
userinfo_endpoint = metadata.get("userinfo_endpoint")
|
||||
if not userinfo_endpoint:
|
||||
raise ValueError("userinfo_endpoint not found in OIDC metadata")
|
||||
|
||||
# Set the token on the session so it's used in the request
|
||||
self.oauth_session.token = token
|
||||
resp = self.oauth_session.get(userinfo_endpoint)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
52
backend/src/ea_chatbot/config.py
Normal file
52
backend/src/ea_chatbot/config.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Configuration for a specific LLM node."""
|
||||
provider: str = "openai"
|
||||
model: str = "gpt-5-mini"
|
||||
temperature: float = 0.0
|
||||
max_tokens: Optional[int] = None
|
||||
provider_specific: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Global application settings."""
|
||||
|
||||
data_dir: str = "data"
|
||||
data_state: str = "new_jersey"
|
||||
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
||||
|
||||
# Voter Database configuration
|
||||
db_host: str = Field(default="localhost", alias="DB_HOST")
|
||||
db_port: int = Field(default=5432, alias="DB_PORT")
|
||||
db_name: str = Field(default="blockdata", alias="DB_NAME")
|
||||
db_user: str = Field(default="user", alias="DB_USER")
|
||||
db_pswd: str = Field(default="password", alias="DB_PSWD")
|
||||
db_table: str = Field(default="rd_gc_voters_nj", alias="DB_TABLE")
|
||||
|
||||
# Application/History Database
|
||||
history_db_url: str = Field(default="postgresql://user:password@localhost:5433/ea_history", alias="HISTORY_DB_URL")
|
||||
|
||||
# JWT Configuration
|
||||
secret_key: str = Field(default="change-me-in-production", alias="SECRET_KEY")
|
||||
algorithm: str = Field(default="HS256", alias="ALGORITHM")
|
||||
access_token_expire_minutes: int = Field(default=30, alias="ACCESS_TOKEN_EXPIRE_MINUTES")
|
||||
|
||||
# OIDC Configuration
|
||||
oidc_client_id: Optional[str] = Field(default=None, alias="OIDC_CLIENT_ID")
|
||||
oidc_client_secret: Optional[str] = Field(default=None, alias="OIDC_CLIENT_SECRET")
|
||||
oidc_server_metadata_url: Optional[str] = Field(default=None, alias="OIDC_SERVER_METADATA_URL")
|
||||
|
||||
# Default configurations for each node
|
||||
query_analyzer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
planner_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
coder_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
summarizer_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
researcher_llm: LLMConfig = Field(default_factory=lambda: LLMConfig(model="gpt-5-mini", temperature=0.0))
|
||||
|
||||
# Allow nested env vars like QUERY_ANALYZER_LLM__MODEL
|
||||
model_config = SettingsConfigDict(env_nested_delimiter='__', env_prefix='')
|
||||
0
backend/src/ea_chatbot/graph/__init__.py
Normal file
0
backend/src/ea_chatbot/graph/__init__.py
Normal file
44
backend/src/ea_chatbot/graph/checkpoint.py
Normal file
44
backend/src/ea_chatbot/graph/checkpoint.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import contextlib
|
||||
from typing import AsyncGenerator
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
_pool = None
|
||||
|
||||
def get_pool() -> AsyncConnectionPool:
|
||||
"""Get or create the async connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = AsyncConnectionPool(
|
||||
conninfo=settings.history_db_url,
|
||||
max_size=20,
|
||||
kwargs={"autocommit": True, "prepare_threshold": 0},
|
||||
open=False, # Don't open automatically on init
|
||||
)
|
||||
return _pool
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]:
|
||||
"""
|
||||
Context manager to get a PostgresSaver checkpointer.
|
||||
Ensures that the checkpointer is properly initialized and the connection is managed.
|
||||
"""
|
||||
pool = get_pool()
|
||||
# Ensure pool is open
|
||||
if pool.closed:
|
||||
await pool.open()
|
||||
|
||||
async with pool.connection() as conn:
|
||||
checkpointer = AsyncPostgresSaver(conn)
|
||||
# Ensure the necessary tables exist
|
||||
await checkpointer.setup()
|
||||
yield checkpointer
|
||||
|
||||
async def close_pool():
|
||||
"""Close the connection pool."""
|
||||
global _pool
|
||||
if _pool and not _pool.closed:
|
||||
await _pool.close()
|
||||
0
backend/src/ea_chatbot/graph/nodes/__init__.py
Normal file
0
backend/src/ea_chatbot/graph/nodes/__init__.py
Normal file
45
backend/src/ea_chatbot/graph/nodes/clarification.py
Normal file
45
backend/src/ea_chatbot/graph/nodes/clarification.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def clarification_node(state: AgentState) -> dict:
|
||||
"""Ask the user for missing information or clarifications."""
|
||||
question = state["question"]
|
||||
analysis = state.get("analysis", {})
|
||||
ambiguities = analysis.get("ambiguities", [])
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("clarification")
|
||||
|
||||
logger.info(f"Generating clarification for {len(ambiguities)} ambiguities.")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.query_analyzer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
system_prompt = """You are a Clarification Specialist. Your role is to identify what information is missing from a user's request to perform a data analysis or research task.
|
||||
Based on the analysis of the user's question, formulate a polite and concise request for the missing information."""
|
||||
|
||||
prompt = f"""Original Question: {question}
|
||||
Missing/Ambiguous Information: {', '.join(ambiguities) if ambiguities else 'Unknown ambiguities'}
|
||||
|
||||
Please ask the user for the necessary details."""
|
||||
|
||||
messages = [
|
||||
("system", system_prompt),
|
||||
("user", prompt)
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Clarification generated.[/bold green]")
|
||||
return {
|
||||
"messages": [response],
|
||||
"next_action": "end" # To indicate we are done for now
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate clarification: {str(e)}")
|
||||
raise e
|
||||
47
backend/src/ea_chatbot/graph/nodes/coder.py
Normal file
47
backend/src/ea_chatbot/graph/nodes/coder.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def coder_node(state: AgentState) -> dict:
|
||||
"""Generate Python code based on the plan and data summary."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "None")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("coder")
|
||||
|
||||
logger.info("Generating Python code...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
# Always provide data summary
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_code = "" # Placeholder
|
||||
|
||||
messages = CODE_GENERATOR_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
database_description=database_description,
|
||||
code_exec_results=code_output,
|
||||
example_code=example_code
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Code generated.[/bold green]")
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None # Clear previous errors on new code generation
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate code: {str(e)}")
|
||||
raise e
|
||||
44
backend/src/ea_chatbot/graph/nodes/error_corrector.py
Normal file
44
backend/src/ea_chatbot/graph/nodes/error_corrector.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def error_corrector_node(state: AgentState) -> dict:
|
||||
"""Fix the code based on the execution error."""
|
||||
code = state.get("code", "")
|
||||
error = state.get("error", "Unknown error")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("error_corrector")
|
||||
|
||||
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
|
||||
logger.info("Attempting to correct the code...")
|
||||
|
||||
# Reuse coder LLM config or add a new one. Using coder_llm for now.
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
messages = ERROR_CORRECTOR_PROMPT.format_messages(
|
||||
code=code,
|
||||
error=error
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Correction generated.[/bold green]")
|
||||
|
||||
current_iterations = state.get("iterations", 0)
|
||||
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None, # Clear error after fix attempt
|
||||
"iterations": current_iterations + 1
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to correct code: {str(e)}")
|
||||
raise e
|
||||
102
backend/src/ea_chatbot/graph/nodes/executor.py
Normal file
102
backend/src/ea_chatbot/graph/nodes/executor.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
def executor_node(state: AgentState) -> dict:
|
||||
"""Execute the Python code and capture output, plots, and dataframes."""
|
||||
code = state.get("code")
|
||||
logger = get_logger("executor")
|
||||
|
||||
if not code:
|
||||
logger.error("No code provided to executor.")
|
||||
return {"error": "No code provided to executor."}
|
||||
|
||||
logger.info("Executing Python code...")
|
||||
settings = Settings()
|
||||
|
||||
db_settings: "DBSettings" = {
|
||||
"host": settings.db_host,
|
||||
"port": settings.db_port,
|
||||
"user": settings.db_user,
|
||||
"pswd": settings.db_pswd,
|
||||
"db": settings.db_name,
|
||||
"table": settings.db_table
|
||||
}
|
||||
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize local variables for execution
|
||||
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
|
||||
local_vars = {
|
||||
'db': db_client,
|
||||
'plots': [],
|
||||
'pd': pd
|
||||
}
|
||||
|
||||
stdout_buffer = io.StringIO()
|
||||
error = None
|
||||
code_output = ""
|
||||
plots = []
|
||||
dfs = {}
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_buffer):
|
||||
# Execute the code in the context of local_vars
|
||||
exec(code, {}, local_vars)
|
||||
|
||||
code_output = stdout_buffer.getvalue()
|
||||
|
||||
# Limit the output length if it's too long
|
||||
if code_output.count('\n') > 32:
|
||||
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...'
|
||||
|
||||
# Extract plots
|
||||
raw_plots = local_vars.get('plots', [])
|
||||
if isinstance(raw_plots, list):
|
||||
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
||||
|
||||
# Extract DataFrames that were likely intended for display
|
||||
# We look for DataFrames in local_vars that were mentioned in the code
|
||||
for key, value in local_vars.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
# Heuristic: if the variable name is in the code, it might be a result DF
|
||||
if key in code:
|
||||
dfs[key] = value
|
||||
|
||||
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
|
||||
|
||||
except Exception as e:
|
||||
# Capture the traceback
|
||||
exc_type, exc_value, tb = sys.exc_info()
|
||||
full_traceback = traceback.format_exc()
|
||||
|
||||
# Filter traceback to show only the relevant part (the executed string)
|
||||
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
|
||||
error = '\n'.join(filtered_tb_lines)
|
||||
if error:
|
||||
error += '\n'
|
||||
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
||||
|
||||
logger.error(f"Execution failed: {str(e)}")
|
||||
|
||||
# If we have an error, we still might want to see partial stdout
|
||||
code_output = stdout_buffer.getvalue()
|
||||
|
||||
return {
|
||||
"code_output": code_output,
|
||||
"error": error,
|
||||
"plots": plots,
|
||||
"dfs": dfs
|
||||
}
|
||||
51
backend/src/ea_chatbot/graph/nodes/planner.py
Normal file
51
backend/src/ea_chatbot/graph/nodes/planner.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import yaml
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
|
||||
def planner_node(state: AgentState) -> dict:
|
||||
"""Generate a structured plan based on the query analysis."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("planner")
|
||||
|
||||
logger.info("Generating task plan...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.planner_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(TaskPlanResponse)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
# Always provide data summary; LLM decides relevance.
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_plan = ""
|
||||
|
||||
messages = PLANNER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary,
|
||||
database_description=database_description,
|
||||
example_plan=example_plan
|
||||
)
|
||||
|
||||
# Generate the structured plan
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
# Convert the structured response back to YAML string for the state
|
||||
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False)
|
||||
logger.info("[bold green]Plan generated successfully.[/bold green]")
|
||||
return {"plan": plan_yaml}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate plan: {str(e)}")
|
||||
raise e
|
||||
73
backend/src/ea_chatbot/graph/nodes/query_analyzer.py
Normal file
73
backend/src/ea_chatbot/graph/nodes/query_analyzer.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import List, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
|
||||
|
||||
class QueryAnalysis(BaseModel):
|
||||
"""Analysis of the user's query."""
|
||||
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
|
||||
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
|
||||
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
|
||||
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
|
||||
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
|
||||
|
||||
def query_analyzer_node(state: AgentState) -> dict:
|
||||
"""Analyze the user's question and determine the next course of action."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])
|
||||
summary = state.get("summary", "")
|
||||
|
||||
# Keep last 3 turns (6 messages)
|
||||
history = history[-6:]
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("query_analyzer")
|
||||
|
||||
logger.info(f"Analyzing question: [italic]\"{question}\"[/italic]")
|
||||
|
||||
# Initialize the LLM with structured output using the factory
|
||||
# Pass logging callback to track LLM usage
|
||||
llm = get_llm_model(
|
||||
settings.query_analyzer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(QueryAnalysis)
|
||||
|
||||
# Prepare messages using the prompt template
|
||||
messages = QUERY_ANALYZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
try:
|
||||
# Invoke the structured LLM directly with the list of messages
|
||||
analysis_result = structured_llm.invoke(messages)
|
||||
analysis_result = QueryAnalysis.model_validate(analysis_result)
|
||||
|
||||
analysis_dict = analysis_result.model_dump()
|
||||
analysis_dict.pop("next_action")
|
||||
next_action = analysis_result.next_action
|
||||
|
||||
logger.info(f"Analysis complete. Next action: [bold magenta]{next_action}[/bold magenta]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during query analysis: {str(e)}")
|
||||
analysis_dict = {
|
||||
"data_required": [],
|
||||
"unknowns": [],
|
||||
"ambiguities": [f"Error during analysis: {str(e)}"],
|
||||
"conditions": []
|
||||
}
|
||||
next_action = "clarify"
|
||||
|
||||
return {
|
||||
"analysis": analysis_dict,
|
||||
"next_action": next_action,
|
||||
"iterations": 0
|
||||
}
|
||||
|
||||
|
||||
60
backend/src/ea_chatbot/graph/nodes/researcher.py
Normal file
60
backend/src/ea_chatbot/graph/nodes/researcher.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
||||
|
||||
def researcher_node(state: AgentState) -> dict:
|
||||
"""Handle general research queries or web searches."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("researcher")
|
||||
|
||||
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
|
||||
|
||||
# Use researcher_llm from settings
|
||||
llm = get_llm_model(
|
||||
settings.researcher_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
messages = RESEARCHER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
# Provider-aware tool binding
|
||||
try:
|
||||
if isinstance(llm, ChatGoogleGenerativeAI):
|
||||
# Native Google Search for Gemini
|
||||
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
||||
elif isinstance(llm, ChatOpenAI):
|
||||
# Native Web Search for OpenAI (built-in tool)
|
||||
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
||||
else:
|
||||
# Fallback for other providers that might not support these specific search tools
|
||||
llm_with_tools = llm
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.")
|
||||
llm_with_tools = llm
|
||||
|
||||
try:
|
||||
response = llm_with_tools.invoke(messages)
|
||||
logger.info("[bold green]Research complete.[/bold green]")
|
||||
return {
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Research failed: {str(e)}")
|
||||
raise e
|
||||
52
backend/src/ea_chatbot/graph/nodes/summarize_conversation.py
Normal file
52
backend/src/ea_chatbot/graph/nodes/summarize_conversation.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def summarize_conversation_node(state: AgentState) -> dict:
|
||||
"""Update the conversation summary based on the latest interaction."""
|
||||
summary = state.get("summary", "")
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# We only summarize if there are messages
|
||||
if not messages:
|
||||
return {}
|
||||
|
||||
# Get the last turn (User + Assistant)
|
||||
last_turn = messages[-2:]
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("summarize_conversation")
|
||||
|
||||
logger.info("Updating conversation summary...")
|
||||
|
||||
# Use summarizer_llm for this task as well
|
||||
llm = get_llm_model(
|
||||
settings.summarizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
if summary:
|
||||
prompt = (
|
||||
f"This is a summary of the conversation so far: {summary}\n\n"
|
||||
"Extend the summary by taking into account the new messages above."
|
||||
)
|
||||
else:
|
||||
prompt = "Create a summary of the conversation above."
|
||||
|
||||
# Construct the messages for the summarization LLM
|
||||
summarization_messages = [
|
||||
SystemMessage(content=f"Current summary: {summary}" if summary else "You are a helpful assistant that summarizes conversations."),
|
||||
HumanMessage(content=f"Recent messages:\n{last_turn}\n\n{prompt}\n\nKeep the summary concise and focused on the key topics and data points discussed.")
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(summarization_messages)
|
||||
new_summary = response.content
|
||||
logger.info("[bold green]Conversation summary updated.[/bold green]")
|
||||
return {"summary": new_summary}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update summary: {str(e)}")
|
||||
# If summarization fails, we keep the old one
|
||||
return {"summary": summary}
|
||||
44
backend/src/ea_chatbot/graph/nodes/summarizer.py
Normal file
44
backend/src/ea_chatbot/graph/nodes/summarizer.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
|
||||
|
||||
def summarizer_node(state: AgentState) -> dict:
|
||||
"""Summarize the code execution results into a final answer."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "")
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("summarizer")
|
||||
|
||||
logger.info("Generating final summary...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.summarizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
messages = SUMMARIZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
code_output=code_output,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Summary generated.[/bold green]")
|
||||
|
||||
# Return the final message to be added to the state
|
||||
return {
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {str(e)}")
|
||||
raise e
|
||||
10
backend/src/ea_chatbot/graph/prompts/__init__.py
Normal file
10
backend/src/ea_chatbot/graph/prompts/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .query_analyzer import QUERY_ANALYZER_PROMPT
|
||||
from .planner import PLANNER_PROMPT
|
||||
from .coder import CODE_GENERATOR_PROMPT, ERROR_CORRECTOR_PROMPT
|
||||
|
||||
__all__ = [
|
||||
"QUERY_ANALYZER_PROMPT",
|
||||
"PLANNER_PROMPT",
|
||||
"CODE_GENERATOR_PROMPT",
|
||||
"ERROR_CORRECTOR_PROMPT",
|
||||
]
|
||||
64
backend/src/ea_chatbot/graph/prompts/coder.py
Normal file
64
backend/src/ea_chatbot/graph/prompts/coder.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
CODE_GENERATOR_SYSTEM = """You are an AI data analyst and your job is to assist users with data analysis and coding tasks.
|
||||
The user will provide a task and a plan.
|
||||
|
||||
**Data Access:**
|
||||
- A database client is available as a variable named `db`.
|
||||
- You MUST use `db.query_df(sql_query)` to execute SQL queries and retrieve data as a Pandas DataFrame.
|
||||
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
|
||||
- The database schema is described in the prompt. Use it to construct valid SQL queries.
|
||||
|
||||
**Plotting:**
|
||||
- If you need to plot any data, use the `plots` list to store the figures.
|
||||
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
|
||||
- Do not use `plt.show()` as it will render the plot and cause an error.
|
||||
|
||||
**Code Requirements:**
|
||||
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
|
||||
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
|
||||
- Always include print statements to output the results of your code.
|
||||
- Use `db.query_df("SELECT ...")` to get data."""
|
||||
|
||||
CODE_GENERATOR_USER = """TASK:
|
||||
{question}
|
||||
|
||||
PLAN:
|
||||
```yaml
|
||||
{plan}
|
||||
```
|
||||
|
||||
AVAILABLE DATA SUMMARY (Database Schema):
|
||||
{database_description}
|
||||
|
||||
CODE EXECUTION OF THE PREVIOUS TASK RESULTED IN:
|
||||
{code_exec_results}
|
||||
|
||||
{example_code}"""
|
||||
|
||||
ERROR_CORRECTOR_SYSTEM = """The execution of the code resulted in an error.
|
||||
Return a complete, corrected python code that incorporates the fixes for the error.
|
||||
|
||||
**Reminders:**
|
||||
- You have access to a database client via the variable `db`.
|
||||
- Use `db.query_df(sql)` to run queries.
|
||||
- Use `plots.append(fig)` for plots.
|
||||
- Always include imports and print statements."""
|
||||
|
||||
ERROR_CORRECTOR_USER = """FAILED CODE:
|
||||
```python
|
||||
{code}
|
||||
```
|
||||
|
||||
ERROR:
|
||||
{error}"""
|
||||
|
||||
CODE_GENERATOR_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", CODE_GENERATOR_SYSTEM),
|
||||
("human", CODE_GENERATOR_USER),
|
||||
])
|
||||
|
||||
ERROR_CORRECTOR_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", ERROR_CORRECTOR_SYSTEM),
|
||||
("human", ERROR_CORRECTOR_USER),
|
||||
])
|
||||
46
backend/src/ea_chatbot/graph/prompts/planner.py
Normal file
46
backend/src/ea_chatbot/graph/prompts/planner.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user.
|
||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
||||
|
||||
Today's Date is: {date}"""
|
||||
|
||||
PLANNER_USER = """Conversation Summary: {summary}
|
||||
|
||||
TASK:
|
||||
{question}
|
||||
|
||||
AVAILABLE DATA SUMMARY (Use only if relevant to the task):
|
||||
{database_description}
|
||||
|
||||
First: Evaluate whether you have all necessary and requested information to provide a solution.
|
||||
Use the dataset description above to determine what data and in what format you have available to you.
|
||||
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
|
||||
|
||||
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
|
||||
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
|
||||
|
||||
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
|
||||
rules, constraints, and other relevant details that appear in the problem description.
|
||||
|
||||
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
|
||||
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
|
||||
If fewer steps suffice, that's acceptable. If more are needed, please include them.
|
||||
Remember to explain steps rather than write code.
|
||||
|
||||
This algorithm will be later converted to Python code.
|
||||
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
|
||||
|
||||
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
|
||||
|
||||
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
|
||||
|
||||
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
|
||||
|
||||
{example_plan}"""
|
||||
|
||||
PLANNER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", PLANNER_SYSTEM),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", PLANNER_USER),
|
||||
])
|
||||
33
backend/src/ea_chatbot/graph/prompts/query_analyzer.py
Normal file
33
backend/src/ea_chatbot/graph/prompts/query_analyzer.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
SYSTEM_PROMPT = """You are an expert election data analyst. Decompose the user's question into key elements to determine the next action.
|
||||
|
||||
### Context & Defaults
|
||||
- **History:** Use the conversation history and summary to resolve coreferences (e.g., "those results", "that state"). Assume the current question inherits missing context (Year, State, County) from history.
|
||||
- **Data Access:** You have access to voter and election databases. Proceed to planning without asking for database or table names.
|
||||
- **Downstream Capabilities:** Visualizations are generated as Matplotlib figures. Proceed to planning for "graphs" or "plots" without asking for file formats or plot types.
|
||||
- **Trends:** For trend requests without a specified interval, allow the Planner to use a sensible default (e.g., by election cycle).
|
||||
|
||||
### Instructions:
|
||||
1. **Analyze:** Identify if the request is for data analysis, general facts (web research), or is critically ambiguous.
|
||||
2. **Extract Entities & Conditions:**
|
||||
- **Data Required:** e.g., "vote count", "demographics".
|
||||
- **Conditions:** e.g., "Year=2024". Include context from history.
|
||||
3. **Identify Target & Critical Ambiguities:**
|
||||
- **Unknowns:** The core target question.
|
||||
- **Critical Ambiguities:** ONLY list issues that PREVENT any analysis.
|
||||
- Examples: No timeframe/geography in query OR history; "track the same voter" without an identity definition.
|
||||
4. **Determine Action:**
|
||||
- `plan`: For data analysis where defaults or history provide sufficient context.
|
||||
- `research`: For general knowledge.
|
||||
- `clarify`: ONLY for CRITICAL ambiguities."""
|
||||
|
||||
USER_PROMPT_TEMPLATE = """Conversation Summary: {summary}
|
||||
|
||||
Analyze the following question: {question}"""
|
||||
|
||||
QUERY_ANALYZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", SYSTEM_PROMPT),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", USER_PROMPT_TEMPLATE),
|
||||
])
|
||||
12
backend/src/ea_chatbot/graph/prompts/researcher.py
Normal file
12
backend/src/ea_chatbot/graph/prompts/researcher.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
RESEARCHER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", """You are a Research Specialist and your job is to find answers and educate the user.
|
||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
||||
|
||||
Today's Date is: {date}"""),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("user", """Conversation Summary: {summary}
|
||||
|
||||
{question}""")
|
||||
])
|
||||
27
backend/src/ea_chatbot/graph/prompts/summarizer.py
Normal file
27
backend/src/ea_chatbot/graph/prompts/summarizer.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", """You are an expert election data analyst providing a final answer to the user.
|
||||
Use the provided conversation history and summary to ensure your response is contextually relevant and flows naturally from previous turns.
|
||||
|
||||
Conversation Summary: {summary}"""),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("user", """The user presented you with the following question.
|
||||
Question: {question}
|
||||
|
||||
To address this, you have designed an algorithm.
|
||||
Algorithm: {plan}.
|
||||
|
||||
You have crafted a Python code based on this algorithm, and the output generated by the code's execution is as follows.
|
||||
Output: {code_output}.
|
||||
|
||||
Please produce a comprehensive, easy-to-understand answer that:
|
||||
1. Summarizes the main insights or conclusions achieved through your method's implementation. Include execution results if necessary.
|
||||
2. Includes relevant findings from the code execution in a clear format (e.g., text explanation, tables, lists, bullet points).
|
||||
- Avoid referencing the code or output as 'the above results' or saying 'it's in the code output.'
|
||||
- Instead, present the actual key data or statistics within your explanation.
|
||||
3. If the user requested specific information that does not appear in the code's output but you can provide it, include that information directly in your summary.
|
||||
4. Present any data or tables that might have been generated by the code in full, since the user cannot directly see the execution output.
|
||||
|
||||
Your goal is to give a final answer that stands on its own without requiring the user to see the code or raw output directly.""")
|
||||
])
|
||||
36
backend/src/ea_chatbot/graph/state.py
Normal file
36
backend/src/ea_chatbot/graph/state.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain.agents import AgentState as AS
|
||||
import operator
|
||||
|
||||
class AgentState(AS):
|
||||
# Conversation history
|
||||
messages: Annotated[List[BaseMessage], operator.add]
|
||||
|
||||
# Task context
|
||||
question: str
|
||||
|
||||
# Query Analysis (Decomposition results)
|
||||
analysis: Optional[Dict[str, Any]]
|
||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
||||
|
||||
# Step-by-step reasoning
|
||||
plan: Optional[str]
|
||||
|
||||
# Code execution context
|
||||
code: Optional[str]
|
||||
code_output: Optional[str]
|
||||
error: Optional[str]
|
||||
|
||||
# Artifacts (for UI display)
|
||||
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
||||
dfs: Dict[str, Any] # Pandas DataFrames
|
||||
|
||||
# Conversation summary
|
||||
summary: Optional[str]
|
||||
|
||||
# Routing hint: "clarify", "plan", "research", "end"
|
||||
next_action: str
|
||||
|
||||
# Number of execution attempts
|
||||
iterations: int
|
||||
92
backend/src/ea_chatbot/graph/workflow.py
Normal file
92
backend/src/ea_chatbot/graph/workflow.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
MAX_ITERATIONS = 3
|
||||
|
||||
def router(state: AgentState) -> str:
|
||||
"""Route to the next node based on the analysis."""
|
||||
next_action = state.get("next_action")
|
||||
if next_action == "plan":
|
||||
return "planner"
|
||||
elif next_action == "research":
|
||||
return "researcher"
|
||||
elif next_action == "clarify":
|
||||
return "clarification"
|
||||
else:
|
||||
return END
|
||||
|
||||
def create_workflow():
|
||||
"""Create the LangGraph workflow."""
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add nodes
|
||||
workflow.add_node("query_analyzer", query_analyzer_node)
|
||||
workflow.add_node("planner", planner_node)
|
||||
workflow.add_node("coder", coder_node)
|
||||
workflow.add_node("error_corrector", error_corrector_node)
|
||||
workflow.add_node("researcher", researcher_node)
|
||||
workflow.add_node("clarification", clarification_node)
|
||||
workflow.add_node("executor", executor_node)
|
||||
workflow.add_node("summarizer", summarizer_node)
|
||||
workflow.add_node("summarize_conversation", summarize_conversation_node)
|
||||
|
||||
# Set entry point
|
||||
workflow.set_entry_point("query_analyzer")
|
||||
|
||||
# Add conditional edges from query_analyzer
|
||||
workflow.add_conditional_edges(
|
||||
"query_analyzer",
|
||||
router,
|
||||
{
|
||||
"planner": "planner",
|
||||
"researcher": "researcher",
|
||||
"clarification": "clarification",
|
||||
END: END
|
||||
}
|
||||
)
|
||||
|
||||
# Linear flow for planning and coding
|
||||
workflow.add_edge("planner", "coder")
|
||||
workflow.add_edge("coder", "executor")
|
||||
|
||||
# Executor routing
|
||||
def executor_router(state: AgentState) -> str:
|
||||
if state.get("error"):
|
||||
# Check for iteration limit to prevent infinite loops
|
||||
if state.get("iterations", 0) >= MAX_ITERATIONS:
|
||||
return "summarizer"
|
||||
return "error_corrector"
|
||||
return "summarizer"
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"executor",
|
||||
executor_router,
|
||||
{
|
||||
"error_corrector": "error_corrector",
|
||||
"summarizer": "summarizer"
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("error_corrector", "executor")
|
||||
|
||||
workflow.add_edge("researcher", "summarize_conversation")
|
||||
workflow.add_edge("clarification", END)
|
||||
workflow.add_edge("summarizer", "summarize_conversation")
|
||||
workflow.add_edge("summarize_conversation", END)
|
||||
|
||||
# Compile the graph
|
||||
app = workflow.compile()
|
||||
|
||||
return app
|
||||
|
||||
# Initialize the app
|
||||
app = create_workflow()
|
||||
0
backend/src/ea_chatbot/history/__init__.py
Normal file
0
backend/src/ea_chatbot/history/__init__.py
Normal file
188
backend/src/ea_chatbot/history/manager.py
Normal file
188
backend/src/ea_chatbot/history/manager.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import create_engine, select, delete
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
|
||||
# Argon2 Password Hasher
|
||||
ph = PasswordHasher()
|
||||
|
||||
class HistoryManager:
|
||||
"""Manages database sessions and operations for history and user data."""
|
||||
|
||||
def __init__(self, db_url: str):
|
||||
self.engine = create_engine(db_url)
|
||||
# expire_on_commit=False is important so we can use objects after session closes
|
||||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine, expire_on_commit=False)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
session = self.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# --- User Management ---
|
||||
|
||||
def get_user(self, email: str) -> Optional[User]:
|
||||
"""Fetch a user by their email (username)."""
|
||||
with self.get_session() as session:
|
||||
result = session.execute(select(User).where(User.username == email))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def get_user_by_id(self, user_id: str) -> Optional[User]:
|
||||
"""Fetch a user by their ID."""
|
||||
with self.get_session() as session:
|
||||
return session.get(User, user_id)
|
||||
|
||||
def create_user(self, email: str, password: Optional[str] = None, display_name: Optional[str] = None) -> User:
|
||||
"""Create a new local user."""
|
||||
hashed_password = ph.hash(password) if password else None
|
||||
user = User(
|
||||
username=email,
|
||||
password_hash=hashed_password,
|
||||
display_name=display_name or email.split("@")[0]
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate a user by email and password."""
|
||||
user = self.get_user(email)
|
||||
if not user or not user.password_hash:
|
||||
return None
|
||||
|
||||
try:
|
||||
ph.verify(user.password_hash, password)
|
||||
return user
|
||||
except VerifyMismatchError:
|
||||
return None
|
||||
|
||||
def sync_user_from_oidc(self, email: str, display_name: Optional[str] = None) -> User:
|
||||
"""
|
||||
Synchronize a user from an OIDC provider.
|
||||
If a user with the same email exists, update their display name.
|
||||
Otherwise, create a new user.
|
||||
"""
|
||||
user = self.get_user(email)
|
||||
if user:
|
||||
# Update existing user if needed
|
||||
if display_name and user.display_name != display_name:
|
||||
with self.get_session() as session:
|
||||
db_user = session.get(User, user.id)
|
||||
db_user.display_name = display_name
|
||||
session.commit()
|
||||
session.refresh(db_user)
|
||||
return db_user
|
||||
return user
|
||||
else:
|
||||
# Create new user (no password for OIDC users initially)
|
||||
return self.create_user(email=email, display_name=display_name)
|
||||
|
||||
# --- Conversation Management ---
|
||||
|
||||
def create_conversation(self, user_id: str, data_state: str, name: str, summary: Optional[str] = None) -> Conversation:
|
||||
"""Create a new conversation for a user."""
|
||||
conv = Conversation(
|
||||
user_id=user_id,
|
||||
data_state=data_state,
|
||||
name=name,
|
||||
summary=summary
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
def get_conversations(self, user_id: str, data_state: str) -> List[Conversation]:
|
||||
"""Get all conversations for a user and data state, ordered by creation time."""
|
||||
with self.get_session() as session:
|
||||
stmt = (
|
||||
select(Conversation)
|
||||
.where(Conversation.user_id == user_id, Conversation.data_state == data_state)
|
||||
.order_by(Conversation.created_at.desc())
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def rename_conversation(self, conversation_id: str, new_name: str) -> Optional[Conversation]:
|
||||
"""Rename an existing conversation."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
conv.name = new_name
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
def delete_conversation(self, conversation_id: str) -> bool:
|
||||
"""Delete a conversation and its associated messages/plots (via cascade)."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
session.delete(conv)
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_conversation_summary(self, conversation_id: str, summary: str) -> Optional[Conversation]:
|
||||
"""Update the summary of a conversation."""
|
||||
with self.get_session() as session:
|
||||
conv = session.get(Conversation, conversation_id)
|
||||
if conv:
|
||||
conv.summary = summary
|
||||
session.commit()
|
||||
session.refresh(conv)
|
||||
return conv
|
||||
|
||||
# --- Message & Plot Management ---
|
||||
|
||||
def add_message(self, conversation_id: str, role: str, content: str, plots: Optional[List[bytes]] = None) -> Message:
|
||||
"""Add a message to a conversation, optionally with plots."""
|
||||
msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content
|
||||
)
|
||||
with self.get_session() as session:
|
||||
session.add(msg)
|
||||
session.flush() # Populate msg.id for plots
|
||||
|
||||
if plots:
|
||||
for plot_data in plots:
|
||||
plot = Plot(message_id=msg.id, image_data=plot_data)
|
||||
session.add(plot)
|
||||
|
||||
session.commit()
|
||||
session.refresh(msg)
|
||||
# Ensure plots are loaded before session closes if we need them
|
||||
_ = msg.plots
|
||||
return msg
|
||||
|
||||
def get_messages(self, conversation_id: str) -> List[Message]:
|
||||
"""Get all messages for a conversation, ordered by creation time."""
|
||||
with self.get_session() as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc())
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
messages = list(result.scalars().all())
|
||||
# Pre-load plots for each message
|
||||
for m in messages:
|
||||
_ = m.plots
|
||||
return messages
|
||||
52
backend/src/ea_chatbot/history/models.py
Normal file
52
backend/src/ea_chatbot/history/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
from sqlalchemy import String, ForeignKey, DateTime, LargeBinary, Text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
username: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
password_hash: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
display_name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
|
||||
conversations: Mapped[List["Conversation"]] = relationship(back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id: Mapped[str] = mapped_column(ForeignKey("users.id"))
|
||||
data_state: Mapped[str] = mapped_column(String)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="conversations")
|
||||
messages: Mapped[List["Message"]] = relationship(back_populates="conversation", cascade="all, delete-orphan")
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
conversation_id: Mapped[str] = mapped_column(ForeignKey("conversations.id"))
|
||||
role: Mapped[str] = mapped_column(String)
|
||||
content: Mapped[str] = mapped_column(Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||
plots: Mapped[List["Plot"]] = relationship(back_populates="message", cascade="all, delete-orphan")
|
||||
|
||||
class Plot(Base):
|
||||
__tablename__ = "plots"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
message_id: Mapped[str] = mapped_column(ForeignKey("messages.id"))
|
||||
image_data: Mapped[bytes] = mapped_column(LargeBinary)
|
||||
|
||||
message: Mapped["Message"] = relationship(back_populates="plots")
|
||||
83
backend/src/ea_chatbot/schemas.py
Normal file
83
backend/src/ea_chatbot/schemas.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing import Sequence, Optional
|
||||
import re
|
||||
|
||||
class TaskPlanContext(BaseModel):
|
||||
'''Background context relevant to the task plan'''
|
||||
initial_context: str = Field(
|
||||
min_length=1,
|
||||
description="Background information about the database/tables and previous conversations relevant to the task.",
|
||||
)
|
||||
assumptions: Sequence[str] = Field(
|
||||
description="Assumptions made while working on the task.",
|
||||
)
|
||||
constraints: Optional[Sequence[str]] = Field(
|
||||
description="Constraints that apply to the task.",
|
||||
)
|
||||
|
||||
class TaskPlanResponse(BaseModel):
|
||||
'''Structured plan to achieve the task objective'''
|
||||
goal: str = Field(
|
||||
min_length=1,
|
||||
description="Single-sentence objective the plan must achieve.",
|
||||
)
|
||||
reflection: str = Field(
|
||||
min_length=1,
|
||||
description="High-level natural-language reasoning describing the user's request and the intended solution approach.",
|
||||
)
|
||||
context: TaskPlanContext = Field(
|
||||
description="Background context relevant to the task plan.",
|
||||
)
|
||||
steps: Sequence[str] = Field(
|
||||
min_length=1,
|
||||
description="Ordered list of steps to execute that follow the 'Step <number>: <detail>' pattern.",
|
||||
)
|
||||
|
||||
_IM_SEP_TOKEN_PATTERN = re.compile(re.escape("<|im_sep|>"))
|
||||
_CODE_BLOCK_PATTERN = re.compile(r"```(?:python\s*)?(.*?)\s*```", re.DOTALL)
|
||||
_FORBIDDEN_MODULES = (
|
||||
"subprocess",
|
||||
"sys",
|
||||
"eval",
|
||||
"exec",
|
||||
"socket",
|
||||
"urllib",
|
||||
"shutil",
|
||||
"pickle",
|
||||
"ctypes",
|
||||
"multiprocessing",
|
||||
"tempfile",
|
||||
"glob",
|
||||
"pty",
|
||||
"commands",
|
||||
"cgi",
|
||||
"cgitb",
|
||||
"xml.etree.ElementTree",
|
||||
"builtins",
|
||||
)
|
||||
_FORBIDDEN_MODULE_PATTERN = re.compile(
|
||||
r"^((?:[^#].*)?\b(" + "|".join(map(re.escape, _FORBIDDEN_MODULES)) + r")\b.*)$",
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
class CodeGenerationResponse(BaseModel):
|
||||
'''Code generation response structure'''
|
||||
code: str = Field(description="The generated code snippet to accomplish the task")
|
||||
explanation: str = Field(description="Explanation of the generated code and its functionality")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def parsed_code(self) -> str:
|
||||
'''Extracts the code snippet without any surrounding text'''
|
||||
normalised = _IM_SEP_TOKEN_PATTERN.sub("```", self.code).strip()
|
||||
match = _CODE_BLOCK_PATTERN.search(normalised)
|
||||
candidate = match.group(1).strip() if match else normalised
|
||||
sanitised = _FORBIDDEN_MODULE_PATTERN.sub(r"# not allowed \1", candidate)
|
||||
return sanitised.strip()
|
||||
|
||||
class RankResponse(BaseModel):
|
||||
'''Code ranking response structure'''
|
||||
rank: int = Field(
|
||||
ge=1, le=10,
|
||||
description="Rank of the code snippet from 1 (best) to 10 (worst)"
|
||||
)
|
||||
24
backend/src/ea_chatbot/types.py
Normal file
24
backend/src/ea_chatbot/types.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import TypedDict, Optional
|
||||
from enum import StrEnum
|
||||
|
||||
class DBSettings(TypedDict):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
pswd: str
|
||||
db: str
|
||||
table: Optional[str]
|
||||
|
||||
class Agent(StrEnum):
|
||||
EXPERT_SELECTOR = "Expert Selector"
|
||||
ANALYST_SELECTOR = "Analyst Selector"
|
||||
THEORIST = "Theorist"
|
||||
THEORIST_WEB = "Theorist-Web"
|
||||
THEORIST_CLARIFICATION = "Theorist-Clarification"
|
||||
PLANNER = "Planner"
|
||||
CODE_GENERATOR = "Code Generator"
|
||||
CODE_DEBUGGER = "Code Debugger"
|
||||
CODE_EXECUTOR = "Code Executor"
|
||||
ERROR_CORRECTOR = "Error Corrector"
|
||||
CODE_RANKER = "Code Ranker"
|
||||
SOLUTION_SUMMARIZER = "Solution Summarizer"
|
||||
12
backend/src/ea_chatbot/utils/__init__.py
Normal file
12
backend/src/ea_chatbot/utils/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .db_client import DBClient
|
||||
from .llm_factory import get_llm_model
|
||||
from .logging import get_logger, LangChainLoggingHandler
|
||||
from . import helpers
|
||||
|
||||
__all__ = [
|
||||
"DBClient",
|
||||
"get_llm_model",
|
||||
"get_logger",
|
||||
"LangChainLoggingHandler",
|
||||
"helpers"
|
||||
]
|
||||
234
backend/src/ea_chatbot/utils/database_inspection.py
Normal file
234
backend/src/ea_chatbot/utils/database_inspection.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from typing import Optional, Dict, Any, List, TYPE_CHECKING
|
||||
import yaml
|
||||
import json
|
||||
import os
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
def _get_table_checksum(db_client: DBClient, table: str) -> str:
|
||||
"""Calculates the checksum of the table using DML statistics from pg_stat_user_tables."""
|
||||
query = f"""
|
||||
SELECT md5(concat_ws('|', n_tup_ins, n_tup_upd, n_tup_del)) AS dml_hash
|
||||
FROM pg_stat_user_tables
|
||||
WHERE schemaname = 'public' AND relname = '{table}';"""
|
||||
try:
|
||||
return str(db_client.query_df(query).iloc[0, 0])
|
||||
except Exception:
|
||||
return "unknown_checksum"
|
||||
|
||||
def _update_checksum_file(filepath: str, table: str, checksum: str):
|
||||
"""Updates the checksum file with the new checksum for the table."""
|
||||
checksums = {}
|
||||
if os.path.exists(filepath):
|
||||
with open(filepath, 'r') as f:
|
||||
for line in f:
|
||||
if ':' in line:
|
||||
k, v = line.strip().split(':', 1)
|
||||
checksums[k] = v
|
||||
|
||||
checksums[table] = checksum
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
for k, v in checksums.items():
|
||||
f.write(f"{k}:{v}")
|
||||
|
||||
def get_data_summary(data_dir: str = "data") -> Optional[str]:
|
||||
"""
|
||||
Reads the inspection.yaml file and returns its content as a string.
|
||||
"""
|
||||
inspection_file = os.path.join(data_dir, "inspection.yaml")
|
||||
if os.path.exists(inspection_file):
|
||||
with open(inspection_file, 'r') as f:
|
||||
return f.read()
|
||||
return None
|
||||
|
||||
def get_primary_key(db_client: DBClient, table_name: str) -> Optional[str]:
|
||||
"""
|
||||
Dynamically identifies the primary key of the table.
|
||||
Returns the column name of the primary key, or None if not found.
|
||||
"""
|
||||
query = f"""
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.key_column_usage AS kcu
|
||||
JOIN information_schema.table_constraints AS tc
|
||||
ON kcu.constraint_name = tc.constraint_name
|
||||
AND kcu.table_schema = tc.table_schema
|
||||
WHERE kcu.table_name = '{table_name}'
|
||||
AND tc.constraint_type = 'PRIMARY KEY'
|
||||
LIMIT 1;
|
||||
"""
|
||||
try:
|
||||
df = db_client.query_df(query)
|
||||
if not df.empty:
|
||||
return str(df.iloc[0, 0])
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not determine primary key for {table_name}: {e}")
|
||||
return None
|
||||
|
||||
def inspect_db_table(
|
||||
db_client: Optional[DBClient]=None,
|
||||
db_settings: Optional["DBSettings"]=None,
|
||||
data_dir: str = "data",
|
||||
force_update: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Inspects the database table, generates statistics for each column,
|
||||
and saves the inspection results to a YAML file locally.
|
||||
|
||||
Improvements:
|
||||
- Dynamic Primary Key Discovery
|
||||
- Cardinality (Unique Counts)
|
||||
- Categorical Sample Values for low cardinality columns
|
||||
- Robust Quoting
|
||||
"""
|
||||
inspection_file = os.path.join(data_dir, "inspection.yaml")
|
||||
checksum_file = os.path.join(data_dir, "checksum")
|
||||
|
||||
# Initialize DB Client
|
||||
if db_client is None:
|
||||
if db_settings is None:
|
||||
print("Error: Either db_client or db_settings must be provided.")
|
||||
return None
|
||||
try:
|
||||
db_client = DBClient(db_settings)
|
||||
except Exception as e:
|
||||
print(f"Failed to create DBClient: {e}")
|
||||
return None
|
||||
|
||||
table_name = db_client.settings.get('table')
|
||||
if not table_name:
|
||||
print("Error: Table name must be specified in DBSettings.")
|
||||
return None
|
||||
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
# Checksum verification
|
||||
new_checksum = _get_table_checksum(db_client, table_name)
|
||||
has_changed = True
|
||||
|
||||
if os.path.exists(checksum_file):
|
||||
try:
|
||||
with open(checksum_file, 'r') as f:
|
||||
saved_checksums = f.read().strip()
|
||||
if f"{table_name}:{new_checksum}" in saved_checksums:
|
||||
has_changed = False
|
||||
except Exception:
|
||||
pass # Force update on read error
|
||||
|
||||
if not has_changed and not force_update:
|
||||
return get_data_summary(data_dir)
|
||||
|
||||
print(f"Regenerating inspection file for table '{table_name}'...")
|
||||
|
||||
# Fetch Table Metadata
|
||||
try:
|
||||
# Get columns and types
|
||||
columns_query = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';"
|
||||
columns_df = db_client.query_df(columns_query)
|
||||
|
||||
# Get Row Counts
|
||||
total_rows_df = db_client.query_df(f'SELECT COUNT(*) FROM "{table_name}"')
|
||||
total_rows = int(total_rows_df.iloc[0, 0])
|
||||
|
||||
# Dynamic Primary Key
|
||||
primary_key = get_primary_key(db_client, table_name)
|
||||
|
||||
# Get First/Last Rows (if PK exists)
|
||||
first_row_df = None
|
||||
last_row_df = None
|
||||
if primary_key:
|
||||
first_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" ASC LIMIT 1')
|
||||
last_row_df = db_client.query_df(f'SELECT * FROM "{table_name}" ORDER BY "{primary_key}" DESC LIMIT 1')
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to retrieve basic table info: {e}")
|
||||
return None
|
||||
|
||||
stats_dict: Dict[str, Any] = {}
|
||||
if primary_key:
|
||||
stats_dict['primary_key'] = primary_key
|
||||
|
||||
for _, row in columns_df.iterrows():
|
||||
col_name = row['column_name']
|
||||
dtype = row['data_type']
|
||||
|
||||
try:
|
||||
# Count Values
|
||||
# Using robust quoting
|
||||
count_df = db_client.query_df(f'SELECT COUNT("{col_name}") FROM "{table_name}"')
|
||||
count_val = int(count_df.iloc[0,0])
|
||||
|
||||
# Count Unique (Cardinality)
|
||||
unique_df = db_client.query_df(f'SELECT COUNT(DISTINCT "{col_name}") FROM "{table_name}"')
|
||||
unique_count = int(unique_df.iloc[0,0])
|
||||
|
||||
col_stats: Dict[str, Any] = {
|
||||
'dtype': dtype,
|
||||
'count_of_values': count_val,
|
||||
'count_of_nulls': total_rows - count_val,
|
||||
'unique_count': unique_count
|
||||
}
|
||||
|
||||
if count_val == 0:
|
||||
stats_dict[col_name] = col_stats
|
||||
continue
|
||||
|
||||
# Numerical Stats
|
||||
if any(t in dtype for t in ('int', 'float', 'numeric', 'double', 'real', 'decimal')):
|
||||
stats_query = f'SELECT AVG("{col_name}"), MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
|
||||
stats_df = db_client.query_df(stats_query)
|
||||
if not stats_df.empty:
|
||||
col_stats['mean'] = float(stats_df.iloc[0,0]) if stats_df.iloc[0,0] is not None else None
|
||||
col_stats['min'] = float(stats_df.iloc[0,1]) if stats_df.iloc[0,1] is not None else None
|
||||
col_stats['max'] = float(stats_df.iloc[0,2]) if stats_df.iloc[0,2] is not None else None
|
||||
|
||||
# Temporal Stats
|
||||
elif any(t in dtype for t in ('date', 'timestamp')):
|
||||
stats_query = f'SELECT MIN("{col_name}"), MAX("{col_name}") FROM "{table_name}"'
|
||||
stats_df = db_client.query_df(stats_query)
|
||||
if not stats_df.empty:
|
||||
col_stats['min'] = str(stats_df.iloc[0,0])
|
||||
col_stats['max'] = str(stats_df.iloc[0,1])
|
||||
|
||||
# Categorical/Text Stats
|
||||
else:
|
||||
# Sample values if cardinality is low (< 20)
|
||||
if 0 < unique_count < 20:
|
||||
distinct_query = f'SELECT DISTINCT "{col_name}" FROM "{table_name}" ORDER BY "{col_name}" LIMIT 20'
|
||||
distinct_df = db_client.query_df(distinct_query)
|
||||
col_stats['distinct_values'] = distinct_df.iloc[:, 0].tolist()
|
||||
|
||||
if first_row_df is not None and not first_row_df.empty and col_name in first_row_df.columns:
|
||||
col_stats['first_value'] = str(first_row_df.iloc[0][col_name])
|
||||
if last_row_df is not None and not last_row_df.empty and col_name in last_row_df.columns:
|
||||
col_stats['last_value'] = str(last_row_df.iloc[0][col_name])
|
||||
|
||||
stats_dict[col_name] = col_stats
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process column {col_name}: {e}")
|
||||
|
||||
# Load existing inspections to merge (if multiple tables)
|
||||
existing_inspections = {}
|
||||
if os.path.exists(inspection_file):
|
||||
try:
|
||||
with open(inspection_file, 'r') as f:
|
||||
existing_inspections = yaml.safe_load(f) or {}
|
||||
# Backup old file
|
||||
os.rename(inspection_file, inspection_file + ".old")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
existing_inspections[table_name] = stats_dict
|
||||
|
||||
# Save new inspection
|
||||
inspection_content = yaml.dump(existing_inspections, sort_keys=False, default_flow_style=False)
|
||||
with open(inspection_file, 'w') as f:
|
||||
f.write(inspection_content)
|
||||
|
||||
# Update Checksum
|
||||
_update_checksum_file(checksum_file, table_name, new_checksum)
|
||||
|
||||
print(f"Inspection saved to {inspection_file}")
|
||||
return inspection_content
|
||||
21
backend/src/ea_chatbot/utils/db_client.py
Normal file
21
backend/src/ea_chatbot/utils/db_client.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
class DBClient:
|
||||
def __init__(self, settings: "DBSettings"):
|
||||
self.settings = settings
|
||||
self._engine = self._create_engine()
|
||||
|
||||
def _create_engine(self):
|
||||
url = f"postgresql://{self.settings['user']}:{self.settings['pswd']}@{self.settings['host']}:{self.settings['port']}/{self.settings['db']}"
|
||||
return create_engine(url)
|
||||
|
||||
def query_df(self, sql: str, params: Optional[dict] = None) -> pd.DataFrame:
|
||||
with self._engine.connect() as conn:
|
||||
result = conn.execute(text(sql), params or {})
|
||||
df = pd.DataFrame(result.fetchall(), columns=result.keys())
|
||||
return df
|
||||
73
backend/src/ea_chatbot/utils/helpers.py
Normal file
73
backend/src/ea_chatbot/utils/helpers.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import Optional, TYPE_CHECKING, Dict, Any
|
||||
from datetime import datetime, timezone
|
||||
import yaml
|
||||
import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
def ordinal(n: int) -> str:
|
||||
return f"{n}{'th' if 11<=n<=13 else {1:'st',2:'nd',3:'rd'}.get(n%10, 'th')}"
|
||||
|
||||
def get_readable_date(date_obj: Optional[datetime] = None, tz: Optional[timezone] = None) -> str:
|
||||
if date_obj is None:
|
||||
date_obj = datetime.now(timezone.utc)
|
||||
if tz:
|
||||
date_obj = date_obj.astimezone(tz)
|
||||
return date_obj.strftime(f"%a {ordinal(date_obj.day)} of %b %Y")
|
||||
|
||||
def to_yaml(json_str: str, indent: int = 2) -> str:
|
||||
"""
|
||||
Attempts to convert a JSON string (potentially malformed from LLM) to a YAML string.
|
||||
"""
|
||||
if not json_str: return ""
|
||||
|
||||
try:
|
||||
# Try direct parse
|
||||
data = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# Try simplified repair: replace single quotes
|
||||
try:
|
||||
cleaned = json_str.replace("'", '"')
|
||||
data = json.loads(cleaned)
|
||||
except Exception:
|
||||
# Fallback: return raw string if unparseable
|
||||
return json_str
|
||||
|
||||
return yaml.dump(data, indent=indent, sort_keys=False)
|
||||
|
||||
def merge_agent_state(current_state: "AgentState", update: Dict[str, Any]) -> "AgentState":
|
||||
"""
|
||||
Merges a partial state update into the current state, mimicking LangGraph reduction logic.
|
||||
- Lists (messages, plots) are appended.
|
||||
- Dictionaries (dfs) are shallow merged.
|
||||
- Other fields are overwritten.
|
||||
"""
|
||||
new_state = current_state.copy()
|
||||
|
||||
for key, value in update.items():
|
||||
if value is None:
|
||||
new_state[key] = None
|
||||
continue
|
||||
|
||||
# Accumulate lists (messages, plots)
|
||||
if key in ["messages", "plots"] and isinstance(value, list):
|
||||
current_list = new_state.get(key, [])
|
||||
if not isinstance(current_list, list):
|
||||
current_list = []
|
||||
new_state[key] = current_list + value
|
||||
|
||||
# Shallow merge dictionaries (dfs)
|
||||
elif key == "dfs" and isinstance(value, dict):
|
||||
current_dict = new_state.get(key, {})
|
||||
if not isinstance(current_dict, dict):
|
||||
current_dict = {}
|
||||
merged_dict = current_dict.copy()
|
||||
merged_dict.update(value)
|
||||
new_state[key] = merged_dict
|
||||
|
||||
# Overwrite everything else
|
||||
else:
|
||||
new_state[key] = value
|
||||
|
||||
return new_state
|
||||
36
backend/src/ea_chatbot/utils/llm_factory.py
Normal file
36
backend/src/ea_chatbot/utils/llm_factory.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from ea_chatbot.config import LLMConfig
|
||||
|
||||
def get_llm_model(config: LLMConfig, callbacks: Optional[List[BaseCallbackHandler]] = None) -> BaseChatModel:
|
||||
"""
|
||||
Factory function to get a LangChain chat model based on configuration.
|
||||
|
||||
Args:
|
||||
config: LLMConfig object containing model settings.
|
||||
callbacks: Optional list of LangChain callback handlers.
|
||||
|
||||
Returns:
|
||||
Initialized BaseChatModel instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported.
|
||||
"""
|
||||
params = {
|
||||
"temperature": config.temperature,
|
||||
"max_tokens": config.max_tokens,
|
||||
**config.provider_specific
|
||||
}
|
||||
|
||||
# Filter out None values to allow defaults to take over if not specified
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
if config.provider.lower() == "openai":
|
||||
return ChatOpenAI(model=config.model, callbacks=callbacks, **params)
|
||||
elif config.provider.lower() == "google" or config.provider.lower() == "google_genai":
|
||||
return ChatGoogleGenerativeAI(model=config.model, callbacks=callbacks, **params)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {config.provider}")
|
||||
141
backend/src/ea_chatbot/utils/logging.py
Normal file
141
backend/src/ea_chatbot/utils/logging.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from rich.logging import RichHandler
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
class LangChainLoggingHandler(BaseCallbackHandler):
|
||||
"""Callback handler for logging LangChain events."""
|
||||
|
||||
def __init__(self, logger: Optional[logging.Logger] = None):
|
||||
self.logger = logger or get_logger("langchain")
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
|
||||
# Serialized might be empty or missing name depending on how it's called
|
||||
model_name = serialized.get("name") or kwargs.get("name") or "LLM"
|
||||
self.logger.info(f"[bold blue]LLM Started:[/bold blue] {model_name}")
|
||||
|
||||
def on_llm_end(self, response: Any, **kwargs: Any) -> Any:
|
||||
llm_output = getattr(response, "llm_output", {}) or {}
|
||||
# Try to find model name in output or use fallback
|
||||
model_name = llm_output.get("model_name") or "LLM"
|
||||
token_usage = llm_output.get("token_usage", {})
|
||||
|
||||
msg = f"[bold green]LLM Ended:[/bold green] {model_name}"
|
||||
if token_usage:
|
||||
prompt = token_usage.get("prompt_tokens", 0)
|
||||
completion = token_usage.get("completion_tokens", 0)
|
||||
total = token_usage.get("total_tokens", 0)
|
||||
msg += f" | [yellow]Tokens: {total}[/yellow] ({prompt} prompt, {completion} completion)"
|
||||
|
||||
self.logger.info(msg)
|
||||
|
||||
def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> Any:
|
||||
self.logger.error(f"[bold red]LLM Error:[/bold red] {str(error)}")
|
||||
|
||||
class ContextLoggerAdapter(logging.LoggerAdapter):
|
||||
"""Adapter to inject contextual metadata into log records."""
|
||||
def process(self, msg: Any, kwargs: Any) -> tuple[Any, Any]:
|
||||
extra = self.extra.copy()
|
||||
if "extra" in kwargs:
|
||||
extra.update(kwargs.pop("extra"))
|
||||
kwargs["extra"] = extra
|
||||
return msg, kwargs
|
||||
|
||||
class FlexibleJSONEncoder(json.JSONEncoder):
|
||||
def default(self, obj: Any) -> Any:
|
||||
if hasattr(obj, 'model_dump'): # Pydantic v2
|
||||
return obj.model_dump()
|
||||
if hasattr(obj, 'dict'): # Pydantic v1
|
||||
return obj.dict()
|
||||
if hasattr(obj, '__dict__'):
|
||||
return self.serialize_custom_object(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: self.default(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self.default(item) for item in obj]
|
||||
return super().default(obj)
|
||||
|
||||
def serialize_custom_object(self, obj: Any) -> dict:
|
||||
obj_dict = obj.__dict__.copy()
|
||||
obj_dict['__custom_class__'] = obj.__class__.__name__
|
||||
return obj_dict
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""Custom JSON formatter for structured logging."""
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
# Standard fields
|
||||
log_record = {
|
||||
"timestamp": self.formatTime(record, self.datefmt),
|
||||
"level": record.levelname,
|
||||
"message": record.getMessage(),
|
||||
"module": record.module,
|
||||
"name": record.name,
|
||||
}
|
||||
|
||||
# Add exception info if present
|
||||
if record.exc_info:
|
||||
log_record["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
# Add all other extra fields from the record
|
||||
# Filter out standard logging attributes
|
||||
standard_attrs = {
|
||||
'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
|
||||
'funcName', 'levelname', 'levelno', 'lineno', 'module',
|
||||
'msecs', 'message', 'msg', 'name', 'pathname', 'process',
|
||||
'processName', 'relativeCreated', 'stack_info', 'thread', 'threadName'
|
||||
}
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in standard_attrs:
|
||||
log_record[key] = value
|
||||
|
||||
return json.dumps(log_record, cls=FlexibleJSONEncoder)
|
||||
|
||||
def get_logger(name: str = "ea_chatbot", level: Optional[str] = None, log_file: Optional[str] = None) -> logging.Logger:
|
||||
"""Get a configured logger with RichHandler and optional Json FileHandler."""
|
||||
# Ensure name starts with ea_chatbot for hierarchy if not already
|
||||
if name != "ea_chatbot" and not name.startswith("ea_chatbot."):
|
||||
full_name = f"ea_chatbot.{name}"
|
||||
else:
|
||||
full_name = name
|
||||
|
||||
logger = logging.getLogger(full_name)
|
||||
|
||||
# Configure root ea_chatbot logger if it hasn't been configured
|
||||
root_logger = logging.getLogger("ea_chatbot")
|
||||
if not root_logger.handlers:
|
||||
# Default to INFO if level not provided
|
||||
log_level = getattr(logging, (level or "INFO").upper(), logging.INFO)
|
||||
root_logger.setLevel(log_level)
|
||||
|
||||
# Console Handler (Rich)
|
||||
rich_handler = RichHandler(
|
||||
rich_tracebacks=True,
|
||||
markup=True,
|
||||
show_time=False,
|
||||
show_path=False
|
||||
)
|
||||
root_logger.addHandler(rich_handler)
|
||||
root_logger.propagate = False
|
||||
|
||||
# Always check if we need to add a FileHandler, even if root is already configured
|
||||
if log_file:
|
||||
existing_file_handlers = [h for h in root_logger.handlers if isinstance(h, RotatingFileHandler)]
|
||||
if not existing_file_handlers:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file, maxBytes=5*1024*1024, backupCount=3
|
||||
)
|
||||
file_handler.setFormatter(JsonFormatter())
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# Refresh logger object in case it was created before root was configured
|
||||
logger = logging.getLogger(full_name)
|
||||
|
||||
# If level is explicitly provided for a sub-logger, set it
|
||||
if level:
|
||||
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||
|
||||
return logger
|
||||
Reference in New Issue
Block a user