feat(orchestrator): Harden VFS and enhance artifact awareness across workers
This commit is contained in:
@@ -9,7 +9,7 @@ from matplotlib.figure import Figure
|
|||||||
|
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.utils.db_client import DBClient
|
from ea_chatbot.utils.db_client import DBClient
|
||||||
from ea_chatbot.utils.vfs import VFSHelper
|
from ea_chatbot.utils.vfs import VFSHelper, safe_vfs_copy
|
||||||
from ea_chatbot.utils.logging import get_logger
|
from ea_chatbot.utils.logging import get_logger
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ def executor_node(state: AgentState) -> dict:
|
|||||||
db_client = DBClient(settings=db_settings)
|
db_client = DBClient(settings=db_settings)
|
||||||
|
|
||||||
# Initialize the Virtual File System (VFS) helper
|
# Initialize the Virtual File System (VFS) helper
|
||||||
vfs_state = copy.deepcopy(state.get("vfs", {}))
|
vfs_state = safe_vfs_copy(state.get("vfs", {}))
|
||||||
vfs_helper = VFSHelper(vfs_state)
|
vfs_helper = VFSHelper(vfs_state)
|
||||||
|
|
||||||
# Initialize local variables for execution
|
# Initialize local variables for execution
|
||||||
|
|||||||
@@ -23,11 +23,23 @@ def synthesizer_node(state: AgentState) -> dict:
|
|||||||
callbacks=[LangChainLoggingHandler(logger=logger)]
|
callbacks=[LangChainLoggingHandler(logger=logger)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Artifact summary
|
||||||
|
plots = state.get("plots", [])
|
||||||
|
vfs = state.get("vfs", {})
|
||||||
|
artifacts_summary = ""
|
||||||
|
if plots:
|
||||||
|
artifacts_summary += f"- {len(plots)} generated plot(s) are attached to this response.\n"
|
||||||
|
if vfs:
|
||||||
|
artifacts_summary += "- Data files available in VFS: " + ", ".join(vfs.keys()) + "\n"
|
||||||
|
if not artifacts_summary:
|
||||||
|
artifacts_summary = "No additional artifacts generated."
|
||||||
|
|
||||||
# We provide the full history and the original question
|
# We provide the full history and the original question
|
||||||
messages = SYNTHESIZER_PROMPT.format_messages(
|
messages = SYNTHESIZER_PROMPT.format_messages(
|
||||||
question=question,
|
question=question,
|
||||||
history=history,
|
history=history,
|
||||||
worker_results="Review the worker summaries provided in the message history."
|
worker_results="Review the worker summaries provided in the message history.",
|
||||||
|
artifacts_summary=artifacts_summary
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -9,6 +9,14 @@ The user will provide a task and a plan.
|
|||||||
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
|
- Do NOT assume a dataframe `df` is already loaded unless explicitly stated. You usually need to query it first.
|
||||||
- The database schema is described in the prompt. Use it to construct valid SQL queries.
|
- The database schema is described in the prompt. Use it to construct valid SQL queries.
|
||||||
|
|
||||||
|
**Virtual File System (VFS):**
|
||||||
|
- An in-memory file system is available as `vfs`. Use it to persist intermediate data or large artifacts.
|
||||||
|
- `vfs.write(filename, content, metadata=None)`: Save a file (content can be any serializable object).
|
||||||
|
- `vfs.read(filename) -> (content, metadata)`: Read a file.
|
||||||
|
- `vfs.list() -> list[str]`: List all files.
|
||||||
|
- `vfs.delete(filename)`: Delete a file.
|
||||||
|
- Prefer using VFS for intermediate DataFrames or complex data structures instead of printing everything.
|
||||||
|
|
||||||
**Plotting:**
|
**Plotting:**
|
||||||
- If you need to plot any data, use the `plots` list to store the figures.
|
- If you need to plot any data, use the `plots` list to store the figures.
|
||||||
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
|
- Example: `plots.append(fig)` or `plots.append(plt.gcf())`.
|
||||||
@@ -18,7 +26,8 @@ The user will provide a task and a plan.
|
|||||||
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
|
- Produce FULL, COMPLETE CODE that includes all steps and solves the task!
|
||||||
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
|
- Always include the import statements at the top of the code (e.g., `import pandas as pd`, `import matplotlib.pyplot as plt`).
|
||||||
- Always include print statements to output the results of your code.
|
- Always include print statements to output the results of your code.
|
||||||
- Use `db.query_df("SELECT ...")` to get data."""
|
- Use `db.query_df("SELECT ...")` to get data.
|
||||||
|
"""
|
||||||
|
|
||||||
CODE_GENERATOR_USER = """TASK:
|
CODE_GENERATOR_USER = """TASK:
|
||||||
{question}
|
{question}
|
||||||
@@ -43,6 +52,7 @@ Return a complete, corrected python code that incorporates the fixes for the err
|
|||||||
- You have access to a database client via the variable `db`.
|
- You have access to a database client via the variable `db`.
|
||||||
- Use `db.query_df(sql)` to run queries.
|
- Use `db.query_df(sql)` to run queries.
|
||||||
- Use `plots.append(fig)` for plots.
|
- Use `plots.append(fig)` for plots.
|
||||||
|
- You have access to `vfs` for persistent in-memory storage.
|
||||||
- Always include imports and print statements."""
|
- Always include imports and print statements."""
|
||||||
|
|
||||||
ERROR_CORRECTOR_USER = """FAILED CODE:
|
ERROR_CORRECTOR_USER = """FAILED CODE:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ Your goal is to synthesize their individual findings into a single, cohesive, an
|
|||||||
- Do NOT mention the internal 'workers' or 'checklist' names.
|
- Do NOT mention the internal 'workers' or 'checklist' names.
|
||||||
- Combine the data insights (from Data Analysts) and factual research (from Researchers) into a natural narrative.
|
- Combine the data insights (from Data Analysts) and factual research (from Researchers) into a natural narrative.
|
||||||
- Ensure all numbers, dates, and names from the worker reports are included accurately.
|
- Ensure all numbers, dates, and names from the worker reports are included accurately.
|
||||||
|
- **Artifacts & Plots:** If plots or charts were generated, refer to them naturally (e.g., "The chart below shows...").
|
||||||
- If any part of the plan failed, explain the status honestly but professionally.
|
- If any part of the plan failed, explain the status honestly but professionally.
|
||||||
- Present data in clear formats (tables, bullet points) where appropriate."""
|
- Present data in clear formats (tables, bullet points) where appropriate."""
|
||||||
|
|
||||||
@@ -18,6 +19,9 @@ SYNTHESIZER_USER = """USER QUESTION:
|
|||||||
EXECUTION SUMMARY (Results from specialized workers):
|
EXECUTION SUMMARY (Results from specialized workers):
|
||||||
{worker_results}
|
{worker_results}
|
||||||
|
|
||||||
|
AVAILABLE ARTIFACTS:
|
||||||
|
{artifacts_summary}
|
||||||
|
|
||||||
Provide the final integrated response:"""
|
Provide the final integrated response:"""
|
||||||
|
|
||||||
SYNTHESIZER_PROMPT = ChatPromptTemplate.from_messages([
|
SYNTHESIZER_PROMPT = ChatPromptTemplate.from_messages([
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import copy
|
|||||||
from langchain_core.messages import HumanMessage, AIMessage
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
from ea_chatbot.graph.state import AgentState
|
from ea_chatbot.graph.state import AgentState
|
||||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
|
from ea_chatbot.utils.vfs import safe_vfs_copy
|
||||||
|
|
||||||
def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
||||||
"""Prepare the initial state for the Data Analyst worker."""
|
"""Prepare the initial state for the Data Analyst worker."""
|
||||||
@@ -17,7 +18,7 @@ def prepare_worker_input(state: AgentState) -> Dict[str, Any]:
|
|||||||
return {
|
return {
|
||||||
"task": task_desc,
|
"task": task_desc,
|
||||||
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
|
"messages": [HumanMessage(content=task_desc)], # Start worker loop with the task
|
||||||
"vfs_state": copy.deepcopy(state.get("vfs", {})),
|
"vfs_state": safe_vfs_copy(state.get("vfs", {})),
|
||||||
"iterations": 0,
|
"iterations": 0,
|
||||||
"plots": [],
|
"plots": [],
|
||||||
"code": None,
|
"code": None,
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from matplotlib.figure import Figure
|
|||||||
|
|
||||||
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
from ea_chatbot.graph.workers.data_analyst.state import WorkerState
|
||||||
from ea_chatbot.utils.db_client import DBClient
|
from ea_chatbot.utils.db_client import DBClient
|
||||||
from ea_chatbot.utils.vfs import VFSHelper
|
from ea_chatbot.utils.vfs import VFSHelper, safe_vfs_copy
|
||||||
from ea_chatbot.utils.logging import get_logger
|
from ea_chatbot.utils.logging import get_logger
|
||||||
from ea_chatbot.config import Settings
|
from ea_chatbot.config import Settings
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ def executor_node(state: WorkerState) -> dict:
|
|||||||
db_client = DBClient(settings=db_settings)
|
db_client = DBClient(settings=db_settings)
|
||||||
|
|
||||||
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
# Initialize the Virtual File System (VFS) helper with the snapshot from state
|
||||||
vfs_state = copy.deepcopy(state.get("vfs_state", {}))
|
vfs_state = safe_vfs_copy(state.get("vfs_state", {}))
|
||||||
vfs_helper = VFSHelper(vfs_state)
|
vfs_helper = VFSHelper(vfs_state)
|
||||||
|
|
||||||
# Initialize local variables for execution
|
# Initialize local variables for execution
|
||||||
|
|||||||
@@ -9,12 +9,21 @@ def summarizer_node(state: WorkerState) -> dict:
|
|||||||
task = state["task"]
|
task = state["task"]
|
||||||
output = state.get("output", "")
|
output = state.get("output", "")
|
||||||
error = state.get("error")
|
error = state.get("error")
|
||||||
|
plots = state.get("plots", [])
|
||||||
|
vfs_state = state.get("vfs_state", {})
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
logger = get_logger("data_analyst_worker:summarizer")
|
logger = get_logger("data_analyst_worker:summarizer")
|
||||||
|
|
||||||
logger.info("Summarizing analysis results for the Orchestrator...")
|
logger.info("Summarizing analysis results for the Orchestrator...")
|
||||||
|
|
||||||
|
# Artifact summary
|
||||||
|
artifact_info = ""
|
||||||
|
if plots:
|
||||||
|
artifact_info += f"- Generated {len(plots)} plot(s).\n"
|
||||||
|
if vfs_state:
|
||||||
|
artifact_info += "- VFS Artifacts: " + ", ".join(vfs_state.keys()) + "\n"
|
||||||
|
|
||||||
# We can use a smaller/faster model for this summary if needed
|
# We can use a smaller/faster model for this summary if needed
|
||||||
llm = get_llm_model(
|
llm = get_llm_model(
|
||||||
settings.planner_llm, # Using planner model for summary logic
|
settings.planner_llm, # Using planner model for summary logic
|
||||||
@@ -25,8 +34,10 @@ def summarizer_node(state: WorkerState) -> dict:
|
|||||||
Task: {task}
|
Task: {task}
|
||||||
Execution Results: {output}
|
Execution Results: {output}
|
||||||
Error Log (if any): {error}
|
Error Log (if any): {error}
|
||||||
|
{artifact_info}
|
||||||
|
|
||||||
Provide a concise summary of the findings or status for the top-level Orchestrator.
|
Provide a concise summary of the findings or status for the top-level Orchestrator.
|
||||||
|
If plots or data files were generated, mention them.
|
||||||
If the execution failed after multiple retries, explain why concisely.
|
If the execution failed after multiple retries, explain why concisely.
|
||||||
Do NOT include the raw Python code, just the results of the analysis."""
|
Do NOT include the raw Python code, just the results of the analysis."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,37 @@
|
|||||||
|
import copy
|
||||||
from typing import Dict, Any, Optional, Tuple, List
|
from typing import Dict, Any, Optional, Tuple, List
|
||||||
|
from ea_chatbot.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("utils:vfs")
|
||||||
|
|
||||||
|
def safe_vfs_copy(vfs_state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Perform a safe deep copy of the VFS state.
|
||||||
|
|
||||||
|
If an entry cannot be deep-copied (e.g., it contains a non-copyable object like a DB handle),
|
||||||
|
logs an error and replaces the entry with a descriptive error marker.
|
||||||
|
This prevents crashing the graph/persistence while making the failure explicit.
|
||||||
|
"""
|
||||||
|
new_vfs = {}
|
||||||
|
for filename, data in vfs_state.items():
|
||||||
|
try:
|
||||||
|
# Attempt a standard deepcopy for isolation
|
||||||
|
new_vfs[filename] = copy.deepcopy(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"CRITICAL: VFS artifact '{filename}' is NOT copyable/serializable: {str(e)}. "
|
||||||
|
"Replacing with error placeholder to prevent graph crash."
|
||||||
|
)
|
||||||
|
# Replace with a standardized error artifact
|
||||||
|
new_vfs[filename] = {
|
||||||
|
"content": f"<ERROR: This artifact could not be persisted or copied: {str(e)}>",
|
||||||
|
"metadata": {
|
||||||
|
"type": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"original_filename": filename
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new_vfs
|
||||||
|
|
||||||
class VFSHelper:
|
class VFSHelper:
|
||||||
"""Helper class for managing in-memory Virtual File System (VFS) artifacts."""
|
"""Helper class for managing in-memory Virtual File System (VFS) artifacts."""
|
||||||
|
|||||||
48
backend/tests/test_vfs_robustness.py
Normal file
48
backend/tests/test_vfs_robustness.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import pytest
|
||||||
|
import threading
|
||||||
|
from ea_chatbot.utils.vfs import safe_vfs_copy
|
||||||
|
|
||||||
|
def test_safe_vfs_copy_success():
|
||||||
|
"""Test standard success case."""
|
||||||
|
vfs = {
|
||||||
|
"test.csv": {"content": "data", "metadata": {"type": "csv"}},
|
||||||
|
"num": 42
|
||||||
|
}
|
||||||
|
copied = safe_vfs_copy(vfs)
|
||||||
|
assert copied == vfs
|
||||||
|
assert copied is not vfs
|
||||||
|
assert copied["test.csv"] is not vfs["test.csv"]
|
||||||
|
|
||||||
|
def test_safe_vfs_copy_handles_non_copyable():
|
||||||
|
"""Test replacing uncopyable objects with error placeholders."""
|
||||||
|
# A threading.Lock is famously uncopyable
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
vfs = {
|
||||||
|
"safe_file": "important data",
|
||||||
|
"unsafe_lock": lock
|
||||||
|
}
|
||||||
|
|
||||||
|
copied = safe_vfs_copy(vfs)
|
||||||
|
|
||||||
|
# Safe one remains
|
||||||
|
assert copied["safe_file"] == "important data"
|
||||||
|
|
||||||
|
# Unsafe one is REPLACED with an error dict
|
||||||
|
assert isinstance(copied["unsafe_lock"], dict)
|
||||||
|
assert "content" in copied["unsafe_lock"]
|
||||||
|
assert "ERROR" in copied["unsafe_lock"]["content"]
|
||||||
|
assert copied["unsafe_lock"]["metadata"]["type"] == "error"
|
||||||
|
assert "lock" in str(copied["unsafe_lock"]["metadata"]["error"]).lower()
|
||||||
|
|
||||||
|
# Original is unchanged (it was a lock)
|
||||||
|
assert vfs["unsafe_lock"] is lock
|
||||||
|
|
||||||
|
def test_safe_vfs_copy_preserves_nested_copyable():
|
||||||
|
"""Test deepcopy still works for complex but copyable objects."""
|
||||||
|
vfs = {
|
||||||
|
"data": {"a": [1, 2, 3], "b": {"c": True}}
|
||||||
|
}
|
||||||
|
copied = safe_vfs_copy(vfs)
|
||||||
|
assert copied["data"]["a"] == [1, 2, 3]
|
||||||
|
assert copied["data"]["a"] is not vfs["data"]["a"]
|
||||||
Reference in New Issue
Block a user