Refactor: Move backend files to backend/ directory and split .gitignore
This commit is contained in:
0
backend/src/ea_chatbot/graph/__init__.py
Normal file
0
backend/src/ea_chatbot/graph/__init__.py
Normal file
44
backend/src/ea_chatbot/graph/checkpoint.py
Normal file
44
backend/src/ea_chatbot/graph/checkpoint.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import contextlib
|
||||
from typing import AsyncGenerator
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
_pool = None
|
||||
|
||||
def get_pool() -> AsyncConnectionPool:
|
||||
"""Get or create the async connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = AsyncConnectionPool(
|
||||
conninfo=settings.history_db_url,
|
||||
max_size=20,
|
||||
kwargs={"autocommit": True, "prepare_threshold": 0},
|
||||
open=False, # Don't open automatically on init
|
||||
)
|
||||
return _pool
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def get_checkpointer() -> AsyncGenerator[AsyncPostgresSaver, None]:
|
||||
"""
|
||||
Context manager to get a PostgresSaver checkpointer.
|
||||
Ensures that the checkpointer is properly initialized and the connection is managed.
|
||||
"""
|
||||
pool = get_pool()
|
||||
# Ensure pool is open
|
||||
if pool.closed:
|
||||
await pool.open()
|
||||
|
||||
async with pool.connection() as conn:
|
||||
checkpointer = AsyncPostgresSaver(conn)
|
||||
# Ensure the necessary tables exist
|
||||
await checkpointer.setup()
|
||||
yield checkpointer
|
||||
|
||||
async def close_pool():
|
||||
"""Close the connection pool."""
|
||||
global _pool
|
||||
if _pool and not _pool.closed:
|
||||
await _pool.close()
|
||||
0
backend/src/ea_chatbot/graph/nodes/__init__.py
Normal file
0
backend/src/ea_chatbot/graph/nodes/__init__.py
Normal file
45
backend/src/ea_chatbot/graph/nodes/clarification.py
Normal file
45
backend/src/ea_chatbot/graph/nodes/clarification.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def clarification_node(state: AgentState) -> dict:
|
||||
"""Ask the user for missing information or clarifications."""
|
||||
question = state["question"]
|
||||
analysis = state.get("analysis", {})
|
||||
ambiguities = analysis.get("ambiguities", [])
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("clarification")
|
||||
|
||||
logger.info(f"Generating clarification for {len(ambiguities)} ambiguities.")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.query_analyzer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
system_prompt = """You are a Clarification Specialist. Your role is to identify what information is missing from a user's request to perform a data analysis or research task.
|
||||
Based on the analysis of the user's question, formulate a polite and concise request for the missing information."""
|
||||
|
||||
prompt = f"""Original Question: {question}
|
||||
Missing/Ambiguous Information: {', '.join(ambiguities) if ambiguities else 'Unknown ambiguities'}
|
||||
|
||||
Please ask the user for the necessary details."""
|
||||
|
||||
messages = [
|
||||
("system", system_prompt),
|
||||
("user", prompt)
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Clarification generated.[/bold green]")
|
||||
return {
|
||||
"messages": [response],
|
||||
"next_action": "end" # To indicate we are done for now
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate clarification: {str(e)}")
|
||||
raise e
|
||||
47
backend/src/ea_chatbot/graph/nodes/coder.py
Normal file
47
backend/src/ea_chatbot/graph/nodes/coder.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def coder_node(state: AgentState) -> dict:
|
||||
"""Generate Python code based on the plan and data summary."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "None")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("coder")
|
||||
|
||||
logger.info("Generating Python code...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
# Always provide data summary
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_code = "" # Placeholder
|
||||
|
||||
messages = CODE_GENERATOR_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
database_description=database_description,
|
||||
code_exec_results=code_output,
|
||||
example_code=example_code
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Code generated.[/bold green]")
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None # Clear previous errors on new code generation
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate code: {str(e)}")
|
||||
raise e
|
||||
44
backend/src/ea_chatbot/graph/nodes/error_corrector.py
Normal file
44
backend/src/ea_chatbot/graph/nodes/error_corrector.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
|
||||
def error_corrector_node(state: AgentState) -> dict:
|
||||
"""Fix the code based on the execution error."""
|
||||
code = state.get("code", "")
|
||||
error = state.get("error", "Unknown error")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("error_corrector")
|
||||
|
||||
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
|
||||
logger.info("Attempting to correct the code...")
|
||||
|
||||
# Reuse coder LLM config or add a new one. Using coder_llm for now.
|
||||
llm = get_llm_model(
|
||||
settings.coder_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
||||
|
||||
messages = ERROR_CORRECTOR_PROMPT.format_messages(
|
||||
code=code,
|
||||
error=error
|
||||
)
|
||||
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
logger.info("[bold green]Correction generated.[/bold green]")
|
||||
|
||||
current_iterations = state.get("iterations", 0)
|
||||
|
||||
return {
|
||||
"code": response.parsed_code,
|
||||
"error": None, # Clear error after fix attempt
|
||||
"iterations": current_iterations + 1
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to correct code: {str(e)}")
|
||||
raise e
|
||||
102
backend/src/ea_chatbot/graph/nodes/executor.py
Normal file
102
backend/src/ea_chatbot/graph/nodes/executor.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import io
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.db_client import DBClient
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from ea_chatbot.config import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ea_chatbot.types import DBSettings
|
||||
|
||||
def executor_node(state: AgentState) -> dict:
|
||||
"""Execute the Python code and capture output, plots, and dataframes."""
|
||||
code = state.get("code")
|
||||
logger = get_logger("executor")
|
||||
|
||||
if not code:
|
||||
logger.error("No code provided to executor.")
|
||||
return {"error": "No code provided to executor."}
|
||||
|
||||
logger.info("Executing Python code...")
|
||||
settings = Settings()
|
||||
|
||||
db_settings: "DBSettings" = {
|
||||
"host": settings.db_host,
|
||||
"port": settings.db_port,
|
||||
"user": settings.db_user,
|
||||
"pswd": settings.db_pswd,
|
||||
"db": settings.db_name,
|
||||
"table": settings.db_table
|
||||
}
|
||||
|
||||
db_client = DBClient(settings=db_settings)
|
||||
|
||||
# Initialize local variables for execution
|
||||
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
|
||||
local_vars = {
|
||||
'db': db_client,
|
||||
'plots': [],
|
||||
'pd': pd
|
||||
}
|
||||
|
||||
stdout_buffer = io.StringIO()
|
||||
error = None
|
||||
code_output = ""
|
||||
plots = []
|
||||
dfs = {}
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_buffer):
|
||||
# Execute the code in the context of local_vars
|
||||
exec(code, {}, local_vars)
|
||||
|
||||
code_output = stdout_buffer.getvalue()
|
||||
|
||||
# Limit the output length if it's too long
|
||||
if code_output.count('\n') > 32:
|
||||
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...'
|
||||
|
||||
# Extract plots
|
||||
raw_plots = local_vars.get('plots', [])
|
||||
if isinstance(raw_plots, list):
|
||||
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
||||
|
||||
# Extract DataFrames that were likely intended for display
|
||||
# We look for DataFrames in local_vars that were mentioned in the code
|
||||
for key, value in local_vars.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
# Heuristic: if the variable name is in the code, it might be a result DF
|
||||
if key in code:
|
||||
dfs[key] = value
|
||||
|
||||
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
|
||||
|
||||
except Exception as e:
|
||||
# Capture the traceback
|
||||
exc_type, exc_value, tb = sys.exc_info()
|
||||
full_traceback = traceback.format_exc()
|
||||
|
||||
# Filter traceback to show only the relevant part (the executed string)
|
||||
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
|
||||
error = '\n'.join(filtered_tb_lines)
|
||||
if error:
|
||||
error += '\n'
|
||||
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
||||
|
||||
logger.error(f"Execution failed: {str(e)}")
|
||||
|
||||
# If we have an error, we still might want to see partial stdout
|
||||
code_output = stdout_buffer.getvalue()
|
||||
|
||||
return {
|
||||
"code_output": code_output,
|
||||
"error": error,
|
||||
"plots": plots,
|
||||
"dfs": dfs
|
||||
}
|
||||
51
backend/src/ea_chatbot/graph/nodes/planner.py
Normal file
51
backend/src/ea_chatbot/graph/nodes/planner.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import yaml
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers, database_inspection
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.planner import PLANNER_PROMPT
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
|
||||
def planner_node(state: AgentState) -> dict:
|
||||
"""Generate a structured plan based on the query analysis."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("planner")
|
||||
|
||||
logger.info("Generating task plan...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.planner_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(TaskPlanResponse)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
# Always provide data summary; LLM decides relevance.
|
||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
||||
example_plan = ""
|
||||
|
||||
messages = PLANNER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary,
|
||||
database_description=database_description,
|
||||
example_plan=example_plan
|
||||
)
|
||||
|
||||
# Generate the structured plan
|
||||
try:
|
||||
response = structured_llm.invoke(messages)
|
||||
# Convert the structured response back to YAML string for the state
|
||||
plan_yaml = yaml.dump(response.model_dump(), sort_keys=False)
|
||||
logger.info("[bold green]Plan generated successfully.[/bold green]")
|
||||
return {"plan": plan_yaml}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate plan: {str(e)}")
|
||||
raise e
|
||||
73
backend/src/ea_chatbot/graph/nodes/query_analyzer.py
Normal file
73
backend/src/ea_chatbot/graph/nodes/query_analyzer.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import List, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.query_analyzer import QUERY_ANALYZER_PROMPT
|
||||
|
||||
class QueryAnalysis(BaseModel):
|
||||
"""Analysis of the user's query."""
|
||||
data_required: List[str] = Field(description="List of data points or entities mentioned (e.g., ['2024 results', 'Florida']).")
|
||||
unknowns: List[str] = Field(description="List of target information the user wants to know or needed for final answer (e.g., 'who won', 'total votes').")
|
||||
ambiguities: List[str] = Field(description="List of CRITICAL missing details that prevent ANY analysis. Do NOT include database names or plot types if defaults can be used.")
|
||||
conditions: List[str] = Field(description="List of any filters or constraints (e.g., ['year=2024', 'state=Florida']). Include context resolved from history.")
|
||||
next_action: Literal["plan", "clarify", "research"] = Field(description="The next action to take. 'plan' for data analysis (even with defaults), 'research' for general knowledge, or 'clarify' ONLY for critical ambiguities.")
|
||||
|
||||
def query_analyzer_node(state: AgentState) -> dict:
|
||||
"""Analyze the user's question and determine the next course of action."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])
|
||||
summary = state.get("summary", "")
|
||||
|
||||
# Keep last 3 turns (6 messages)
|
||||
history = history[-6:]
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("query_analyzer")
|
||||
|
||||
logger.info(f"Analyzing question: [italic]\"{question}\"[/italic]")
|
||||
|
||||
# Initialize the LLM with structured output using the factory
|
||||
# Pass logging callback to track LLM usage
|
||||
llm = get_llm_model(
|
||||
settings.query_analyzer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
structured_llm = llm.with_structured_output(QueryAnalysis)
|
||||
|
||||
# Prepare messages using the prompt template
|
||||
messages = QUERY_ANALYZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
try:
|
||||
# Invoke the structured LLM directly with the list of messages
|
||||
analysis_result = structured_llm.invoke(messages)
|
||||
analysis_result = QueryAnalysis.model_validate(analysis_result)
|
||||
|
||||
analysis_dict = analysis_result.model_dump()
|
||||
analysis_dict.pop("next_action")
|
||||
next_action = analysis_result.next_action
|
||||
|
||||
logger.info(f"Analysis complete. Next action: [bold magenta]{next_action}[/bold magenta]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during query analysis: {str(e)}")
|
||||
analysis_dict = {
|
||||
"data_required": [],
|
||||
"unknowns": [],
|
||||
"ambiguities": [f"Error during analysis: {str(e)}"],
|
||||
"conditions": []
|
||||
}
|
||||
next_action = "clarify"
|
||||
|
||||
return {
|
||||
"analysis": analysis_dict,
|
||||
"next_action": next_action,
|
||||
"iterations": 0
|
||||
}
|
||||
|
||||
|
||||
60
backend/src/ea_chatbot/graph/nodes/researcher.py
Normal file
60
backend/src/ea_chatbot/graph/nodes/researcher.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
||||
|
||||
def researcher_node(state: AgentState) -> dict:
|
||||
"""Handle general research queries or web searches."""
|
||||
question = state["question"]
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("researcher")
|
||||
|
||||
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
|
||||
|
||||
# Use researcher_llm from settings
|
||||
llm = get_llm_model(
|
||||
settings.researcher_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
messages = RESEARCHER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=question,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
# Provider-aware tool binding
|
||||
try:
|
||||
if isinstance(llm, ChatGoogleGenerativeAI):
|
||||
# Native Google Search for Gemini
|
||||
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
||||
elif isinstance(llm, ChatOpenAI):
|
||||
# Native Web Search for OpenAI (built-in tool)
|
||||
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
||||
else:
|
||||
# Fallback for other providers that might not support these specific search tools
|
||||
llm_with_tools = llm
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.")
|
||||
llm_with_tools = llm
|
||||
|
||||
try:
|
||||
response = llm_with_tools.invoke(messages)
|
||||
logger.info("[bold green]Research complete.[/bold green]")
|
||||
return {
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Research failed: {str(e)}")
|
||||
raise e
|
||||
52
backend/src/ea_chatbot/graph/nodes/summarize_conversation.py
Normal file
52
backend/src/ea_chatbot/graph/nodes/summarize_conversation.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def summarize_conversation_node(state: AgentState) -> dict:
|
||||
"""Update the conversation summary based on the latest interaction."""
|
||||
summary = state.get("summary", "")
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# We only summarize if there are messages
|
||||
if not messages:
|
||||
return {}
|
||||
|
||||
# Get the last turn (User + Assistant)
|
||||
last_turn = messages[-2:]
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("summarize_conversation")
|
||||
|
||||
logger.info("Updating conversation summary...")
|
||||
|
||||
# Use summarizer_llm for this task as well
|
||||
llm = get_llm_model(
|
||||
settings.summarizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
if summary:
|
||||
prompt = (
|
||||
f"This is a summary of the conversation so far: {summary}\n\n"
|
||||
"Extend the summary by taking into account the new messages above."
|
||||
)
|
||||
else:
|
||||
prompt = "Create a summary of the conversation above."
|
||||
|
||||
# Construct the messages for the summarization LLM
|
||||
summarization_messages = [
|
||||
SystemMessage(content=f"Current summary: {summary}" if summary else "You are a helpful assistant that summarizes conversations."),
|
||||
HumanMessage(content=f"Recent messages:\n{last_turn}\n\n{prompt}\n\nKeep the summary concise and focused on the key topics and data points discussed.")
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(summarization_messages)
|
||||
new_summary = response.content
|
||||
logger.info("[bold green]Conversation summary updated.[/bold green]")
|
||||
return {"summary": new_summary}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update summary: {str(e)}")
|
||||
# If summarization fails, we keep the old one
|
||||
return {"summary": summary}
|
||||
44
backend/src/ea_chatbot/graph/nodes/summarizer.py
Normal file
44
backend/src/ea_chatbot/graph/nodes/summarizer.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
|
||||
|
||||
def summarizer_node(state: AgentState) -> dict:
|
||||
"""Summarize the code execution results into a final answer."""
|
||||
question = state["question"]
|
||||
plan = state.get("plan", "")
|
||||
code_output = state.get("code_output", "")
|
||||
history = state.get("messages", [])[-6:]
|
||||
summary = state.get("summary", "")
|
||||
|
||||
settings = Settings()
|
||||
logger = get_logger("summarizer")
|
||||
|
||||
logger.info("Generating final summary...")
|
||||
|
||||
llm = get_llm_model(
|
||||
settings.summarizer_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
messages = SUMMARIZER_PROMPT.format_messages(
|
||||
question=question,
|
||||
plan=plan,
|
||||
code_output=code_output,
|
||||
history=history,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
try:
|
||||
response = llm.invoke(messages)
|
||||
logger.info("[bold green]Summary generated.[/bold green]")
|
||||
|
||||
# Return the final message to be added to the state
|
||||
return {
|
||||
"messages": [response]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {str(e)}")
|
||||
raise e
|
||||
10
backend/src/ea_chatbot/graph/prompts/__init__.py
Normal file
10
backend/src/ea_chatbot/graph/prompts/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .query_analyzer import QUERY_ANALYZER_PROMPT
|
||||
from .planner import PLANNER_PROMPT
|
||||
from .coder import CODE_GENERATOR_PROMPT, ERROR_CORRECTOR_PROMPT
|
||||
|
||||
__all__ = [
|
||||
"QUERY_ANALYZER_PROMPT",
|
||||
"PLANNER_PROMPT",
|
||||
"CODE_GENERATOR_PROMPT",
|
||||
"ERROR_CORRECTOR_PROMPT",
|
||||
]
|
||||
64
backend/src/ea_chatbot/graph/prompts/coder.py
Normal file
64
backend/src/ea_chatbot/graph/prompts/coder.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
CODE_GENERATOR_SYSTEM = """You are an AI data analyst and your job is to assist users with data analysis and coding tasks.
|
||||
The user will provide a task and a plan.
|
||||
|
||||
**Data Access:**
|
||||
- A database client is available as a variable named `db`.
|
||||
- You MUST use `db.query_df(sql_query)` to execute SQL queries and retrieve data as a Pandas DataFrame.
|
||||
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
|
||||
- The database schema is described in the prompt. Use it to construct valid SQL queries.
|
||||
|
||||
**Plotting:**
|
||||
- If you need to plot any data, use the `plots` list to store the figures.
|
||||
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
|
||||
- Do not use `plt.show()` as it will render the plot and cause an error.
|
||||
|
||||
**Code Requirements:**
|
||||
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
|
||||
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
|
||||
- Always include print statements to output the results of your code.
|
||||
- Use `db.query_df("SELECT ...")` to get data."""
|
||||
|
||||
CODE_GENERATOR_USER = """TASK:
|
||||
{question}
|
||||
|
||||
PLAN:
|
||||
```yaml
|
||||
{plan}
|
||||
```
|
||||
|
||||
AVAILABLE DATA SUMMARY (Database Schema):
|
||||
{database_description}
|
||||
|
||||
CODE EXECUTION OF THE PREVIOUS TASK RESULTED IN:
|
||||
{code_exec_results}
|
||||
|
||||
{example_code}"""
|
||||
|
||||
ERROR_CORRECTOR_SYSTEM = """The execution of the code resulted in an error.
|
||||
Return a complete, corrected python code that incorporates the fixes for the error.
|
||||
|
||||
**Reminders:**
|
||||
- You have access to a database client via the variable `db`.
|
||||
- Use `db.query_df(sql)` to run queries.
|
||||
- Use `plots.append(fig)` for plots.
|
||||
- Always include imports and print statements."""
|
||||
|
||||
ERROR_CORRECTOR_USER = """FAILED CODE:
|
||||
```python
|
||||
{code}
|
||||
```
|
||||
|
||||
ERROR:
|
||||
{error}"""
|
||||
|
||||
CODE_GENERATOR_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", CODE_GENERATOR_SYSTEM),
|
||||
("human", CODE_GENERATOR_USER),
|
||||
])
|
||||
|
||||
ERROR_CORRECTOR_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", ERROR_CORRECTOR_SYSTEM),
|
||||
("human", ERROR_CORRECTOR_USER),
|
||||
])
|
||||
46
backend/src/ea_chatbot/graph/prompts/planner.py
Normal file
46
backend/src/ea_chatbot/graph/prompts/planner.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
PLANNER_SYSTEM = """You are a Research Specialist and your job is to find answers and educate the user.
|
||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
||||
|
||||
Today's Date is: {date}"""
|
||||
|
||||
PLANNER_USER = """Conversation Summary: {summary}
|
||||
|
||||
TASK:
|
||||
{question}
|
||||
|
||||
AVAILABLE DATA SUMMARY (Use only if relevant to the task):
|
||||
{database_description}
|
||||
|
||||
First: Evaluate whether you have all necessary and requested information to provide a solution.
|
||||
Use the dataset description above to determine what data and in what format you have available to you.
|
||||
You are able to search internet if the user asks for it, or you require any information that you can not derive from the given dataset or the instruction.
|
||||
|
||||
Second: Incorporate any additional relevant context, reasoning, or details from previous interactions or internal chain-of-thought that may impact the solution.
|
||||
Ensure that all such information is fully included in your response rather than referring to previous answers indirectly.
|
||||
|
||||
Third: Reflect on the problem and briefly describe it, while addressing the problem goal, inputs, outputs,
|
||||
rules, constraints, and other relevant details that appear in the problem description.
|
||||
|
||||
Fourth: Based on the preceding steps, formulate your response as an algorithm, breaking the solution in up to eight simple concise yet descriptive, clear English steps.
|
||||
You MUST Include all values or instructions as described in the above task, or retrieved using internet search!
|
||||
If fewer steps suffice, that's acceptable. If more are needed, please include them.
|
||||
Remember to explain steps rather than write code.
|
||||
|
||||
This algorithm will be later converted to Python code.
|
||||
If a dataframe is required, assume it is named 'df' and is already defined/populated based on the data summary.
|
||||
|
||||
There is a list variable called `plots` that you need to use to store any plots you generate. Do not use `plt.show()` as it will render the plot and cause an error.
|
||||
|
||||
Output the algorithm as a YAML string. Always enclose the YAML string within ```yaml tags.
|
||||
|
||||
**Note: Ensure that any necessary context from prior interactions is fully embedded in the plan. Do not use phrases like "refer to previous answer"; instead, provide complete details inline.**
|
||||
|
||||
{example_plan}"""
|
||||
|
||||
PLANNER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", PLANNER_SYSTEM),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", PLANNER_USER),
|
||||
])
|
||||
33
backend/src/ea_chatbot/graph/prompts/query_analyzer.py
Normal file
33
backend/src/ea_chatbot/graph/prompts/query_analyzer.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
SYSTEM_PROMPT = """You are an expert election data analyst. Decompose the user's question into key elements to determine the next action.
|
||||
|
||||
### Context & Defaults
|
||||
- **History:** Use the conversation history and summary to resolve coreferences (e.g., "those results", "that state"). Assume the current question inherits missing context (Year, State, County) from history.
|
||||
- **Data Access:** You have access to voter and election databases. Proceed to planning without asking for database or table names.
|
||||
- **Downstream Capabilities:** Visualizations are generated as Matplotlib figures. Proceed to planning for "graphs" or "plots" without asking for file formats or plot types.
|
||||
- **Trends:** For trend requests without a specified interval, allow the Planner to use a sensible default (e.g., by election cycle).
|
||||
|
||||
### Instructions:
|
||||
1. **Analyze:** Identify if the request is for data analysis, general facts (web research), or is critically ambiguous.
|
||||
2. **Extract Entities & Conditions:**
|
||||
- **Data Required:** e.g., "vote count", "demographics".
|
||||
- **Conditions:** e.g., "Year=2024". Include context from history.
|
||||
3. **Identify Target & Critical Ambiguities:**
|
||||
- **Unknowns:** The core target question.
|
||||
- **Critical Ambiguities:** ONLY list issues that PREVENT any analysis.
|
||||
- Examples: No timeframe/geography in query OR history; "track the same voter" without an identity definition.
|
||||
4. **Determine Action:**
|
||||
- `plan`: For data analysis where defaults or history provide sufficient context.
|
||||
- `research`: For general knowledge.
|
||||
- `clarify`: ONLY for CRITICAL ambiguities."""
|
||||
|
||||
USER_PROMPT_TEMPLATE = """Conversation Summary: {summary}
|
||||
|
||||
Analyze the following question: {question}"""
|
||||
|
||||
QUERY_ANALYZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", SYSTEM_PROMPT),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", USER_PROMPT_TEMPLATE),
|
||||
])
|
||||
12
backend/src/ea_chatbot/graph/prompts/researcher.py
Normal file
12
backend/src/ea_chatbot/graph/prompts/researcher.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
RESEARCHER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", """You are a Research Specialist and your job is to find answers and educate the user.
|
||||
Provide factual information responding directly to the user's question. Include key details and context to ensure your response comprehensively answers their query.
|
||||
|
||||
Today's Date is: {date}"""),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("user", """Conversation Summary: {summary}
|
||||
|
||||
{question}""")
|
||||
])
|
||||
27
backend/src/ea_chatbot/graph/prompts/summarizer.py
Normal file
27
backend/src/ea_chatbot/graph/prompts/summarizer.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
SUMMARIZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", """You are an expert election data analyst providing a final answer to the user.
|
||||
Use the provided conversation history and summary to ensure your response is contextually relevant and flows naturally from previous turns.
|
||||
|
||||
Conversation Summary: {summary}"""),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("user", """The user presented you with the following question.
|
||||
Question: {question}
|
||||
|
||||
To address this, you have designed an algorithm.
|
||||
Algorithm: {plan}.
|
||||
|
||||
You have crafted a Python code based on this algorithm, and the output generated by the code's execution is as follows.
|
||||
Output: {code_output}.
|
||||
|
||||
Please produce a comprehensive, easy-to-understand answer that:
|
||||
1. Summarizes the main insights or conclusions achieved through your method's implementation. Include execution results if necessary.
|
||||
2. Includes relevant findings from the code execution in a clear format (e.g., text explanation, tables, lists, bullet points).
|
||||
- Avoid referencing the code or output as 'the above results' or saying 'it's in the code output.'
|
||||
- Instead, present the actual key data or statistics within your explanation.
|
||||
3. If the user requested specific information that does not appear in the code's output but you can provide it, include that information directly in your summary.
|
||||
4. Present any data or tables that might have been generated by the code in full, since the user cannot directly see the execution output.
|
||||
|
||||
Your goal is to give a final answer that stands on its own without requiring the user to see the code or raw output directly.""")
|
||||
])
|
||||
36
backend/src/ea_chatbot/graph/state.py
Normal file
36
backend/src/ea_chatbot/graph/state.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain.agents import AgentState as AS
|
||||
import operator
|
||||
|
||||
class AgentState(AS):
|
||||
# Conversation history
|
||||
messages: Annotated[List[BaseMessage], operator.add]
|
||||
|
||||
# Task context
|
||||
question: str
|
||||
|
||||
# Query Analysis (Decomposition results)
|
||||
analysis: Optional[Dict[str, Any]]
|
||||
# Expected keys: "requires_dataset", "expert", "data", "unknown", "condition"
|
||||
|
||||
# Step-by-step reasoning
|
||||
plan: Optional[str]
|
||||
|
||||
# Code execution context
|
||||
code: Optional[str]
|
||||
code_output: Optional[str]
|
||||
error: Optional[str]
|
||||
|
||||
# Artifacts (for UI display)
|
||||
plots: Annotated[List[Any], operator.add] # Matplotlib figures
|
||||
dfs: Dict[str, Any] # Pandas DataFrames
|
||||
|
||||
# Conversation summary
|
||||
summary: Optional[str]
|
||||
|
||||
# Routing hint: "clarify", "plan", "research", "end"
|
||||
next_action: str
|
||||
|
||||
# Number of execution attempts
|
||||
iterations: int
|
||||
92
backend/src/ea_chatbot/graph/workflow.py
Normal file
92
backend/src/ea_chatbot/graph/workflow.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.clarification import clarification_node
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
MAX_ITERATIONS = 3
|
||||
|
||||
def router(state: AgentState) -> str:
|
||||
"""Route to the next node based on the analysis."""
|
||||
next_action = state.get("next_action")
|
||||
if next_action == "plan":
|
||||
return "planner"
|
||||
elif next_action == "research":
|
||||
return "researcher"
|
||||
elif next_action == "clarify":
|
||||
return "clarification"
|
||||
else:
|
||||
return END
|
||||
|
||||
def create_workflow():
|
||||
"""Create the LangGraph workflow."""
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add nodes
|
||||
workflow.add_node("query_analyzer", query_analyzer_node)
|
||||
workflow.add_node("planner", planner_node)
|
||||
workflow.add_node("coder", coder_node)
|
||||
workflow.add_node("error_corrector", error_corrector_node)
|
||||
workflow.add_node("researcher", researcher_node)
|
||||
workflow.add_node("clarification", clarification_node)
|
||||
workflow.add_node("executor", executor_node)
|
||||
workflow.add_node("summarizer", summarizer_node)
|
||||
workflow.add_node("summarize_conversation", summarize_conversation_node)
|
||||
|
||||
# Set entry point
|
||||
workflow.set_entry_point("query_analyzer")
|
||||
|
||||
# Add conditional edges from query_analyzer
|
||||
workflow.add_conditional_edges(
|
||||
"query_analyzer",
|
||||
router,
|
||||
{
|
||||
"planner": "planner",
|
||||
"researcher": "researcher",
|
||||
"clarification": "clarification",
|
||||
END: END
|
||||
}
|
||||
)
|
||||
|
||||
# Linear flow for planning and coding
|
||||
workflow.add_edge("planner", "coder")
|
||||
workflow.add_edge("coder", "executor")
|
||||
|
||||
# Executor routing
|
||||
def executor_router(state: AgentState) -> str:
|
||||
if state.get("error"):
|
||||
# Check for iteration limit to prevent infinite loops
|
||||
if state.get("iterations", 0) >= MAX_ITERATIONS:
|
||||
return "summarizer"
|
||||
return "error_corrector"
|
||||
return "summarizer"
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"executor",
|
||||
executor_router,
|
||||
{
|
||||
"error_corrector": "error_corrector",
|
||||
"summarizer": "summarizer"
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("error_corrector", "executor")
|
||||
|
||||
workflow.add_edge("researcher", "summarize_conversation")
|
||||
workflow.add_edge("clarification", END)
|
||||
workflow.add_edge("summarizer", "summarize_conversation")
|
||||
workflow.add_edge("summarize_conversation", END)
|
||||
|
||||
# Compile the graph
|
||||
app = workflow.compile()
|
||||
|
||||
return app
|
||||
|
||||
# Initialize the app
|
||||
app = create_workflow()
|
||||
Reference in New Issue
Block a user