feat(workers): Implement Researcher worker subgraph for web research tasks
This commit is contained in:
67
backend/tests/test_researcher_worker.py
Normal file
67
backend/tests/test_researcher_worker.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker, WorkerState
|
||||
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
def test_researcher_worker_flow():
|
||||
"""Verify that the Researcher worker flow works as expected."""
|
||||
mock_searcher = MagicMock()
|
||||
mock_summarizer = MagicMock()
|
||||
|
||||
mock_searcher.return_value = {
|
||||
"raw_results": ["Result A"],
|
||||
"messages": [AIMessage(content="Search result")]
|
||||
}
|
||||
mock_summarizer.return_value = {"result": "Consolidated Summary"}
|
||||
|
||||
graph = create_researcher_worker(
|
||||
searcher=mock_searcher,
|
||||
summarizer=mock_summarizer
|
||||
)
|
||||
|
||||
initial_state = WorkerState(
|
||||
messages=[],
|
||||
task="Find governor",
|
||||
queries=[],
|
||||
raw_results=[],
|
||||
iterations=0,
|
||||
result=None
|
||||
)
|
||||
|
||||
final_state = graph.invoke(initial_state)
|
||||
|
||||
assert final_state["result"] == "Consolidated Summary"
|
||||
assert mock_searcher.called
|
||||
assert mock_summarizer.called
|
||||
|
||||
def test_researcher_mapping():
|
||||
"""Verify that we correctly map states for the researcher."""
|
||||
global_state = AgentState(
|
||||
checklist=[{"task": "Search X", "worker": "researcher"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
worker_input = prepare_researcher_input(global_state)
|
||||
assert worker_input["task"] == "Search X"
|
||||
|
||||
worker_output = WorkerState(
|
||||
messages=[],
|
||||
task="Search X",
|
||||
queries=[],
|
||||
raw_results=[],
|
||||
iterations=1,
|
||||
result="Found X"
|
||||
)
|
||||
|
||||
updates = merge_researcher_output(worker_output)
|
||||
assert updates["messages"][0].content == "Found X"
|
||||
Reference in New Issue
Block a user