124 lines
3.6 KiB
Python
124 lines
3.6 KiB
Python
import pytest
|
|
import pandas as pd
|
|
from unittest.mock import MagicMock, patch
|
|
from matplotlib.figure import Figure
|
|
from ea_chatbot.graph.nodes.executor import executor_node
|
|
from ea_chatbot.graph.state import AgentState
|
|
|
|
@pytest.fixture
|
|
def mock_settings():
|
|
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
|
|
mock_settings_instance = MagicMock()
|
|
mock_settings_instance.db_host = "localhost"
|
|
mock_settings_instance.db_port = 5432
|
|
mock_settings_instance.db_user = "user"
|
|
mock_settings_instance.db_pswd = "pass"
|
|
mock_settings_instance.db_name = "test_db"
|
|
mock_settings_instance.db_table = "test_table"
|
|
MockSettings.return_value = mock_settings_instance
|
|
yield mock_settings_instance
|
|
|
|
@pytest.fixture
|
|
def mock_db_client():
|
|
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
|
|
mock_client_instance = MagicMock()
|
|
MockDBClient.return_value = mock_client_instance
|
|
yield mock_client_instance
|
|
|
|
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
|
|
"""Test executing simple code that prints to stdout."""
|
|
state = {
|
|
"code": "print('Hello, World!')",
|
|
"question": "test",
|
|
"messages": []
|
|
}
|
|
|
|
result = executor_node(state)
|
|
|
|
assert "code_output" in result
|
|
assert "Hello, World!" in result["code_output"]
|
|
assert result["error"] is None
|
|
assert result["plots"] == []
|
|
assert result["dfs"] == {}
|
|
|
|
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
|
|
"""Test executing code that creates a DataFrame."""
|
|
code = """
|
|
import pandas as pd
|
|
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
|
|
print(df)
|
|
"""
|
|
state = {
|
|
"code": code,
|
|
"question": "test",
|
|
"messages": []
|
|
}
|
|
|
|
result = executor_node(state)
|
|
|
|
assert "code_output" in result
|
|
assert "a b" in result["code_output"] # Check part of DF string representation
|
|
assert "dfs" in result
|
|
assert "df" in result["dfs"]
|
|
assert isinstance(result["dfs"]["df"], pd.DataFrame)
|
|
|
|
def test_executor_node_success_plot(mock_settings, mock_db_client):
|
|
"""Test executing code that generates a plot."""
|
|
code = """
|
|
import matplotlib.pyplot as plt
|
|
fig = plt.figure()
|
|
plots.append(fig)
|
|
print('Plot generated')
|
|
"""
|
|
state = {
|
|
"code": code,
|
|
"question": "test",
|
|
"messages": []
|
|
}
|
|
|
|
result = executor_node(state)
|
|
|
|
assert "Plot generated" in result["code_output"]
|
|
assert "plots" in result
|
|
assert len(result["plots"]) == 1
|
|
assert isinstance(result["plots"][0], Figure)
|
|
|
|
def test_executor_node_error_syntax(mock_settings, mock_db_client):
|
|
"""Test executing code with a syntax error."""
|
|
state = {
|
|
"code": "print('Hello World", # Missing closing quote
|
|
"question": "test",
|
|
"messages": []
|
|
}
|
|
|
|
result = executor_node(state)
|
|
|
|
assert result["error"] is not None
|
|
assert "SyntaxError" in result["error"]
|
|
|
|
def test_executor_node_error_runtime(mock_settings, mock_db_client):
|
|
"""Test executing code with a runtime error."""
|
|
state = {
|
|
"code": "print(1 / 0)",
|
|
"question": "test",
|
|
"messages": []
|
|
}
|
|
|
|
result = executor_node(state)
|
|
|
|
assert result["error"] is not None
|
|
assert "ZeroDivisionError" in result["error"]
|
|
|
|
def test_executor_node_no_code(mock_settings, mock_db_client):
|
|
"""Test handling when no code is provided."""
|
|
state = {
|
|
"code": None,
|
|
"question": "test",
|
|
"messages": []
|
|
}
|
|
|
|
result = executor_node(state)
|
|
|
|
assert "error" in result
|
|
assert "No code provided" in result["error"]
|