feat(workers): Implement Researcher worker subgraph for web research tasks

This commit is contained in:
Yunxiao Xu
2026-02-23 06:10:15 -08:00
parent 120b6fd11a
commit 5cc5bd91ae
6 changed files with 223 additions and 0 deletions

View File

@@ -0,0 +1,31 @@
from typing import Dict, Any, List
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.state import AgentState
from ea_chatbot.graph.workers.researcher.state import WorkerState
def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
"""Prepare the initial state for the Researcher worker."""
checklist = state.get("checklist", [])
current_step = state.get("current_step", 0)
task_desc = "Perform research"
if 0 <= current_step < len(checklist):
task_desc = checklist[current_step].get("task", task_desc)
return {
"task": task_desc,
"messages": [HumanMessage(content=task_desc)],
"queries": [],
"raw_results": [],
"iterations": 0,
"result": None
}
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
"""Map researcher results back to the global AgentState."""
result = worker_state.get("result", "Research complete.")
return {
"messages": [AIMessage(content=result)],
# Researcher doesn't usually update VFS or Plots, but we keep the structure
}

View File

@@ -0,0 +1,60 @@
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.workers.researcher.state import WorkerState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils import helpers
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
def searcher_node(state: WorkerState) -> dict:
"""Execute web research for the specific task."""
task = state["task"]
logger = get_logger("researcher_worker:searcher")
logger.info(f"Researching task: {task[:50]}...")
settings = Settings()
llm = get_llm_model(
settings.researcher_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
date_str = helpers.get_readable_date()
# Adapt the global researcher prompt for the sub-task
messages = RESEARCHER_PROMPT.format_messages(
date=date_str,
question=task,
history=[], # Worker has fresh context or task-specific history
summary=""
)
# Tool binding
try:
if isinstance(llm, ChatGoogleGenerativeAI):
llm_with_tools = llm.bind_tools([{"google_search": {}}])
elif isinstance(llm, ChatOpenAI):
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
else:
llm_with_tools = llm
except Exception as e:
logger.warning(f"Failed to bind search tools: {str(e)}")
llm_with_tools = llm
try:
response = llm_with_tools.invoke(messages)
logger.info("[bold green]Search complete.[/bold green]")
# In a real tool-use scenario, we'd extract the tool outputs here.
# For now, we'll store the response content as a 'raw_result'.
content = response.content if hasattr(response, "content") else str(response)
return {
"messages": [response],
"raw_results": [content],
"iterations": state.get("iterations", 0) + 1
}
except Exception as e:
logger.error(f"Search failed: {str(e)}")
raise e

View File

@@ -0,0 +1,41 @@
from ea_chatbot.graph.workers.researcher.state import WorkerState
from ea_chatbot.config import Settings
from ea_chatbot.utils.llm_factory import get_llm_model
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
def summarizer_node(state: WorkerState) -> dict:
"""Summarize research results for the Orchestrator."""
task = state["task"]
raw_results = state.get("raw_results", [])
logger = get_logger("researcher_worker:summarizer")
logger.info("Summarizing research results...")
settings = Settings()
llm = get_llm_model(
settings.planner_llm,
callbacks=[LangChainLoggingHandler(logger=logger)]
)
results_str = "\n---\n".join(raw_results)
prompt = f"""You are a Research Specialist sub-agent. You have completed a research sub-task.
Task: {task}
Raw Research Findings:
{results_str}
Provide a concise, factual summary of the findings for the top-level Orchestrator.
Ensure all key facts, dates, and sources (if provided) are preserved.
Do NOT include internal reasoning, just the factual summary."""
try:
response = llm.invoke(prompt)
result = response.content if hasattr(response, "content") else str(response)
logger.info("[bold green]Research summary complete.[/bold green]")
return {
"result": result
}
except Exception as e:
logger.error(f"Failed to summarize research: {str(e)}")
raise e

View File

@@ -0,0 +1,24 @@
from langgraph.graph import StateGraph, END
from ea_chatbot.graph.workers.researcher.state import WorkerState
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
def create_researcher_worker(
searcher=searcher_node,
summarizer=summarizer_node
) -> StateGraph:
"""Create the Researcher worker subgraph."""
workflow = StateGraph(WorkerState)
# Add Nodes
workflow.add_node("searcher", searcher)
workflow.add_node("summarizer", summarizer)
# Set entry point
workflow.set_entry_point("searcher")
# Add Edges
workflow.add_edge("searcher", "summarizer")
workflow.add_edge("summarizer", END)
return workflow.compile()

View File

@@ -0,0 +1,67 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker, WorkerState
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
from ea_chatbot.graph.state import AgentState
from langchain_core.messages import AIMessage
def test_researcher_worker_flow():
"""Verify that the Researcher worker flow works as expected."""
mock_searcher = MagicMock()
mock_summarizer = MagicMock()
mock_searcher.return_value = {
"raw_results": ["Result A"],
"messages": [AIMessage(content="Search result")]
}
mock_summarizer.return_value = {"result": "Consolidated Summary"}
graph = create_researcher_worker(
searcher=mock_searcher,
summarizer=mock_summarizer
)
initial_state = WorkerState(
messages=[],
task="Find governor",
queries=[],
raw_results=[],
iterations=0,
result=None
)
final_state = graph.invoke(initial_state)
assert final_state["result"] == "Consolidated Summary"
assert mock_searcher.called
assert mock_summarizer.called
def test_researcher_mapping():
"""Verify that we correctly map states for the researcher."""
global_state = AgentState(
checklist=[{"task": "Search X", "worker": "researcher"}],
current_step=0,
messages=[],
question="test",
analysis={},
next_action="",
iterations=0,
vfs={},
plots=[],
dfs={}
)
worker_input = prepare_researcher_input(global_state)
assert worker_input["task"] == "Search X"
worker_output = WorkerState(
messages=[],
task="Search X",
queries=[],
raw_results=[],
iterations=1,
result="Found X"
)
updates = merge_researcher_output(worker_output)
assert updates["messages"][0].content == "Found X"