feat(workers): Implement Researcher worker subgraph for web research tasks
This commit is contained in:
31
backend/src/ea_chatbot/graph/workers/researcher/mapping.py
Normal file
31
backend/src/ea_chatbot/graph/workers/researcher/mapping.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
|
||||
def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
|
||||
"""Prepare the initial state for the Researcher worker."""
|
||||
checklist = state.get("checklist", [])
|
||||
current_step = state.get("current_step", 0)
|
||||
|
||||
task_desc = "Perform research"
|
||||
if 0 <= current_step < len(checklist):
|
||||
task_desc = checklist[current_step].get("task", task_desc)
|
||||
|
||||
return {
|
||||
"task": task_desc,
|
||||
"messages": [HumanMessage(content=task_desc)],
|
||||
"queries": [],
|
||||
"raw_results": [],
|
||||
"iterations": 0,
|
||||
"result": None
|
||||
}
|
||||
|
||||
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||
"""Map researcher results back to the global AgentState."""
|
||||
result = worker_state.get("result", "Research complete.")
|
||||
|
||||
return {
|
||||
"messages": [AIMessage(content=result)],
|
||||
# Researcher doesn't usually update VFS or Plots, but we keep the structure
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils import helpers
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
||||
|
||||
def searcher_node(state: WorkerState) -> dict:
|
||||
"""Execute web research for the specific task."""
|
||||
task = state["task"]
|
||||
logger = get_logger("researcher_worker:searcher")
|
||||
|
||||
logger.info(f"Researching task: {task[:50]}...")
|
||||
|
||||
settings = Settings()
|
||||
llm = get_llm_model(
|
||||
settings.researcher_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
date_str = helpers.get_readable_date()
|
||||
|
||||
# Adapt the global researcher prompt for the sub-task
|
||||
messages = RESEARCHER_PROMPT.format_messages(
|
||||
date=date_str,
|
||||
question=task,
|
||||
history=[], # Worker has fresh context or task-specific history
|
||||
summary=""
|
||||
)
|
||||
|
||||
# Tool binding
|
||||
try:
|
||||
if isinstance(llm, ChatGoogleGenerativeAI):
|
||||
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
||||
elif isinstance(llm, ChatOpenAI):
|
||||
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
||||
else:
|
||||
llm_with_tools = llm
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind search tools: {str(e)}")
|
||||
llm_with_tools = llm
|
||||
|
||||
try:
|
||||
response = llm_with_tools.invoke(messages)
|
||||
logger.info("[bold green]Search complete.[/bold green]")
|
||||
|
||||
# In a real tool-use scenario, we'd extract the tool outputs here.
|
||||
# For now, we'll store the response content as a 'raw_result'.
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
return {
|
||||
"messages": [response],
|
||||
"raw_results": [content],
|
||||
"iterations": state.get("iterations", 0) + 1
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {str(e)}")
|
||||
raise e
|
||||
@@ -0,0 +1,41 @@
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
from ea_chatbot.config import Settings
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||
|
||||
def summarizer_node(state: WorkerState) -> dict:
|
||||
"""Summarize research results for the Orchestrator."""
|
||||
task = state["task"]
|
||||
raw_results = state.get("raw_results", [])
|
||||
|
||||
logger = get_logger("researcher_worker:summarizer")
|
||||
logger.info("Summarizing research results...")
|
||||
|
||||
settings = Settings()
|
||||
llm = get_llm_model(
|
||||
settings.planner_llm,
|
||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||
)
|
||||
|
||||
results_str = "\n---\n".join(raw_results)
|
||||
|
||||
prompt = f"""You are a Research Specialist sub-agent. You have completed a research sub-task.
|
||||
Task: {task}
|
||||
|
||||
Raw Research Findings:
|
||||
{results_str}
|
||||
|
||||
Provide a concise, factual summary of the findings for the top-level Orchestrator.
|
||||
Ensure all key facts, dates, and sources (if provided) are preserved.
|
||||
Do NOT include internal reasoning, just the factual summary."""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content if hasattr(response, "content") else str(response)
|
||||
logger.info("[bold green]Research summary complete.[/bold green]")
|
||||
return {
|
||||
"result": result
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to summarize research: {str(e)}")
|
||||
raise e
|
||||
24
backend/src/ea_chatbot/graph/workers/researcher/workflow.py
Normal file
24
backend/src/ea_chatbot/graph/workers/researcher/workflow.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from langgraph.graph import StateGraph, END
|
||||
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
|
||||
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
||||
|
||||
def create_researcher_worker(
|
||||
searcher=searcher_node,
|
||||
summarizer=summarizer_node
|
||||
) -> StateGraph:
|
||||
"""Create the Researcher worker subgraph."""
|
||||
workflow = StateGraph(WorkerState)
|
||||
|
||||
# Add Nodes
|
||||
workflow.add_node("searcher", searcher)
|
||||
workflow.add_node("summarizer", summarizer)
|
||||
|
||||
# Set entry point
|
||||
workflow.set_entry_point("searcher")
|
||||
|
||||
# Add Edges
|
||||
workflow.add_edge("searcher", "summarizer")
|
||||
workflow.add_edge("summarizer", END)
|
||||
|
||||
return workflow.compile()
|
||||
67
backend/tests/test_researcher_worker.py
Normal file
67
backend/tests/test_researcher_worker.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker, WorkerState
|
||||
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
def test_researcher_worker_flow():
|
||||
"""Verify that the Researcher worker flow works as expected."""
|
||||
mock_searcher = MagicMock()
|
||||
mock_summarizer = MagicMock()
|
||||
|
||||
mock_searcher.return_value = {
|
||||
"raw_results": ["Result A"],
|
||||
"messages": [AIMessage(content="Search result")]
|
||||
}
|
||||
mock_summarizer.return_value = {"result": "Consolidated Summary"}
|
||||
|
||||
graph = create_researcher_worker(
|
||||
searcher=mock_searcher,
|
||||
summarizer=mock_summarizer
|
||||
)
|
||||
|
||||
initial_state = WorkerState(
|
||||
messages=[],
|
||||
task="Find governor",
|
||||
queries=[],
|
||||
raw_results=[],
|
||||
iterations=0,
|
||||
result=None
|
||||
)
|
||||
|
||||
final_state = graph.invoke(initial_state)
|
||||
|
||||
assert final_state["result"] == "Consolidated Summary"
|
||||
assert mock_searcher.called
|
||||
assert mock_summarizer.called
|
||||
|
||||
def test_researcher_mapping():
|
||||
"""Verify that we correctly map states for the researcher."""
|
||||
global_state = AgentState(
|
||||
checklist=[{"task": "Search X", "worker": "researcher"}],
|
||||
current_step=0,
|
||||
messages=[],
|
||||
question="test",
|
||||
analysis={},
|
||||
next_action="",
|
||||
iterations=0,
|
||||
vfs={},
|
||||
plots=[],
|
||||
dfs={}
|
||||
)
|
||||
|
||||
worker_input = prepare_researcher_input(global_state)
|
||||
assert worker_input["task"] == "Search X"
|
||||
|
||||
worker_output = WorkerState(
|
||||
messages=[],
|
||||
task="Search X",
|
||||
queries=[],
|
||||
raw_results=[],
|
||||
iterations=1,
|
||||
result="Found X"
|
||||
)
|
||||
|
||||
updates = merge_researcher_output(worker_output)
|
||||
assert updates["messages"][0].content == "Found X"
|
||||
Reference in New Issue
Block a user