diff --git a/backend/src/ea_chatbot/graph/workflow.py b/backend/src/ea_chatbot/graph/workflow.py index 8b982d7..a8ceb8b 100644 --- a/backend/src/ea_chatbot/graph/workflow.py +++ b/backend/src/ea_chatbot/graph/workflow.py @@ -7,6 +7,8 @@ from ea_chatbot.graph.nodes.reflector import reflector_node from ea_chatbot.graph.nodes.synthesizer import synthesizer_node from ea_chatbot.graph.workers.data_analyst.workflow import create_data_analyst_worker from ea_chatbot.graph.workers.data_analyst.mapping import prepare_worker_input, merge_worker_output +from ea_chatbot.graph.workers.researcher.workflow import create_researcher_worker +from ea_chatbot.graph.workers.researcher.mapping import prepare_researcher_input, merge_researcher_output from ea_chatbot.graph.nodes.researcher import researcher_node from ea_chatbot.graph.nodes.clarification import clarification_node from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node @@ -18,6 +20,13 @@ def data_analyst_worker_node(state: AgentState) -> dict: worker_result = worker_graph.invoke(worker_input) return merge_worker_output(worker_result) +def researcher_worker_node(state: AgentState) -> dict: + """Wrapper node for the Researcher subgraph with state mapping.""" + worker_graph = create_researcher_worker() + worker_input = prepare_researcher_input(state) + worker_result = worker_graph.invoke(worker_input) + return merge_researcher_output(worker_result) + def main_router(state: AgentState) -> str: """Route from query analyzer based on initial assessment.""" next_action = state.get("next_action") @@ -31,25 +40,35 @@ def delegation_router(state: AgentState) -> str: if next_action == "data_analyst": return "data_analyst_worker" elif next_action == "researcher": - return "researcher" + return "researcher_worker" elif next_action == "summarize": return "synthesizer" return "synthesizer" -def create_workflow(): +def create_workflow( + query_analyzer=query_analyzer_node, + planner=planner_node, + delegate=delegate_node, + data_analyst_worker=data_analyst_worker_node, + researcher_worker=researcher_worker_node, + reflector=reflector_node, + synthesizer=synthesizer_node, + clarification=clarification_node, + summarize_conversation=summarize_conversation_node +): """Create the high-level Orchestrator workflow.""" workflow = StateGraph(AgentState) # Add Nodes - workflow.add_node("query_analyzer", query_analyzer_node) - workflow.add_node("planner", planner_node) - workflow.add_node("delegate", delegate_node) - workflow.add_node("data_analyst_worker", data_analyst_worker_node) - workflow.add_node("researcher", researcher_node) - workflow.add_node("reflector", reflector_node) - workflow.add_node("synthesizer", synthesizer_node) - workflow.add_node("clarification", clarification_node) - workflow.add_node("summarize_conversation", summarize_conversation_node) + workflow.add_node("query_analyzer", query_analyzer) + workflow.add_node("planner", planner) + workflow.add_node("delegate", delegate) + workflow.add_node("data_analyst_worker", data_analyst_worker) + workflow.add_node("researcher_worker", researcher_worker) + workflow.add_node("reflector", reflector) + workflow.add_node("synthesizer", synthesizer) + workflow.add_node("clarification", clarification) + workflow.add_node("summarize_conversation", summarize_conversation) # Set entry point workflow.set_entry_point("query_analyzer") @@ -71,13 +90,13 @@ def create_workflow(): delegation_router, { "data_analyst_worker": "data_analyst_worker", - "researcher": "researcher", + "researcher_worker": "researcher_worker", "synthesizer": "synthesizer" } ) workflow.add_edge("data_analyst_worker", "reflector") - workflow.add_edge("researcher", "reflector") + workflow.add_edge("researcher_worker", "reflector") workflow.add_edge("reflector", "delegate") workflow.add_edge("synthesizer", "summarize_conversation")