97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
import json
|
|
import pytest
|
|
import logging
|
|
from unittest.mock import patch
|
|
from ea_chatbot.graph.workflow import create_workflow
|
|
from ea_chatbot.graph.state import AgentState
|
|
from ea_chatbot.utils.logging import get_logger
|
|
from langchain_community.chat_models import FakeListChatModel
|
|
from langchain_core.runnables import RunnableLambda
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_logging():
|
|
"""Reset handlers on the root ea_chatbot logger."""
|
|
root = logging.getLogger("ea_chatbot")
|
|
for handler in root.handlers[:]:
|
|
root.removeHandler(handler)
|
|
yield
|
|
for handler in root.handlers[:]:
|
|
root.removeHandler(handler)
|
|
|
|
class FakeStructuredModel(FakeListChatModel):
|
|
def with_structured_output(self, schema, **kwargs):
|
|
# Return a runnable that returns a parsed object
|
|
def _invoke(input, config=None, **kwargs):
|
|
content = self.responses[0]
|
|
import json
|
|
data = json.loads(content)
|
|
if hasattr(schema, "model_validate"):
|
|
return schema.model_validate(data)
|
|
return data
|
|
|
|
return RunnableLambda(_invoke)
|
|
|
|
def test_logging_e2e_json_output(tmp_path):
|
|
"""Test that a full graph run produces structured JSON logs from multiple nodes."""
|
|
log_file = tmp_path / "e2e_test.jsonl"
|
|
|
|
# Configure the root logger
|
|
get_logger("ea_chatbot", log_file=str(log_file))
|
|
|
|
initial_state: AgentState = {
|
|
"messages": [],
|
|
"question": "Who won in 2024?",
|
|
"analysis": None,
|
|
"next_action": "",
|
|
"iterations": 0,
|
|
"checklist": [],
|
|
"current_step": 0,
|
|
"vfs": {},
|
|
"plots": [],
|
|
"dfs": {}
|
|
}
|
|
|
|
# Create fake models that support callbacks and structured output
|
|
fake_analyzer_response = """{"data_required": [], "unknowns": [], "ambiguities": ["Which year?"], "conditions": [], "next_action": "clarify"}"""
|
|
fake_analyzer = FakeStructuredModel(responses=[fake_analyzer_response])
|
|
|
|
fake_clarify = FakeListChatModel(responses=["Please specify."])
|
|
|
|
# Create a test app without interrupts
|
|
# We need to manually compile it here to avoid the global 'app' which has interrupts
|
|
from langgraph.graph import StateGraph, END
|
|
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node
|
|
from ea_chatbot.graph.nodes.clarification import clarification_node
|
|
|
|
workflow = StateGraph(AgentState)
|
|
workflow.add_node("query_analyzer", query_analyzer_node)
|
|
workflow.add_node("clarification", clarification_node)
|
|
workflow.set_entry_point("query_analyzer")
|
|
workflow.add_edge("query_analyzer", "clarification")
|
|
workflow.add_edge("clarification", END)
|
|
test_app = workflow.compile()
|
|
|
|
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
|
|
mock_llm_factory.return_value = fake_analyzer
|
|
|
|
with patch("ea_chatbot.graph.nodes.clarification.get_llm_model") as mock_clarify_llm_factory:
|
|
mock_clarify_llm_factory.return_value = fake_clarify
|
|
|
|
# Run the graph
|
|
test_app.invoke(initial_state)
|
|
|
|
# Verify file content
|
|
assert log_file.exists()
|
|
lines = log_file.read_text().splitlines()
|
|
assert len(lines) > 0
|
|
|
|
# Verify we have logs from different nodes
|
|
node_names = [json.loads(line)["name"] for line in lines]
|
|
assert "ea_chatbot.query_analyzer" in node_names
|
|
assert "ea_chatbot.clarification" in node_names
|
|
|
|
# Verify events
|
|
messages = [json.loads(line)["message"] for line in lines]
|
|
assert any("Analyzing question" in m for m in messages)
|
|
assert any("Clarification generated" in m for m in messages)
|