62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from ea_chatbot.graph.nodes.coder import coder_node
|
|
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
|
|
|
@pytest.fixture
|
|
def mock_state():
|
|
return {
|
|
"messages": [],
|
|
"question": "Show me results for New Jersey",
|
|
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
|
|
"code": None,
|
|
"error": None,
|
|
"plots": [],
|
|
"dfs": {},
|
|
"next_action": "plan"
|
|
}
|
|
|
|
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
|
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
|
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
|
|
"""Test coder node generates code from plan."""
|
|
mock_get_summary.return_value = "Column: Name, Type: text"
|
|
|
|
mock_llm = MagicMock()
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
from ea_chatbot.schemas import CodeGenerationResponse
|
|
mock_response = CodeGenerationResponse(
|
|
code="import pandas as pd\nprint('Hello')",
|
|
explanation="Generated code"
|
|
)
|
|
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
|
|
|
result = coder_node(mock_state)
|
|
|
|
assert "code" in result
|
|
assert "import pandas as pd" in result["code"]
|
|
assert "error" in result
|
|
assert result["error"] is None
|
|
|
|
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
|
|
def test_error_corrector_node(mock_get_llm, mock_state):
|
|
"""Test error corrector node fixes code."""
|
|
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
|
|
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
|
|
|
|
mock_llm = MagicMock()
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
from ea_chatbot.schemas import CodeGenerationResponse
|
|
mock_response = CodeGenerationResponse(
|
|
code="import pandas as pd\nprint('Defined')",
|
|
explanation="Fixed variable"
|
|
)
|
|
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
|
|
|
result = error_corrector_node(mock_state)
|
|
|
|
assert "code" in result
|
|
assert "print('Defined')" in result["code"]
|
|
assert result["error"] is None |