chore(graph): Remove obsolete linear nodes and legacy tests
This commit is contained in:
@@ -1,47 +0,0 @@
|
|||||||
from ea_chatbot.graph.state import AgentState
|
|
||||||
from ea_chatbot.config import Settings
|
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
|
||||||
from ea_chatbot.utils import database_inspection
|
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
|
||||||
from ea_chatbot.graph.prompts.coder import CODE_GENERATOR_PROMPT
|
|
||||||
from ea_chatbot.schemas import CodeGenerationResponse
|
|
||||||
|
|
||||||
def coder_node(state: AgentState) -> dict:
|
|
||||||
"""Generate Python code based on the plan and data summary."""
|
|
||||||
question = state["question"]
|
|
||||||
plan = state.get("plan", "")
|
|
||||||
code_output = state.get("code_output", "None")
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
logger = get_logger("coder")
|
|
||||||
|
|
||||||
logger.info("Generating Python code...")
|
|
||||||
|
|
||||||
llm = get_llm_model(
|
|
||||||
settings.coder_llm,
|
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
|
||||||
)
|
|
||||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
|
||||||
|
|
||||||
# Always provide data summary
|
|
||||||
database_description = database_inspection.get_data_summary(data_dir=settings.data_dir) or "No data available."
|
|
||||||
example_code = "" # Placeholder
|
|
||||||
|
|
||||||
messages = CODE_GENERATOR_PROMPT.format_messages(
|
|
||||||
question=question,
|
|
||||||
plan=plan,
|
|
||||||
database_description=database_description,
|
|
||||||
code_exec_results=code_output,
|
|
||||||
example_code=example_code
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = structured_llm.invoke(messages)
|
|
||||||
logger.info("[bold green]Code generated.[/bold green]")
|
|
||||||
return {
|
|
||||||
"code": response.parsed_code,
|
|
||||||
"error": None # Clear previous errors on new code generation
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate code: {str(e)}")
|
|
||||||
raise e
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
from ea_chatbot.graph.state import AgentState
|
|
||||||
from ea_chatbot.config import Settings
|
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
|
||||||
from ea_chatbot.graph.prompts.coder import ERROR_CORRECTOR_PROMPT
|
|
||||||
from ea_chatbot.schemas import CodeGenerationResponse
|
|
||||||
|
|
||||||
def error_corrector_node(state: AgentState) -> dict:
|
|
||||||
"""Fix the code based on the execution error."""
|
|
||||||
code = state.get("code", "")
|
|
||||||
error = state.get("error", "Unknown error")
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
logger = get_logger("error_corrector")
|
|
||||||
|
|
||||||
logger.warning(f"[bold red]Execution error detected:[/bold red] {error[:100]}...")
|
|
||||||
logger.info("Attempting to correct the code...")
|
|
||||||
|
|
||||||
# Reuse coder LLM config or add a new one. Using coder_llm for now.
|
|
||||||
llm = get_llm_model(
|
|
||||||
settings.coder_llm,
|
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
|
||||||
)
|
|
||||||
structured_llm = llm.with_structured_output(CodeGenerationResponse)
|
|
||||||
|
|
||||||
messages = ERROR_CORRECTOR_PROMPT.format_messages(
|
|
||||||
code=code,
|
|
||||||
error=error
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = structured_llm.invoke(messages)
|
|
||||||
logger.info("[bold green]Correction generated.[/bold green]")
|
|
||||||
|
|
||||||
current_iterations = state.get("iterations", 0)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"code": response.parsed_code,
|
|
||||||
"error": None, # Clear error after fix attempt
|
|
||||||
"iterations": current_iterations + 1
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to correct code: {str(e)}")
|
|
||||||
raise e
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
import io
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import copy
|
|
||||||
from contextlib import redirect_stdout
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
import pandas as pd
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
|
|
||||||
from ea_chatbot.graph.state import AgentState
|
|
||||||
from ea_chatbot.utils.db_client import DBClient
|
|
||||||
from ea_chatbot.utils.vfs import VFSHelper, safe_vfs_copy
|
|
||||||
from ea_chatbot.utils.logging import get_logger
|
|
||||||
from ea_chatbot.config import Settings
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ea_chatbot.types import DBSettings
|
|
||||||
|
|
||||||
def executor_node(state: AgentState) -> dict:
|
|
||||||
"""Execute the Python code and capture output, plots, and dataframes."""
|
|
||||||
code = state.get("code")
|
|
||||||
logger = get_logger("executor")
|
|
||||||
|
|
||||||
if not code:
|
|
||||||
logger.error("No code provided to executor.")
|
|
||||||
return {"error": "No code provided to executor."}
|
|
||||||
|
|
||||||
logger.info("Executing Python code...")
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
db_settings: "DBSettings" = {
|
|
||||||
"host": settings.db_host,
|
|
||||||
"port": settings.db_port,
|
|
||||||
"user": settings.db_user,
|
|
||||||
"pswd": settings.db_pswd,
|
|
||||||
"db": settings.db_name,
|
|
||||||
"table": settings.db_table
|
|
||||||
}
|
|
||||||
|
|
||||||
db_client = DBClient(settings=db_settings)
|
|
||||||
|
|
||||||
# Initialize the Virtual File System (VFS) helper
|
|
||||||
vfs_state = safe_vfs_copy(state.get("vfs", {}))
|
|
||||||
vfs_helper = VFSHelper(vfs_state)
|
|
||||||
|
|
||||||
# Initialize local variables for execution
|
|
||||||
# 'db' is the DBClient instance, 'plots' is for matplotlib figures
|
|
||||||
local_vars = {
|
|
||||||
'db': db_client,
|
|
||||||
'plots': [],
|
|
||||||
'pd': pd,
|
|
||||||
'vfs': vfs_helper
|
|
||||||
}
|
|
||||||
|
|
||||||
stdout_buffer = io.StringIO()
|
|
||||||
error = None
|
|
||||||
code_output = ""
|
|
||||||
plots = []
|
|
||||||
dfs = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with redirect_stdout(stdout_buffer):
|
|
||||||
# Execute the code in the context of local_vars
|
|
||||||
exec(code, {}, local_vars)
|
|
||||||
|
|
||||||
code_output = stdout_buffer.getvalue()
|
|
||||||
|
|
||||||
# Limit the output length if it's too long
|
|
||||||
if code_output.count('\n') > 32:
|
|
||||||
code_output = '\n'.join(code_output.split('\n')[:32]) + '\n...'
|
|
||||||
|
|
||||||
# Extract plots
|
|
||||||
raw_plots = local_vars.get('plots', [])
|
|
||||||
if isinstance(raw_plots, list):
|
|
||||||
plots = [p for p in raw_plots if isinstance(p, Figure)]
|
|
||||||
|
|
||||||
# Extract DataFrames that were likely intended for display
|
|
||||||
# We look for DataFrames in local_vars that were mentioned in the code
|
|
||||||
for key, value in local_vars.items():
|
|
||||||
if isinstance(value, pd.DataFrame):
|
|
||||||
# Heuristic: if the variable name is in the code, it might be a result DF
|
|
||||||
if key in code:
|
|
||||||
dfs[key] = value
|
|
||||||
|
|
||||||
logger.info(f"[bold green]Execution complete.[/bold green] Captured {len(plots)} plots and {len(dfs)} dataframes.")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Capture the traceback
|
|
||||||
exc_type, exc_value, tb = sys.exc_info()
|
|
||||||
full_traceback = traceback.format_exc()
|
|
||||||
|
|
||||||
# Filter traceback to show only the relevant part (the executed string)
|
|
||||||
filtered_tb_lines = [line for line in full_traceback.split('\n') if '<string>' in line]
|
|
||||||
error = '\n'.join(filtered_tb_lines)
|
|
||||||
if error:
|
|
||||||
error += '\n'
|
|
||||||
error += f"{exc_type.__name__ if exc_type else 'Exception'}: {exc_value}"
|
|
||||||
|
|
||||||
logger.error(f"Execution failed: {str(e)}")
|
|
||||||
|
|
||||||
# If we have an error, we still might want to see partial stdout
|
|
||||||
code_output = stdout_buffer.getvalue()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"code_output": code_output,
|
|
||||||
"error": error,
|
|
||||||
"plots": plots,
|
|
||||||
"dfs": dfs,
|
|
||||||
"vfs": vfs_state
|
|
||||||
}
|
|
||||||
@@ -38,7 +38,7 @@ def planner_node(state: AgentState) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = structured_llm.invoke(messages)
|
response = ChecklistResponse.model_validate(structured_llm.invoke(messages))
|
||||||
# Convert ChecklistTask objects to dicts for state
|
# Convert ChecklistTask objects to dicts for state
|
||||||
checklist = [task.model_dump() for task in response.checklist]
|
checklist = [task.model_dump() for task in response.checklist]
|
||||||
logger.info(f"[bold green]Checklist generated with {len(checklist)} tasks.[/bold green]")
|
logger.info(f"[bold green]Checklist generated with {len(checklist)} tasks.[/bold green]")
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
||||||
from ea_chatbot.graph.state import AgentState
|
|
||||||
from ea_chatbot.config import Settings
|
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
|
||||||
from ea_chatbot.utils import helpers
|
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
|
||||||
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
|
||||||
|
|
||||||
def researcher_node(state: AgentState) -> dict:
|
|
||||||
"""Handle general research queries or web searches."""
|
|
||||||
question = state["question"]
|
|
||||||
history = state.get("messages", [])[-6:]
|
|
||||||
summary = state.get("summary", "")
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
logger = get_logger("researcher")
|
|
||||||
|
|
||||||
logger.info(f"Researching question: [italic]\"{question}\"[/italic]")
|
|
||||||
|
|
||||||
# Use researcher_llm from settings
|
|
||||||
llm = get_llm_model(
|
|
||||||
settings.researcher_llm,
|
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
|
||||||
)
|
|
||||||
|
|
||||||
date_str = helpers.get_readable_date()
|
|
||||||
|
|
||||||
messages = RESEARCHER_PROMPT.format_messages(
|
|
||||||
date=date_str,
|
|
||||||
question=question,
|
|
||||||
history=history,
|
|
||||||
summary=summary
|
|
||||||
)
|
|
||||||
|
|
||||||
# Provider-aware tool binding
|
|
||||||
try:
|
|
||||||
if isinstance(llm, ChatGoogleGenerativeAI):
|
|
||||||
# Native Google Search for Gemini
|
|
||||||
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
|
||||||
elif isinstance(llm, ChatOpenAI):
|
|
||||||
# Native Web Search for OpenAI (built-in tool)
|
|
||||||
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
|
||||||
else:
|
|
||||||
# Fallback for other providers that might not support these specific search tools
|
|
||||||
llm_with_tools = llm
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to bind search tools: {str(e)}. Falling back to base LLM.")
|
|
||||||
llm_with_tools = llm
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = llm_with_tools.invoke(messages)
|
|
||||||
logger.info("[bold green]Research complete.[/bold green]")
|
|
||||||
return {
|
|
||||||
"messages": [response]
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Research failed: {str(e)}")
|
|
||||||
raise e
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
from ea_chatbot.graph.state import AgentState
|
|
||||||
from ea_chatbot.config import Settings
|
|
||||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
|
||||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
|
||||||
from ea_chatbot.graph.prompts.summarizer import SUMMARIZER_PROMPT
|
|
||||||
|
|
||||||
def summarizer_node(state: AgentState) -> dict:
|
|
||||||
"""Summarize the code execution results into a final answer."""
|
|
||||||
question = state["question"]
|
|
||||||
plan = state.get("plan", "")
|
|
||||||
code_output = state.get("code_output", "")
|
|
||||||
history = state.get("messages", [])[-6:]
|
|
||||||
summary = state.get("summary", "")
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
logger = get_logger("summarizer")
|
|
||||||
|
|
||||||
logger.info("Generating final summary...")
|
|
||||||
|
|
||||||
llm = get_llm_model(
|
|
||||||
settings.summarizer_llm,
|
|
||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = SUMMARIZER_PROMPT.format_messages(
|
|
||||||
question=question,
|
|
||||||
plan=plan,
|
|
||||||
code_output=code_output,
|
|
||||||
history=history,
|
|
||||||
summary=summary
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = llm.invoke(messages)
|
|
||||||
logger.info("[bold green]Summary generated.[/bold green]")
|
|
||||||
|
|
||||||
# Return the final message to be added to the state
|
|
||||||
return {
|
|
||||||
"messages": [response]
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate summary: {str(e)}")
|
|
||||||
raise e
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from ea_chatbot.graph.nodes.coder import coder_node
|
|
||||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_state():
|
|
||||||
return {
|
|
||||||
"messages": [],
|
|
||||||
"question": "Show me results for New Jersey",
|
|
||||||
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
|
|
||||||
"code": None,
|
|
||||||
"error": None,
|
|
||||||
"plots": [],
|
|
||||||
"dfs": {},
|
|
||||||
"next_action": "plan"
|
|
||||||
}
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
|
||||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
|
||||||
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
|
|
||||||
"""Test coder node generates code from plan."""
|
|
||||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
|
||||||
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
from ea_chatbot.schemas import CodeGenerationResponse
|
|
||||||
mock_response = CodeGenerationResponse(
|
|
||||||
code="import pandas as pd\nprint('Hello')",
|
|
||||||
explanation="Generated code"
|
|
||||||
)
|
|
||||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
|
||||||
|
|
||||||
result = coder_node(mock_state)
|
|
||||||
|
|
||||||
assert "code" in result
|
|
||||||
assert "import pandas as pd" in result["code"]
|
|
||||||
assert "error" in result
|
|
||||||
assert result["error"] is None
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
|
|
||||||
def test_error_corrector_node(mock_get_llm, mock_state):
|
|
||||||
"""Test error corrector node fixes code."""
|
|
||||||
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
|
|
||||||
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
|
|
||||||
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
from ea_chatbot.schemas import CodeGenerationResponse
|
|
||||||
mock_response = CodeGenerationResponse(
|
|
||||||
code="import pandas as pd\nprint('Defined')",
|
|
||||||
explanation="Fixed variable"
|
|
||||||
)
|
|
||||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
|
||||||
|
|
||||||
result = error_corrector_node(mock_state)
|
|
||||||
|
|
||||||
assert "code" in result
|
|
||||||
assert "print('Defined')" in result["code"]
|
|
||||||
assert result["error"] is None
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import pandas as pd
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from ea_chatbot.graph.nodes.executor import executor_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_settings():
|
|
||||||
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
|
|
||||||
mock_settings_instance = MagicMock()
|
|
||||||
mock_settings_instance.db_host = "localhost"
|
|
||||||
mock_settings_instance.db_port = 5432
|
|
||||||
mock_settings_instance.db_user = "user"
|
|
||||||
mock_settings_instance.db_pswd = "pass"
|
|
||||||
mock_settings_instance.db_name = "test_db"
|
|
||||||
mock_settings_instance.db_table = "test_table"
|
|
||||||
MockSettings.return_value = mock_settings_instance
|
|
||||||
yield mock_settings_instance
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_db_client():
|
|
||||||
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
|
|
||||||
mock_client_instance = MagicMock()
|
|
||||||
MockDBClient.return_value = mock_client_instance
|
|
||||||
yield mock_client_instance
|
|
||||||
|
|
||||||
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
|
|
||||||
"""Test executing simple code that prints to stdout."""
|
|
||||||
state = {
|
|
||||||
"code": "print('Hello, World!')",
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert "code_output" in result
|
|
||||||
assert "Hello, World!" in result["code_output"]
|
|
||||||
assert result["error"] is None
|
|
||||||
assert result["plots"] == []
|
|
||||||
assert result["dfs"] == {}
|
|
||||||
|
|
||||||
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
|
|
||||||
"""Test executing code that creates a DataFrame."""
|
|
||||||
code = """
|
|
||||||
import pandas as pd
|
|
||||||
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
|
|
||||||
print(df)
|
|
||||||
"""
|
|
||||||
state = {
|
|
||||||
"code": code,
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert "code_output" in result
|
|
||||||
assert "a b" in result["code_output"] # Check part of DF string representation
|
|
||||||
assert "dfs" in result
|
|
||||||
assert "df" in result["dfs"]
|
|
||||||
assert isinstance(result["dfs"]["df"], pd.DataFrame)
|
|
||||||
|
|
||||||
def test_executor_node_success_plot(mock_settings, mock_db_client):
|
|
||||||
"""Test executing code that generates a plot."""
|
|
||||||
code = """
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
fig = plt.figure()
|
|
||||||
plots.append(fig)
|
|
||||||
print('Plot generated')
|
|
||||||
"""
|
|
||||||
state = {
|
|
||||||
"code": code,
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert "Plot generated" in result["code_output"]
|
|
||||||
assert "plots" in result
|
|
||||||
assert len(result["plots"]) == 1
|
|
||||||
assert isinstance(result["plots"][0], Figure)
|
|
||||||
|
|
||||||
def test_executor_node_error_syntax(mock_settings, mock_db_client):
|
|
||||||
"""Test executing code with a syntax error."""
|
|
||||||
state = {
|
|
||||||
"code": "print('Hello World", # Missing closing quote
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert result["error"] is not None
|
|
||||||
assert "SyntaxError" in result["error"]
|
|
||||||
|
|
||||||
def test_executor_node_error_runtime(mock_settings, mock_db_client):
|
|
||||||
"""Test executing code with a runtime error."""
|
|
||||||
state = {
|
|
||||||
"code": "print(1 / 0)",
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert result["error"] is not None
|
|
||||||
assert "ZeroDivisionError" in result["error"]
|
|
||||||
|
|
||||||
def test_executor_node_no_code(mock_settings, mock_db_client):
|
|
||||||
"""Test handling when no code is provided."""
|
|
||||||
state = {
|
|
||||||
"code": None,
|
|
||||||
"question": "test",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
assert "error" in result
|
|
||||||
assert "No code provided" in result["error"]
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
from ea_chatbot.graph.nodes.executor import executor_node
|
|
||||||
from ea_chatbot.graph.state import AgentState
|
|
||||||
|
|
||||||
def test_executor_with_vfs():
|
|
||||||
"""Verify that the executor node provides VFS access to the code."""
|
|
||||||
state = AgentState(
|
|
||||||
messages=[],
|
|
||||||
question="test",
|
|
||||||
analysis={},
|
|
||||||
next_action="test",
|
|
||||||
iterations=0,
|
|
||||||
checklist=[],
|
|
||||||
current_step=0,
|
|
||||||
vfs={},
|
|
||||||
plots=[],
|
|
||||||
dfs={}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Code that uses the 'vfs' helper
|
|
||||||
code = """
|
|
||||||
vfs.write("output.txt", "Execution Result", metadata={"type": "text"})
|
|
||||||
print("VFS Write Complete")
|
|
||||||
"""
|
|
||||||
state["code"] = code
|
|
||||||
|
|
||||||
result = executor_node(state)
|
|
||||||
|
|
||||||
# Check if the execution was successful
|
|
||||||
assert result["error"] is None
|
|
||||||
assert "VFS Write Complete" in result["code_output"]
|
|
||||||
|
|
||||||
# Verify that the VFS state was updated
|
|
||||||
# Note: executor_node returns a dict of updates, which should include the updated 'vfs'
|
|
||||||
assert "vfs" in result
|
|
||||||
assert "output.txt" in result["vfs"]
|
|
||||||
assert result["vfs"]["output.txt"]["content"] == "Execution Result"
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_llm():
|
|
||||||
with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm:
|
|
||||||
mock_llm_instance = MagicMock(spec=ChatOpenAI)
|
|
||||||
mock_get_llm.return_value = mock_llm_instance
|
|
||||||
yield mock_llm_instance
|
|
||||||
|
|
||||||
def test_researcher_node_success(mock_llm):
|
|
||||||
"""Test that researcher_node invokes LLM with web_search tool and returns messages."""
|
|
||||||
state = {
|
|
||||||
"question": "What is the capital of France?",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_llm_with_tools = MagicMock()
|
|
||||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
|
||||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.")
|
|
||||||
|
|
||||||
result = researcher_node(state)
|
|
||||||
|
|
||||||
assert mock_llm.bind_tools.called
|
|
||||||
# Check that it was called with web_search
|
|
||||||
args, kwargs = mock_llm.bind_tools.call_args
|
|
||||||
assert {"type": "web_search"} in args[0]
|
|
||||||
|
|
||||||
assert mock_llm_with_tools.invoke.called
|
|
||||||
assert "messages" in result
|
|
||||||
assert result["messages"][0].content == "The capital of France is Paris."
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
||||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def base_state():
|
|
||||||
return {
|
|
||||||
"question": "Who won the 2024 election?",
|
|
||||||
"messages": [],
|
|
||||||
"summary": ""
|
|
||||||
}
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
||||||
def test_researcher_binds_openai_search(mock_get_llm, base_state):
|
|
||||||
"""Test that OpenAI LLM binds 'web_search' tool."""
|
|
||||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
mock_llm_with_tools = MagicMock()
|
|
||||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
|
||||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
|
|
||||||
|
|
||||||
result = researcher_node(base_state)
|
|
||||||
|
|
||||||
# Verify bind_tools called with correct OpenAI tool
|
|
||||||
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
|
|
||||||
assert result["messages"][0].content == "OpenAI Search Result"
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
||||||
def test_researcher_binds_google_search(mock_get_llm, base_state):
|
|
||||||
"""Test that Google LLM binds 'google_search' tool."""
|
|
||||||
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
mock_llm_with_tools = MagicMock()
|
|
||||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
|
||||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
|
|
||||||
|
|
||||||
result = researcher_node(base_state)
|
|
||||||
|
|
||||||
# Verify bind_tools called with correct Google tool
|
|
||||||
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
|
|
||||||
assert result["messages"][0].content == "Google Search Result"
|
|
||||||
|
|
||||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
||||||
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
|
|
||||||
"""Test that researcher falls back to basic LLM if bind_tools fails."""
|
|
||||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
|
||||||
mock_get_llm.return_value = mock_llm
|
|
||||||
|
|
||||||
# Simulate bind_tools failing (e.g. model doesn't support it)
|
|
||||||
mock_llm.bind_tools.side_effect = Exception("Not supported")
|
|
||||||
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
|
|
||||||
|
|
||||||
result = researcher_node(base_state)
|
|
||||||
|
|
||||||
# Should still succeed using the base LLM
|
|
||||||
assert result["messages"][0].content == "Basic Result"
|
|
||||||
mock_llm.invoke.assert_called_once()
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_llm():
|
|
||||||
with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm:
|
|
||||||
mock_llm_instance = MagicMock()
|
|
||||||
mock_get_llm.return_value = mock_llm_instance
|
|
||||||
yield mock_llm_instance
|
|
||||||
|
|
||||||
def test_summarizer_node_success(mock_llm):
|
|
||||||
"""Test that summarizer_node invokes LLM with correct inputs and returns messages."""
|
|
||||||
state = {
|
|
||||||
"question": "What is the total count?",
|
|
||||||
"plan": "1. Run query\n2. Sum results",
|
|
||||||
"code_output": "The total is 100",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.")
|
|
||||||
|
|
||||||
result = summarizer_node(state)
|
|
||||||
|
|
||||||
# Verify LLM was called
|
|
||||||
assert mock_llm.invoke.called
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "messages" in result
|
|
||||||
assert len(result["messages"]) == 1
|
|
||||||
assert isinstance(result["messages"][0], AIMessage)
|
|
||||||
assert result["messages"][0].content == "The final answer is 100."
|
|
||||||
|
|
||||||
def test_summarizer_node_empty_state(mock_llm):
|
|
||||||
"""Test handling of empty or minimal state."""
|
|
||||||
state = {
|
|
||||||
"question": "Empty?",
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_llm.invoke.return_value = AIMessage(content="No data provided.")
|
|
||||||
|
|
||||||
result = summarizer_node(state)
|
|
||||||
|
|
||||||
assert "messages" in result
|
|
||||||
assert result["messages"][0].content == "No data provided."
|
|
||||||
@@ -12,7 +12,8 @@ def mock_llms():
|
|||||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
|
patch("ea_chatbot.graph.workers.data_analyst.nodes.coder.get_llm_model") as mock_coder, \
|
||||||
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
|
patch("ea_chatbot.graph.workers.data_analyst.nodes.summarizer.get_llm_model") as mock_worker_summarizer, \
|
||||||
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
|
patch("ea_chatbot.graph.nodes.synthesizer.get_llm_model") as mock_synthesizer, \
|
||||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher, \
|
patch("ea_chatbot.graph.workers.researcher.nodes.searcher.get_llm_model") as mock_researcher, \
|
||||||
|
patch("ea_chatbot.graph.workers.researcher.nodes.summarizer.get_llm_model") as mock_res_summarizer, \
|
||||||
patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector:
|
patch("ea_chatbot.graph.nodes.reflector.get_llm_model") as mock_reflector:
|
||||||
yield {
|
yield {
|
||||||
"qa": mock_qa,
|
"qa": mock_qa,
|
||||||
@@ -21,6 +22,7 @@ def mock_llms():
|
|||||||
"worker_summarizer": mock_worker_summarizer,
|
"worker_summarizer": mock_worker_summarizer,
|
||||||
"synthesizer": mock_synthesizer,
|
"synthesizer": mock_synthesizer,
|
||||||
"researcher": mock_researcher,
|
"researcher": mock_researcher,
|
||||||
|
"res_summarizer": mock_res_summarizer,
|
||||||
"reflector": mock_reflector
|
"reflector": mock_reflector
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,17 +115,22 @@ def test_workflow_research_flow(mock_llms):
|
|||||||
checklist=[ChecklistTask(task="Search Web", worker="researcher")]
|
checklist=[ChecklistTask(task="Search Web", worker="researcher")]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Mock Researcher
|
# 3. Mock Researcher Searcher
|
||||||
mock_res_instance = MagicMock()
|
mock_res_instance = MagicMock()
|
||||||
mock_llms["researcher"].return_value = mock_res_instance
|
mock_llms["researcher"].return_value = mock_res_instance
|
||||||
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
|
mock_res_instance.invoke.return_value = AIMessage(content="Research Result")
|
||||||
|
|
||||||
# 4. Mock Reflector
|
# 4. Mock Researcher Summarizer
|
||||||
|
mock_rs_instance = MagicMock()
|
||||||
|
mock_llms["res_summarizer"].return_value = mock_rs_instance
|
||||||
|
mock_rs_instance.invoke.return_value = AIMessage(content="Researcher Summary")
|
||||||
|
|
||||||
|
# 5. Mock Reflector
|
||||||
mock_reflector_instance = MagicMock()
|
mock_reflector_instance = MagicMock()
|
||||||
mock_llms["reflector"].return_value = mock_reflector_instance
|
mock_llms["reflector"].return_value = mock_reflector_instance
|
||||||
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
mock_reflector_instance.with_structured_output.return_value.invoke.return_value = MagicMock(satisfied=True, reasoning="Good.")
|
||||||
|
|
||||||
# 5. Mock Synthesizer
|
# 6. Mock Synthesizer
|
||||||
mock_syn_instance = MagicMock()
|
mock_syn_instance = MagicMock()
|
||||||
mock_llms["synthesizer"].return_value = mock_syn_instance
|
mock_llms["synthesizer"].return_value = mock_syn_instance
|
||||||
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
|
mock_syn_instance.invoke.return_value = AIMessage(content="Final Research Summary")
|
||||||
|
|||||||
Reference in New Issue
Block a user