68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock
|
|
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker, WorkerState
|
|
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
|
from ea_chatbot.graph.state import AgentState
|
|
from langchain_core.messages import AIMessage
|
|
|
|
def test_researcher_worker_flow():
|
|
"""Verify that the Researcher worker flow works as expected."""
|
|
mock_searcher = MagicMock()
|
|
mock_summarizer = MagicMock()
|
|
|
|
mock_searcher.return_value = {
|
|
"raw_results": ["Result A"],
|
|
"messages": [AIMessage(content="Search result")]
|
|
}
|
|
mock_summarizer.return_value = {"result": "Consolidated Summary"}
|
|
|
|
graph = create_researcher_worker(
|
|
searcher=mock_searcher,
|
|
summarizer=mock_summarizer
|
|
)
|
|
|
|
initial_state = WorkerState(
|
|
messages=[],
|
|
task="Find governor",
|
|
queries=[],
|
|
raw_results=[],
|
|
iterations=0,
|
|
result=None
|
|
)
|
|
|
|
final_state = graph.invoke(initial_state)
|
|
|
|
assert final_state["result"] == "Consolidated Summary"
|
|
assert mock_searcher.called
|
|
assert mock_summarizer.called
|
|
|
|
def test_researcher_mapping():
|
|
"""Verify that we correctly map states for the researcher."""
|
|
global_state = AgentState(
|
|
checklist=[{"task": "Search X", "worker": "researcher"}],
|
|
current_step=0,
|
|
messages=[],
|
|
question="test",
|
|
analysis={},
|
|
next_action="",
|
|
iterations=0,
|
|
vfs={},
|
|
plots=[],
|
|
dfs={}
|
|
)
|
|
|
|
worker_input = prepare_researcher_input(global_state)
|
|
assert worker_input["task"] == "Search X"
|
|
|
|
worker_output = WorkerState(
|
|
messages=[],
|
|
task="Search X",
|
|
queries=[],
|
|
raw_results=[],
|
|
iterations=1,
|
|
result="Found X"
|
|
)
|
|
|
|
updates = merge_researcher_output(worker_output)
|
|
assert updates["messages"][0].content == "Found X"
|