feat(workers): Implement Researcher worker subgraph for web research tasks
This commit is contained in:
31
backend/src/ea_chatbot/graph/workers/researcher/mapping.py
Normal file
31
backend/src/ea_chatbot/graph/workers/researcher/mapping.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from typing import Dict, Any, List
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
|
|
||||||
|
def prepare_researcher_input(state: AgentState) -> Dict[str, Any]:
|
||||||
|
"""Prepare the initial state for the Researcher worker."""
|
||||||
|
checklist = state.get("checklist", [])
|
||||||
|
current_step = state.get("current_step", 0)
|
||||||
|
|
||||||
|
task_desc = "Perform research"
|
||||||
|
if 0 <= current_step < len(checklist):
|
||||||
|
task_desc = checklist[current_step].get("task", task_desc)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task": task_desc,
|
||||||
|
"messages": [HumanMessage(content=task_desc)],
|
||||||
|
"queries": [],
|
||||||
|
"raw_results": [],
|
||||||
|
"iterations": 0,
|
||||||
|
"result": None
|
||||||
|
}
|
||||||
|
|
||||||
|
def merge_researcher_output(worker_state: WorkerState) -> Dict[str, Any]:
|
||||||
|
"""Map researcher results back to the global AgentState."""
|
||||||
|
result = worker_state.get("result", "Research complete.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content=result)],
|
||||||
|
# Researcher doesn't usually update VFS or Plots, but we keep the structure
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils import helpers
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
from ea_chatbot.graph.prompts.researcher import RESEARCHER_PROMPT
|
||||||
|
|
||||||
|
def searcher_node(state: WorkerState) -> dict:
|
||||||
|
"""Execute web research for the specific task."""
|
||||||
|
task = state["task"]
|
||||||
|
logger = get_logger("researcher_worker:searcher")
|
||||||
|
|
||||||
|
logger.info(f"Researching task: {task[:50]}...")
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.researcher_llm,
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
|
||||||
|
date_str = helpers.get_readable_date()
|
||||||
|
|
||||||
|
# Adapt the global researcher prompt for the sub-task
|
||||||
|
messages = RESEARCHER_PROMPT.format_messages(
|
||||||
|
date=date_str,
|
||||||
|
question=task,
|
||||||
|
history=[], # Worker has fresh context or task-specific history
|
||||||
|
summary=""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tool binding
|
||||||
|
try:
|
||||||
|
if isinstance(llm, ChatGoogleGenerativeAI):
|
||||||
|
llm_with_tools = llm.bind_tools([{"google_search": {}}])
|
||||||
|
elif isinstance(llm, ChatOpenAI):
|
||||||
|
llm_with_tools = llm.bind_tools([{"type": "web_search"}])
|
||||||
|
else:
|
||||||
|
llm_with_tools = llm
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to bind search tools: {str(e)}")
|
||||||
|
llm_with_tools = llm
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm_with_tools.invoke(messages)
|
||||||
|
logger.info("[bold green]Search complete.[/bold green]")
|
||||||
|
|
||||||
|
# In a real tool-use scenario, we'd extract the tool outputs here.
|
||||||
|
# For now, we'll store the response content as a 'raw_result'.
|
||||||
|
content = response.content if hasattr(response, "content") else str(response)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [response],
|
||||||
|
"raw_results": [content],
|
||||||
|
"iterations": state.get("iterations", 0) + 1
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Search failed: {str(e)}")
|
||||||
|
raise e
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
|
from ea_chatbot.config import Settings
|
||||||
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||||
|
from ea_chatbot.utils.logging import get_logger, LangChainLoggingHandler
|
||||||
|
|
||||||
|
def summarizer_node(state: WorkerState) -> dict:
|
||||||
|
"""Summarize research results for the Orchestrator."""
|
||||||
|
task = state["task"]
|
||||||
|
raw_results = state.get("raw_results", [])
|
||||||
|
|
||||||
|
logger = get_logger("researcher_worker:summarizer")
|
||||||
|
logger.info("Summarizing research results...")
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
llm = get_llm_model(
|
||||||
|
settings.planner_llm,
|
||||||
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
|
)
|
||||||
|
|
||||||
|
results_str = "\n---\n".join(raw_results)
|
||||||
|
|
||||||
|
prompt = f"""You are a Research Specialist sub-agent. You have completed a research sub-task.
|
||||||
|
Task: {task}
|
||||||
|
|
||||||
|
Raw Research Findings:
|
||||||
|
{results_str}
|
||||||
|
|
||||||
|
Provide a concise, factual summary of the findings for the top-level Orchestrator.
|
||||||
|
Ensure all key facts, dates, and sources (if provided) are preserved.
|
||||||
|
Do NOT include internal reasoning, just the factual summary."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = llm.invoke(prompt)
|
||||||
|
result = response.content if hasattr(response, "content") else str(response)
|
||||||
|
logger.info("[bold green]Research summary complete.[/bold green]")
|
||||||
|
return {
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to summarize research: {str(e)}")
|
||||||
|
raise e
|
||||||
24
backend/src/ea_chatbot/graph/workers/researcher/workflow.py
Normal file
24
backend/src/ea_chatbot/graph/workers/researcher/workflow.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from ea_chatbot.graph.workers.researcher.state import WorkerState
|
||||||
|
from ea_chatbot.graph.workers.researcher.nodes.searcher import searcher_node
|
||||||
|
from ea_chatbot.graph.workers.researcher.nodes.summarizer import summarizer_node
|
||||||
|
|
||||||
|
def create_researcher_worker(
|
||||||
|
searcher=searcher_node,
|
||||||
|
summarizer=summarizer_node
|
||||||
|
) -> StateGraph:
|
||||||
|
"""Create the Researcher worker subgraph."""
|
||||||
|
workflow = StateGraph(WorkerState)
|
||||||
|
|
||||||
|
# Add Nodes
|
||||||
|
workflow.add_node("searcher", searcher)
|
||||||
|
workflow.add_node("summarizer", summarizer)
|
||||||
|
|
||||||
|
# Set entry point
|
||||||
|
workflow.set_entry_point("searcher")
|
||||||
|
|
||||||
|
# Add Edges
|
||||||
|
workflow.add_edge("searcher", "summarizer")
|
||||||
|
workflow.add_edge("summarizer", END)
|
||||||
|
|
||||||
|
return workflow.compile()
|
||||||
67
backend/tests/test_researcher_worker.py
Normal file
67
backend/tests/test_researcher_worker.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker, WorkerState
|
||||||
|
from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output
|
||||||
|
from ea_chatbot.graph.state import AgentState
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
def test_researcher_worker_flow():
|
||||||
|
"""Verify that the Researcher worker flow works as expected."""
|
||||||
|
mock_searcher = MagicMock()
|
||||||
|
mock_summarizer = MagicMock()
|
||||||
|
|
||||||
|
mock_searcher.return_value = {
|
||||||
|
"raw_results": ["Result A"],
|
||||||
|
"messages": [AIMessage(content="Search result")]
|
||||||
|
}
|
||||||
|
mock_summarizer.return_value = {"result": "Consolidated Summary"}
|
||||||
|
|
||||||
|
graph = create_researcher_worker(
|
||||||
|
searcher=mock_searcher,
|
||||||
|
summarizer=mock_summarizer
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_state = WorkerState(
|
||||||
|
messages=[],
|
||||||
|
task="Find governor",
|
||||||
|
queries=[],
|
||||||
|
raw_results=[],
|
||||||
|
iterations=0,
|
||||||
|
result=None
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = graph.invoke(initial_state)
|
||||||
|
|
||||||
|
assert final_state["result"] == "Consolidated Summary"
|
||||||
|
assert mock_searcher.called
|
||||||
|
assert mock_summarizer.called
|
||||||
|
|
||||||
|
def test_researcher_mapping():
|
||||||
|
"""Verify that we correctly map states for the researcher."""
|
||||||
|
global_state = AgentState(
|
||||||
|
checklist=[{"task": "Search X", "worker": "researcher"}],
|
||||||
|
current_step=0,
|
||||||
|
messages=[],
|
||||||
|
question="test",
|
||||||
|
analysis={},
|
||||||
|
next_action="",
|
||||||
|
iterations=0,
|
||||||
|
vfs={},
|
||||||
|
plots=[],
|
||||||
|
dfs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_input = prepare_researcher_input(global_state)
|
||||||
|
assert worker_input["task"] == "Search X"
|
||||||
|
|
||||||
|
worker_output = WorkerState(
|
||||||
|
messages=[],
|
||||||
|
task="Search X",
|
||||||
|
queries=[],
|
||||||
|
raw_results=[],
|
||||||
|
iterations=1,
|
||||||
|
result="Found X"
|
||||||
|
)
|
||||||
|
|
||||||
|
updates = merge_researcher_output(worker_output)
|
||||||
|
assert updates["messages"][0].content == "Found X"
|
||||||
Reference in New Issue
Block a user