feat: implement mvp with email-first login flow and langgraph architecture

This commit is contained in:
Yunxiao Xu
2026-02-09 23:22:30 -08:00
parent af227d40e6
commit 5a943b902a
79 changed files with 8200 additions and 1 deletions

90
tests/test_app.py Normal file
View File

@@ -0,0 +1,90 @@
import os
import sys
import pytest
from unittest.mock import MagicMock, patch
from streamlit.testing.v1 import AppTest
from langchain_core.messages import AIMessage
# Ensure src is in python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
@pytest.fixture(autouse=True)
def mock_history_manager():
"""Globally mock HistoryManager to avoid DB calls during AppTest."""
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
instance = mock_cls.return_value
instance.create_conversation.return_value = MagicMock(id="conv_123")
instance.get_conversations.return_value = []
instance.get_messages.return_value = []
instance.add_message.return_value = MagicMock()
instance.update_conversation_summary.return_value = MagicMock()
yield instance
@pytest.fixture
def mock_app_stream():
with patch("ea_chatbot.graph.workflow.app.stream") as mock_stream:
# Mock events from app.stream
mock_stream.return_value = [
{"query_analyzer": {"next_action": "research"}},
{"researcher": {"messages": [AIMessage(content="Research result")]}}
]
yield mock_stream
@pytest.fixture
def mock_user():
user = MagicMock()
user.id = "test_id"
user.username = "test@example.com"
user.display_name = "Test User"
return user
def test_app_initial_state(mock_app_stream, mock_user):
"""Test that the app initializes with the correct title and empty history."""
at = AppTest.from_file("src/ea_chatbot/app.py")
# Simulate logged-in user
at.session_state["user"] = mock_user
at.run()
assert not at.exception
assert at.title[0].value == "🗳️ Election Analytics Chatbot"
# Check session state initialization
assert "messages" in at.session_state
assert len(at.session_state["messages"]) == 0
def test_app_dev_mode_toggle(mock_app_stream, mock_user):
"""Test that the dev mode toggle exists in the sidebar."""
with patch.dict(os.environ, {"DEV_MODE": "false"}):
at = AppTest.from_file("src/ea_chatbot/app.py")
at.session_state["user"] = mock_user
at.run()
# Check for sidebar toggle (checkbox)
assert len(at.sidebar.checkbox) > 0
dev_mode_toggle = at.sidebar.checkbox[0]
assert dev_mode_toggle.label == "Dev Mode"
assert dev_mode_toggle.value is False
def test_app_graph_execution_streaming(mock_app_stream, mock_user, mock_history_manager):
"""Test that entering a prompt triggers the graph stream and displays response."""
at = AppTest.from_file("src/ea_chatbot/app.py")
at.session_state["user"] = mock_user
at.run()
# Input a question
at.chat_input[0].set_value("Test question").run()
# Verify graph stream was called
assert mock_app_stream.called
# Message should be added to history
assert len(at.session_state["messages"]) == 2
assert at.session_state["messages"][0]["role"] == "user"
assert at.session_state["messages"][1]["role"] == "assistant"
assert "Research result" in at.session_state["messages"][1]["content"]
# Verify history manager was used
assert mock_history_manager.create_conversation.called
assert mock_history_manager.add_message.called

83
tests/test_app_auth.py Normal file
View File

@@ -0,0 +1,83 @@
import pytest
from unittest.mock import MagicMock, patch
from streamlit.testing.v1 import AppTest
from ea_chatbot.auth import AuthType
@pytest.fixture
def mock_history_manager_instance():
# We need to patch before the AppTest loads the module
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
instance = mock_cls.return_value
yield instance
def test_auth_ui_flow_step1_to_password(mock_history_manager_instance):
"""Test UI transition from Step 1 (email) to Step 2a (password) for LOCAL user."""
# Patch BEFORE creating AppTest
mock_user = MagicMock()
mock_user.password_hash = "hashed_password"
mock_history_manager_instance.get_user.return_value = mock_user
at = AppTest.from_file("src/ea_chatbot/app.py")
at.run()
# Step 1: Identification
assert at.session_state["login_step"] == "email"
at.text_input[0].set_value("local@example.com")
at.button[0].click().run()
# Verify transition to password step
assert at.session_state["login_step"] == "login_password"
assert at.session_state["login_email"] == "local@example.com"
assert "Welcome back" in at.info[0].value
def test_auth_ui_flow_step1_to_register(mock_history_manager_instance):
"""Test UI transition from Step 1 (email) to Step 2b (registration) for NEW user."""
mock_history_manager_instance.get_user.return_value = None
at = AppTest.from_file("src/ea_chatbot/app.py")
at.run()
# Step 1: Identification
at.text_input[0].set_value("new@example.com")
at.button[0].click().run()
# Verify transition to registration step
assert at.session_state["login_step"] == "register_details"
assert at.session_state["login_email"] == "new@example.com"
assert "Create an account" in at.info[0].value
def test_auth_ui_flow_step1_to_oidc(mock_history_manager_instance):
"""Test UI transition from Step 1 (email) to Step 2c (OIDC) for OIDC user."""
# Mock history_manager.get_user to return a user WITHOUT a password
mock_user = MagicMock()
mock_user.password_hash = None
mock_history_manager_instance.get_user.return_value = mock_user
at = AppTest.from_file("src/ea_chatbot/app.py")
at.run()
# Step 1: Identification
at.text_input[0].set_value("oidc@example.com")
at.button[0].click().run()
# Verify transition to OIDC step
assert at.session_state["login_step"] == "oidc_login"
assert at.session_state["login_email"] == "oidc@example.com"
assert "configured for Single Sign-On" in at.info[0].value
def test_auth_ui_flow_back_button(mock_history_manager_instance):
"""Test that the 'Back' button returns to Step 1."""
at = AppTest.from_file("src/ea_chatbot/app.py")
# Simulate being on Step 2a
at.session_state["login_step"] = "login_password"
at.session_state["login_email"] = "local@example.com"
at.run()
# Click Back (index 1 in Step 2a)
at.button[1].click().run()
# Verify return to email step
assert at.session_state["login_step"] == "email"

43
tests/test_auth.py Normal file
View File

@@ -0,0 +1,43 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.auth import OIDCClient
@patch("ea_chatbot.auth.OAuth2Session")
def test_oidc_client_initialization(mock_oauth):
client = OIDCClient(
client_id="test_id",
client_secret="test_secret",
server_metadata_url="https://test.server/.well-known/openid-configuration"
)
assert client.oauth_session is not None
@patch("ea_chatbot.auth.requests")
@patch("ea_chatbot.auth.OAuth2Session")
def test_get_login_url(mock_oauth_cls, mock_requests):
# Setup mock session
mock_session = MagicMock()
mock_oauth_cls.return_value = mock_session
# Mock metadata response
mock_response = MagicMock()
mock_response.json.return_value = {
"authorization_endpoint": "https://test.server/auth",
"token_endpoint": "https://test.server/token",
"userinfo_endpoint": "https://test.server/userinfo"
}
mock_requests.get.return_value = mock_response
# Mock authorization url generation
mock_session.create_authorization_url.return_value = ("https://test.server/auth?response_type=code", "state")
client = OIDCClient(
client_id="test_id",
client_secret="test_secret",
server_metadata_url="https://test.server/.well-known/openid-configuration"
)
url = client.get_login_url()
assert url == "https://test.server/auth?response_type=code"
# Verify metadata was fetched via requests
mock_requests.get.assert_called_with("https://test.server/.well-known/openid-configuration")

49
tests/test_auth_flow.py Normal file
View File

@@ -0,0 +1,49 @@
import pytest
from unittest.mock import MagicMock
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.auth import get_user_auth_type, AuthType
# Mocks
@pytest.fixture
def mock_history_manager():
return MagicMock(spec=HistoryManager)
def test_auth_flow_existing_local_user(mock_history_manager):
"""Test that an existing user with a password returns LOCAL auth type."""
# Setup
mock_user = MagicMock()
mock_user.password_hash = "hashed_secret"
mock_history_manager.get_user.return_value = mock_user
# Execute
auth_type = get_user_auth_type("test@example.com", mock_history_manager)
# Verify
assert auth_type == AuthType.LOCAL
mock_history_manager.get_user.assert_called_once_with("test@example.com")
def test_auth_flow_existing_oidc_user(mock_history_manager):
"""Test that an existing user WITHOUT a password returns OIDC auth type."""
# Setup
mock_user = MagicMock()
mock_user.password_hash = None # No password implies OIDC
mock_history_manager.get_user.return_value = mock_user
# Execute
auth_type = get_user_auth_type("sso@example.com", mock_history_manager)
# Verify
assert auth_type == AuthType.OIDC
mock_history_manager.get_user.assert_called_once_with("sso@example.com")
def test_auth_flow_new_user(mock_history_manager):
"""Test that a non-existent user returns NEW auth type."""
# Setup
mock_history_manager.get_user.return_value = None
# Execute
auth_type = get_user_auth_type("new@example.com", mock_history_manager)
# Verify
assert auth_type == AuthType.NEW
mock_history_manager.get_user.assert_called_once_with("new@example.com")

62
tests/test_coder.py Normal file
View File

@@ -0,0 +1,62 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.coder import coder_node
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
@pytest.fixture
def mock_state():
return {
"messages": [],
"question": "Show me results for New Jersey",
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
"code": None,
"error": None,
"plots": [],
"dfs": {},
"next_action": "plan"
}
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
"""Test coder node generates code from plan."""
mock_get_summary.return_value = "Column: Name, Type: text"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import CodeGenerationResponse
mock_response = CodeGenerationResponse(
code="import pandas as pd\nprint('Hello')",
explanation="Generated code"
)
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
result = coder_node(mock_state)
assert "code" in result
assert "import pandas as pd" in result["code"]
assert "error" in result
assert result["error"] is None
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
def test_error_corrector_node(mock_get_llm, mock_state):
"""Test error corrector node fixes code."""
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import CodeGenerationResponse
mock_response = CodeGenerationResponse(
code="import pandas as pd\nprint('Defined')",
explanation="Fixed variable"
)
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
result = error_corrector_node(mock_state)
assert "code" in result
assert "print('Defined')" in result["code"]
assert result["error"] is None

47
tests/test_config.py Normal file
View File

@@ -0,0 +1,47 @@
import pytest
from pydantic import ValidationError
from ea_chatbot.config import Settings, LLMConfig
def test_default_settings():
"""Test that default settings are loaded correctly."""
settings = Settings()
# Check default config for query analyzer
assert isinstance(settings.query_analyzer_llm, LLMConfig)
assert settings.query_analyzer_llm.provider == "openai"
assert settings.query_analyzer_llm.model == "gpt-5-mini"
assert settings.query_analyzer_llm.temperature == 0.0
# Check default config for planner
assert isinstance(settings.planner_llm, LLMConfig)
assert settings.planner_llm.provider == "openai"
assert settings.planner_llm.model == "gpt-5-mini"
def test_env_override(monkeypatch):
"""Test that environment variables override defaults."""
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
monkeypatch.setenv("QUERY_ANALYZER_LLM__TEMPERATURE", "0.7")
settings = Settings()
assert settings.query_analyzer_llm.model == "gpt-3.5-turbo"
assert settings.query_analyzer_llm.temperature == 0.7
def test_provider_specific_params():
"""Test that provider specific parameters can be set."""
config = LLMConfig(
provider="openai",
model="o1-preview",
provider_specific={"reasoning_effort": "high"}
)
assert config.provider_specific["reasoning_effort"] == "high"
def test_oidc_settings(monkeypatch):
"""Test OIDC settings configuration."""
monkeypatch.setenv("OIDC_CLIENT_ID", "test_client_id")
monkeypatch.setenv("OIDC_CLIENT_SECRET", "test_client_secret")
monkeypatch.setenv("OIDC_SERVER_METADATA_URL", "https://test.server/.well-known/openid-configuration")
settings = Settings()
assert settings.oidc_client_id == "test_client_id"
assert settings.oidc_client_secret == "test_client_secret"
assert settings.oidc_server_metadata_url == "https://test.server/.well-known/openid-configuration"

View File

@@ -0,0 +1,56 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
@pytest.fixture
def mock_state_with_history():
return {
"messages": [
HumanMessage(content="Show me the 2024 results for Florida"),
AIMessage(content="Here are the results for Florida in 2024...")
],
"summary": "The user is asking about 2024 election results."
}
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
def test_summarize_conversation_node_updates_summary(mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
# Mock LLM response for updating summary
mock_llm_instance.invoke.return_value = AIMessage(content="Updated summary including NJ results.")
# Add new messages to simulate a completed turn
mock_state_with_history["messages"].extend([
HumanMessage(content="What about in New Jersey?"),
AIMessage(content="In New Jersey, the 2024 results were...")
])
result = summarize_conversation_node(mock_state_with_history)
assert "summary" in result
assert result["summary"] == "Updated summary including NJ results."
# Verify LLM was called with the correct context
call_messages = mock_llm_instance.invoke.call_args[0][0]
# Should include current summary and last turn messages
assert "Current summary: The user is asking about 2024 election results." in call_messages[0].content
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
def test_summarize_conversation_node_initial_summary(mock_get_llm):
state = {
"messages": [
HumanMessage(content="Hi"),
AIMessage(content="Hello! How can I help you today?")
],
"summary": ""
}
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_llm_instance.invoke.return_value = AIMessage(content="Initial greeting.")
result = summarize_conversation_node(state)
assert result["summary"] == "Initial greeting."

View File

@@ -0,0 +1,195 @@
import pytest
import pandas as pd
import os
from unittest.mock import MagicMock, patch
from ea_chatbot.utils.database_inspection import get_primary_key, inspect_db_table, get_data_summary
@pytest.fixture
def mock_db_client():
mock_client = MagicMock()
mock_client.settings = {"table": "test_table"}
return mock_client
def test_get_primary_key(mock_db_client):
"""Test dynamic primary key discovery."""
# Mock response for primary key query
mock_df = pd.DataFrame({"column_name": ["my_pk"]})
mock_db_client.query_df.return_value = mock_df
pk = get_primary_key(mock_db_client, "test_table")
assert pk == "my_pk"
# Verify the query was called (at least once)
assert mock_db_client.query_df.called
def test_inspect_db_table_improved(mock_db_client, tmp_path):
"""Test improved inspect_db_table with cardinality and sampling."""
data_dir = str(tmp_path)
# 1. Mock columns and types
columns_df = pd.DataFrame({
"column_name": ["id", "category", "count"],
"data_type": ["integer", "text", "integer"]
})
# 2. Mock row count
total_rows_df = pd.DataFrame([{"count": 100}])
# 3. Mock PK discovery
pk_df = pd.DataFrame({"column_name": ["id"]})
# 4. Mock stats for columns
# We need to handle multiple calls to query_df
def side_effect(query):
if "information_schema.columns" in query:
return columns_df
if "COUNT(*)" in query:
return total_rows_df
if "information_schema.key_column_usage" in query:
return pk_df
# Category stats
if 'COUNT("category")' in query:
return pd.DataFrame([{"count": 100}])
if 'COUNT(DISTINCT "category")' in query:
return pd.DataFrame([{"count": 5}])
if 'SELECT DISTINCT "category"' in query:
return pd.DataFrame({"category": ["A", "B", "C", "D", "E"]})
# Count stats
if 'COUNT("count")' in query:
return pd.DataFrame([{"count": 100}])
if 'COUNT(DISTINCT "count")' in query:
return pd.DataFrame([{"count": 100}])
if 'AVG("count")' in query:
return pd.DataFrame([{"avg": 10.0, "min": 1, "max": 20}])
# ID stats (fix for IndexError)
if 'COUNT("id")' in query:
return pd.DataFrame([{"count": 100}])
if 'COUNT(DISTINCT "id")' in query:
return pd.DataFrame([{"count": 100}])
if 'AVG("id")' in query:
return pd.DataFrame([{"avg": 50.0, "min": 1, "max": 100}])
return pd.DataFrame()
mock_db_client.query_df.side_effect = side_effect
# Run inspection
inspect_db_table(mock_db_client, data_dir=data_dir)
# Read summary to verify
summary = get_data_summary(data_dir)
assert summary is not None
assert "test_table" in summary
assert "category" in summary
assert "distinct_values" in summary
assert "unique_count: 5" in summary
assert "- A" in summary
assert "- E" in summary
assert "primary_key: id" in summary
def test_get_data_summary_none(tmp_path):
"""Test get_data_summary when file doesn't exist."""
assert get_data_summary(str(tmp_path)) is None
def test_inspect_db_table_temporal(mock_db_client, tmp_path):
"""Test inspect_db_table with temporal columns."""
data_dir = str(tmp_path)
columns_df = pd.DataFrame({
"column_name": ["created_at"],
"data_type": ["timestamp without time zone"]
})
total_rows_df = pd.DataFrame([{"count": 50}])
pk_df = pd.DataFrame() # No PK
def side_effect(query):
if "information_schema.columns" in query:
return columns_df
if "COUNT(*)" in query:
return total_rows_df
if "information_schema.key_column_usage" in query:
return pk_df
if 'COUNT("created_at")' in query:
return pd.DataFrame([{"count": 50}])
if 'COUNT(DISTINCT "created_at")' in query:
return pd.DataFrame([{"count": 50}])
if 'MIN("created_at")' in query:
return pd.DataFrame([{"min": "2023-01-01", "max": "2023-12-31"}])
return pd.DataFrame()
mock_db_client.query_df.side_effect = side_effect
inspect_db_table(mock_db_client, data_dir=data_dir)
summary = get_data_summary(data_dir)
assert "created_at" in summary
assert "min: '2023-01-01'" in summary
assert "max: '2023-12-31'" in summary
def test_inspect_db_table_high_cardinality(mock_db_client, tmp_path):
"""Test inspect_db_table with high cardinality categorical column (no sample values)."""
data_dir = str(tmp_path)
columns_df = pd.DataFrame({
"column_name": ["user_id"],
"data_type": ["text"]
})
total_rows_df = pd.DataFrame([{"count": 100}])
pk_df = pd.DataFrame()
def side_effect(query):
if "information_schema.columns" in query:
return columns_df
if "COUNT(*)" in query:
return total_rows_df
if "information_schema.key_column_usage" in query:
return pk_df
if 'COUNT("user_id")' in query:
return pd.DataFrame([{"count": 100}])
if 'COUNT(DISTINCT "user_id")' in query:
# High cardinality > 20
return pd.DataFrame([{"count": 50}])
return pd.DataFrame()
mock_db_client.query_df.side_effect = side_effect
inspect_db_table(mock_db_client, data_dir=data_dir)
summary = get_data_summary(data_dir)
assert "user_id" in summary
assert "unique_count: 50" in summary
# Should NOT have distinct_values
assert "distinct_values" not in summary
def test_inspect_db_table_checksum_skip(mock_db_client, tmp_path):
"""Test that inspection is skipped if checksum matches."""
data_dir = str(tmp_path)
table = "test_table"
# 1. Create a fake checksum file
os.makedirs(data_dir, exist_ok=True)
# Checksum is md5 of "ins|upd|del". Let's say mock returns "my_hash"
# Mock checksum query
mock_db_client.query_df.return_value = pd.DataFrame([{"dml_hash": "my_hash"}])
# Write existing checksum
with open(os.path.join(data_dir, "checksum"), "w") as f:
f.write(f"{table}:my_hash\n")
# Write existing inspection
with open(os.path.join(data_dir, "inspection.yaml"), "w") as f:
f.write(f"{table}: {{ existing: true }}")
# Run inspection
result = inspect_db_table(mock_db_client, data_dir=data_dir)
# Should return existing content
assert "existing: true" in result
# query_df should be called ONLY for checksum (once)
# verify count of calls?
# Logic: 1 call for checksum. If match, return.
assert mock_db_client.query_df.call_count == 1

123
tests/test_executor.py Normal file
View File

@@ -0,0 +1,123 @@
import pytest
import pandas as pd
from unittest.mock import MagicMock, patch
from matplotlib.figure import Figure
from ea_chatbot.graph.nodes.executor import executor_node
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_settings():
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
mock_settings_instance = MagicMock()
mock_settings_instance.db_host = "localhost"
mock_settings_instance.db_port = 5432
mock_settings_instance.db_user = "user"
mock_settings_instance.db_pswd = "pass"
mock_settings_instance.db_name = "test_db"
mock_settings_instance.db_table = "test_table"
MockSettings.return_value = mock_settings_instance
yield mock_settings_instance
@pytest.fixture
def mock_db_client():
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
mock_client_instance = MagicMock()
MockDBClient.return_value = mock_client_instance
yield mock_client_instance
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
"""Test executing simple code that prints to stdout."""
state = {
"code": "print('Hello, World!')",
"question": "test",
"messages": []
}
result = executor_node(state)
assert "code_output" in result
assert "Hello, World!" in result["code_output"]
assert result["error"] is None
assert result["plots"] == []
assert result["dfs"] == {}
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
"""Test executing code that creates a DataFrame."""
code = """
import pandas as pd
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
print(df)
"""
state = {
"code": code,
"question": "test",
"messages": []
}
result = executor_node(state)
assert "code_output" in result
assert "a b" in result["code_output"] # Check part of DF string representation
assert "dfs" in result
assert "df" in result["dfs"]
assert isinstance(result["dfs"]["df"], pd.DataFrame)
def test_executor_node_success_plot(mock_settings, mock_db_client):
"""Test executing code that generates a plot."""
code = """
import matplotlib.pyplot as plt
fig = plt.figure()
plots.append(fig)
print('Plot generated')
"""
state = {
"code": code,
"question": "test",
"messages": []
}
result = executor_node(state)
assert "Plot generated" in result["code_output"]
assert "plots" in result
assert len(result["plots"]) == 1
assert isinstance(result["plots"][0], Figure)
def test_executor_node_error_syntax(mock_settings, mock_db_client):
"""Test executing code with a syntax error."""
state = {
"code": "print('Hello World", # Missing closing quote
"question": "test",
"messages": []
}
result = executor_node(state)
assert result["error"] is not None
assert "SyntaxError" in result["error"]
def test_executor_node_error_runtime(mock_settings, mock_db_client):
"""Test executing code with a runtime error."""
state = {
"code": "print(1 / 0)",
"question": "test",
"messages": []
}
result = executor_node(state)
assert result["error"] is not None
assert "ZeroDivisionError" in result["error"]
def test_executor_node_no_code(mock_settings, mock_db_client):
"""Test handling when no code is provided."""
state = {
"code": None,
"question": "test",
"messages": []
}
result = executor_node(state)
assert "error" in result
assert "No code provided" in result["error"]

77
tests/test_helpers.py Normal file
View File

@@ -0,0 +1,77 @@
import pytest
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.utils.helpers import merge_agent_state
def test_merge_agent_state_list_accumulation():
"""Verify that list fields (messages, plots) are accumulated (appended)."""
current_state = {
"messages": [HumanMessage(content="hello")],
"plots": ["plot1"]
}
update = {
"messages": [AIMessage(content="hi")],
"plots": ["plot2"]
}
merged = merge_agent_state(current_state, update)
assert len(merged["messages"]) == 2
assert merged["messages"][0].content == "hello"
assert merged["messages"][1].content == "hi"
assert len(merged["plots"]) == 2
assert merged["plots"] == ["plot1", "plot2"]
def test_merge_agent_state_dict_update():
"""Verify that dictionary fields (dfs) are updated (shallow merge)."""
current_state = {
"dfs": {"df1": "data1"}
}
update = {
"dfs": {"df2": "data2"}
}
merged = merge_agent_state(current_state, update)
assert merged["dfs"] == {"df1": "data1", "df2": "data2"}
# Verify overwrite within dict
update_overwrite = {
"dfs": {"df1": "new_data1"}
}
merged_overwrite = merge_agent_state(merged, update_overwrite)
assert merged_overwrite["dfs"] == {"df1": "new_data1", "df2": "data2"}
def test_merge_agent_state_standard_overwrite():
"""Verify that standard fields are overwritten."""
current_state = {
"question": "old question",
"next_action": "old action",
"plan": "old plan"
}
update = {
"question": "new question",
"next_action": "new action",
"plan": "new plan"
}
merged = merge_agent_state(current_state, update)
assert merged["question"] == "new question"
assert merged["next_action"] == "new action"
assert merged["plan"] == "new plan"
def test_merge_agent_state_none_handling():
"""Verify that None updates or missing keys in update don't break things."""
current_state = {
"question": "test",
"messages": ["msg1"]
}
# Empty update
assert merge_agent_state(current_state, {}) == current_state
# Update with None value for overwritable field
merged_none = merge_agent_state(current_state, {"question": None})
assert merged_none["question"] is None
assert merged_none["messages"] == ["msg1"]

View File

@@ -0,0 +1,145 @@
import pytest
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.history.models import User, Conversation, Message, Plot
from ea_chatbot.config import Settings
from sqlalchemy import delete
@pytest.fixture
def history_manager():
settings = Settings()
manager = HistoryManager(settings.history_db_url)
# Clean up tables before tests (order matters because of foreign keys)
with manager.get_session() as session:
session.execute(delete(Plot))
session.execute(delete(Message))
session.execute(delete(Conversation))
session.execute(delete(User))
return manager
def test_history_manager_initialization(history_manager):
assert history_manager.engine is not None
assert history_manager.SessionLocal is not None
def test_history_manager_session_context(history_manager):
with history_manager.get_session() as session:
assert session is not None
def test_get_user_not_found(history_manager):
user = history_manager.get_user("nonexistent@example.com")
assert user is None
def test_authenticate_user_success(history_manager):
email = "test@example.com"
password = "secretpassword"
history_manager.create_user(email=email, password=password)
user = history_manager.authenticate_user(email, password)
assert user is not None
assert user.username == email
def test_authenticate_user_failure(history_manager):
email = "test@example.com"
history_manager.create_user(email=email, password="correctpassword")
user = history_manager.authenticate_user(email, "wrongpassword")
assert user is None
def test_sync_user_from_oidc_new_user(history_manager):
user = history_manager.sync_user_from_oidc(
email="new@example.com",
display_name="New User"
)
assert user is not None
assert user.username == "new@example.com"
assert user.display_name == "New User"
def test_sync_user_from_oidc_existing_user(history_manager):
# First sync
history_manager.sync_user_from_oidc(
email="existing@example.com",
display_name="First Name"
)
# Second sync should update or return same user
user = history_manager.sync_user_from_oidc(
email="existing@example.com",
display_name="Updated Name"
)
assert user.display_name == "Updated Name"
# --- Conversation Management Tests ---
@pytest.fixture
def user(history_manager):
return history_manager.create_user(email="conv_user@example.com")
def test_create_conversation(history_manager, user):
conv = history_manager.create_conversation(
user_id=user.id,
data_state="new_jersey",
name="Test Chat",
summary="A test conversation summary"
)
assert conv is not None
assert conv.name == "Test Chat"
assert conv.summary == "A test conversation summary"
assert conv.user_id == user.id
def test_get_conversations(history_manager, user):
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C1")
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C2")
history_manager.create_conversation(user_id=user.id, data_state="ny", name="C3")
nj_convs = history_manager.get_conversations(user_id=user.id, data_state="nj")
assert len(nj_convs) == 2
ny_convs = history_manager.get_conversations(user_id=user.id, data_state="ny")
assert len(ny_convs) == 1
def test_rename_conversation(history_manager, user):
conv = history_manager.create_conversation(user.id, "nj", "Old Name")
updated = history_manager.rename_conversation(conv.id, "New Name")
assert updated.name == "New Name"
def test_delete_conversation(history_manager, user):
conv = history_manager.create_conversation(user.id, "nj", "To Delete")
history_manager.delete_conversation(conv.id)
convs = history_manager.get_conversations(user.id, "nj")
assert len(convs) == 0
# --- Message Management Tests ---
@pytest.fixture
def conversation(history_manager, user):
return history_manager.create_conversation(user.id, "nj", "Msg Test Conv")
def test_add_message(history_manager, conversation):
msg = history_manager.add_message(
conversation_id=conversation.id,
role="user",
content="Hello world"
)
assert msg is not None
assert msg.content == "Hello world"
assert msg.role == "user"
assert msg.conversation_id == conversation.id
def test_add_message_with_plots(history_manager, conversation):
plots_data = [b"fake_plot_1", b"fake_plot_2"]
msg = history_manager.add_message(
conversation_id=conversation.id,
role="assistant",
content="Here are plots",
plots=plots_data
)
assert len(msg.plots) == 2
assert msg.plots[0].image_data == b"fake_plot_1"
def test_get_messages(history_manager, conversation):
history_manager.add_message(conversation.id, "user", "Q1")
history_manager.add_message(conversation.id, "assistant", "A1")
messages = history_manager.get_messages(conversation.id)
assert len(messages) == 2
assert messages[0].content == "Q1"
assert messages[1].content == "A1"

View File

@@ -0,0 +1,55 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, DeclarativeBase
# We anticipate these imports will fail initially
try:
from ea_chatbot.history.models import Base, User, Conversation, Message, Plot
except ImportError:
Base = None
User = None
Conversation = None
Message = None
Plot = None
def test_models_exist():
assert User is not None, "User model not found"
assert Conversation is not None, "Conversation model not found"
assert Message is not None, "Message model not found"
assert Plot is not None, "Plot model not found"
assert Base is not None, "Base declarative class not found"
def test_user_model_columns():
if not User: pytest.fail("User model undefined")
# Basic check if columns exist (by inspecting __table__.columns)
columns = User.__table__.columns.keys()
assert "id" in columns
assert "username" in columns
assert "password_hash" in columns
assert "display_name" in columns
def test_conversation_model_columns():
if not Conversation: pytest.fail("Conversation model undefined")
columns = Conversation.__table__.columns.keys()
assert "id" in columns
assert "user_id" in columns
assert "data_state" in columns
assert "name" in columns
assert "summary" in columns
assert "created_at" in columns
def test_message_model_columns():
if not Message: pytest.fail("Message model undefined")
columns = Message.__table__.columns.keys()
assert "id" in columns
assert "role" in columns
assert "content" in columns
assert "conversation_id" in columns
assert "created_at" in columns
def test_plot_model_columns():
if not Plot: pytest.fail("Plot model undefined")
columns = Plot.__table__.columns.keys()
assert "id" in columns
assert "message_id" in columns
assert "image_data" in columns

View File

@@ -0,0 +1,4 @@
import ea_chatbot.history
def test_history_module_importable():
assert ea_chatbot.history is not None

54
tests/test_llm_factory.py Normal file
View File

@@ -0,0 +1,54 @@
import pytest
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.config import LLMConfig
from ea_chatbot.utils.llm_factory import get_llm_model
def test_get_openai_model(monkeypatch):
"""Test creating an OpenAI model."""
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
config = LLMConfig(
provider="openai",
model="gpt-4o",
temperature=0.5,
max_tokens=100
)
model = get_llm_model(config)
assert isinstance(model, ChatOpenAI)
assert model.model_name == "gpt-4o"
assert model.temperature == 0.5
assert model.max_tokens == 100
def test_get_google_model(monkeypatch):
"""Test creating a Google model."""
monkeypatch.setenv("GOOGLE_API_KEY", "dummy")
config = LLMConfig(
provider="google",
model="gemini-1.5-pro",
temperature=0.7
)
model = get_llm_model(config)
assert isinstance(model, ChatGoogleGenerativeAI)
assert model.model == "gemini-1.5-pro"
assert model.temperature == 0.7
def test_unsupported_provider():
"""Test that an unsupported provider raises an error."""
config = LLMConfig(provider="unknown", model="test")
with pytest.raises(ValueError, match="Unsupported LLM provider: unknown"):
get_llm_model(config)
def test_provider_specific_params(monkeypatch):
"""Test passing provider specific params."""
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
config = LLMConfig(
provider="openai",
model="o1-preview",
provider_specific={"reasoning_effort": "high"}
)
# Note: reasoning_effort support depends on the langchain-openai version,
# but we check if kwargs are passed.
model = get_llm_model(config)
assert isinstance(model, ChatOpenAI)
# Check if reasoning_effort was passed correctly
assert getattr(model, "reasoning_effort", None) == "high"

View File

@@ -0,0 +1,19 @@
import pytest
from langchain_openai import ChatOpenAI
from ea_chatbot.config import LLMConfig
from ea_chatbot.utils.llm_factory import get_llm_model
from langchain_core.callbacks import BaseCallbackHandler
class MockHandler(BaseCallbackHandler):
pass
def test_get_llm_model_with_callbacks(monkeypatch):
"""Test that callbacks are passed to the model."""
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
config = LLMConfig(provider="openai", model="gpt-4o")
handler = MockHandler()
model = get_llm_model(config, callbacks=[handler])
assert isinstance(model, ChatOpenAI)
assert handler in model.callbacks

View File

@@ -0,0 +1,44 @@
import logging
import pytest
import io
import json
from ea_chatbot.utils.logging import ContextLoggerAdapter, JsonFormatter
@pytest.fixture
def json_log_capture():
"""Fixture to capture JSON logs."""
log_stream = io.StringIO()
logger = logging.getLogger("test_context")
logger.setLevel(logging.INFO)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler = logging.StreamHandler(log_stream)
handler.setFormatter(JsonFormatter())
logger.addHandler(handler)
return logger, log_stream
def test_context_logger_adapter_injects_metadata(json_log_capture):
"""Test that ContextLoggerAdapter injects metadata into the log record."""
logger, log_stream = json_log_capture
adapter = ContextLoggerAdapter(logger, {"run_id": "123", "node_name": "test_node"})
adapter.info("test message")
data = json.loads(log_stream.getvalue())
assert data["message"] == "test message"
assert data["run_id"] == "123"
assert data["node_name"] == "test_node"
def test_context_logger_adapter_override_metadata(json_log_capture):
"""Test that extra metadata can be provided during call."""
logger, log_stream = json_log_capture
adapter = ContextLoggerAdapter(logger, {"run_id": "123"})
# Passing extra context via the 'extra' parameter in standard logging
# Note: Our adapter should handle merging this.
adapter.info("test message", extra={"node_name": "dynamic_node"})
data = json.loads(log_stream.getvalue())
assert data["run_id"] == "123"
assert data["node_name"] == "dynamic_node"

View File

@@ -0,0 +1,67 @@
import logging
import pytest
from ea_chatbot.utils.logging import get_logger
@pytest.fixture(autouse=True)
def reset_logging():
"""Reset the ea_chatbot logger handlers before each test."""
logger = logging.getLogger("ea_chatbot")
# Remove all existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
yield
# Also clean up after test
for handler in logger.handlers[:]:
logger.removeHandler(handler)
def test_get_logger_singleton():
"""Test that get_logger returns the same logger instance for the same name."""
logger1 = get_logger("test_logger")
logger2 = get_logger("test_logger")
assert logger1 is logger2
def test_get_logger_rich_handler():
"""Test that get_logger configures a RichHandler on root."""
get_logger("test_rich")
root = logging.getLogger("ea_chatbot")
# Check if any handler is a RichHandler
handler_names = [h.__class__.__name__ for h in root.handlers]
assert "RichHandler" in handler_names
def test_get_logger_level():
"""Test that get_logger sets the correct log level."""
logger = get_logger("test_level", level="DEBUG")
assert logger.level == logging.DEBUG
def test_json_formatter_serializes_dict():
"""Test that JsonFormatter serializes log records to JSON."""
from ea_chatbot.utils.logging import JsonFormatter
import json
formatter = JsonFormatter()
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="test.py", lineno=10,
msg="test message", args=(), exc_info=None
)
formatted = formatter.format(record)
data = json.loads(formatted)
assert data["message"] == "test message"
assert data["level"] == "INFO"
assert "timestamp" in data
def test_get_logger_file_handler(tmp_path):
"""Test that get_logger configures a file handler on root."""
log_file = tmp_path / "test.json"
logger = get_logger("test_file", log_file=str(log_file))
root = logging.getLogger("ea_chatbot")
handler_names = [h.__class__.__name__ for h in root.handlers]
assert "RotatingFileHandler" in handler_names
logger.info("file log test")
# Check if file exists and has content
assert log_file.exists()
content = log_file.read_text()
assert "file log test" in content

83
tests/test_logging_e2e.py Normal file
View File

@@ -0,0 +1,83 @@
import os
import json
import pytest
import logging
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.state import AgentState
from ea_chatbot.utils.logging import get_logger
from langchain_community.chat_models import FakeListChatModel
from langchain_core.runnables import RunnableLambda
@pytest.fixture(autouse=True)
def reset_logging():
"""Reset handlers on the root ea_chatbot logger."""
root = logging.getLogger("ea_chatbot")
for handler in root.handlers[:]:
root.removeHandler(handler)
yield
for handler in root.handlers[:]:
root.removeHandler(handler)
class FakeStructuredModel(FakeListChatModel):
def with_structured_output(self, schema, **kwargs):
# Return a runnable that returns a parsed object
def _invoke(input, config=None, **kwargs):
content = self.responses[0]
import json
data = json.loads(content)
if hasattr(schema, "model_validate"):
return schema.model_validate(data)
return data
return RunnableLambda(_invoke)
def test_logging_e2e_json_output(tmp_path):
"""Test that a full graph run produces structured JSON logs from multiple nodes."""
log_file = tmp_path / "e2e_test.jsonl"
# Configure the root logger
get_logger("ea_chatbot", log_file=str(log_file))
initial_state: AgentState = {
"messages": [],
"question": "Who won in 2024?",
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"code_output": None,
"error": None,
"plots": [],
"dfs": {}
}
# Create fake models that support callbacks and structured output
fake_analyzer_response = """{"data_required": [], "unknowns": [], "ambiguities": ["Which year?"], "conditions": [], "next_action": "clarify"}"""
fake_analyzer = FakeStructuredModel(responses=[fake_analyzer_response])
fake_clarify = FakeListChatModel(responses=["Please specify."])
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
mock_llm_factory.return_value = fake_analyzer
with patch("ea_chatbot.graph.nodes.clarification.get_llm_model") as mock_clarify_llm_factory:
mock_clarify_llm_factory.return_value = fake_clarify
# Run the graph
list(app.stream(initial_state))
# Verify file content
assert log_file.exists()
lines = log_file.read_text().splitlines()
assert len(lines) > 0
# Verify we have logs from different nodes
node_names = [json.loads(line)["name"] for line in lines]
assert "ea_chatbot.query_analyzer" in node_names
assert "ea_chatbot.clarification" in node_names
# Verify events
messages = [json.loads(line)["message"] for line in lines]
assert any("Analyzing question" in m for m in messages)
assert any("Clarification generated" in m for m in messages)

View File

@@ -0,0 +1,64 @@
import logging
import pytest
import io
from unittest.mock import MagicMock
from ea_chatbot.utils.logging import LangChainLoggingHandler
@pytest.fixture
def log_capture():
"""Fixture to capture logs from a logger."""
log_stream = io.StringIO()
logger = logging.getLogger("test_langchain")
logger.setLevel(logging.INFO)
# Remove existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler = logging.StreamHandler(log_stream)
logger.addHandler(handler)
return logger, log_stream
def test_langchain_logging_handler_on_llm_start(log_capture):
"""Test that on_llm_start logs the correct message."""
logger, log_stream = log_capture
handler = LangChainLoggingHandler(logger=logger)
handler.on_llm_start(serialized={"name": "test_model"}, prompts=["test prompt"])
output = log_stream.getvalue()
assert "LLM Started:" in output
assert "test_model" in output
def test_langchain_logging_handler_on_llm_end(log_capture):
"""Test that on_llm_end logs token usage."""
logger, log_stream = log_capture
handler = LangChainLoggingHandler(logger=logger)
response = MagicMock()
response.llm_output = {
"token_usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
},
"model_name": "test_model"
}
handler.on_llm_end(response=response)
output = log_stream.getvalue()
assert "LLM Ended:" in output
assert "test_model" in output
assert "Tokens: 30" in output
assert "10 prompt" in output
assert "20 completion" in output
def test_langchain_logging_handler_on_llm_error(log_capture):
"""Test that on_llm_error logs the error."""
logger, log_stream = log_capture
handler = LangChainLoggingHandler(logger=logger)
error = Exception("test error")
handler.on_llm_error(error=error)
output = log_stream.getvalue()
assert "LLM Error:" in output
assert "test error" in output

View File

@@ -0,0 +1,79 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.planner import planner_node
from ea_chatbot.graph.nodes.researcher import researcher_node
from ea_chatbot.graph.nodes.summarizer import summarizer_node
from ea_chatbot.schemas import TaskPlanResponse
@pytest.fixture
def mock_state_with_history():
return {
"messages": [
HumanMessage(content="Show me the 2024 results for Florida"),
AIMessage(content="Here are the results for Florida in 2024...")
],
"question": "What about in New Jersey?",
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []},
"next_action": "plan",
"summary": "The user is asking about 2024 election results.",
"plan": "Plan steps...",
"code_output": "Code output..."
}
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
@patch("ea_chatbot.graph.nodes.planner.PLANNER_PROMPT")
def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_get_llm, mock_state_with_history):
mock_get_summary.return_value = "Data summary"
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = TaskPlanResponse(
goal="goal",
reflection="reflection",
context={
"initial_context": "context",
"assumptions": [],
"constraints": []
},
steps=["Step 1: test"]
)
planner_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
researcher_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
summarizer_node(mock_state_with_history)
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert kwargs["summary"] == mock_state_with_history["summary"]
assert len(kwargs["history"]) == 2

View File

@@ -0,0 +1,76 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_state_with_history():
return {
"messages": [
HumanMessage(content="Show me the 2024 results for Florida"),
AIMessage(content="Here are the results for Florida in 2024...")
],
"question": "What about in New Jersey?",
"analysis": None,
"next_action": "",
"summary": "The user is asking about 2024 election results."
}
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT")
def test_query_analyzer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
"""Test that query_analyzer_node passes history and summary to the prompt."""
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = QueryAnalysis(
data_required=["2024 results", "New Jersey"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
query_analyzer_node(mock_state_with_history)
# Verify that the prompt was formatted with the correct variables
mock_prompt.format_messages.assert_called_once()
kwargs = mock_prompt.format_messages.call_args[1]
assert kwargs["question"] == "What about in New Jersey?"
assert "summary" in kwargs
assert kwargs["summary"] == mock_state_with_history["summary"]
assert "history" in kwargs
# History should contain the messages from the state
assert len(kwargs["history"]) == 2
assert kwargs["history"][0].content == "Show me the 2024 results for Florida"
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_query_analyzer_context_window(mock_get_llm):
"""Test that query_analyzer_node only uses the last 6 messages (3 turns)."""
messages = [HumanMessage(content=f"Msg {i}") for i in range(10)]
state = {
"messages": messages,
"question": "Latest question",
"analysis": None,
"next_action": "",
"summary": "Summary"
}
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = QueryAnalysis(
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
)
with patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT") as mock_prompt:
query_analyzer_node(state)
kwargs = mock_prompt.format_messages.call_args[1]
# Should only have last 6 messages
assert len(kwargs["history"]) == 6
assert kwargs["history"][0].content == "Msg 4"

87
tests/test_oidc_client.py Normal file
View File

@@ -0,0 +1,87 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.auth import OIDCClient
@pytest.fixture
def oidc_config():
return {
"client_id": "test_id",
"client_secret": "test_secret",
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
"redirect_uri": "http://localhost:8501"
}
@pytest.fixture
def mock_metadata():
return {
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"userinfo_endpoint": "https://example.com/userinfo"
}
def test_oidc_fetch_metadata(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
with patch("requests.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_metadata
mock_get.return_value = mock_response
metadata = client.fetch_metadata()
assert metadata == mock_metadata
mock_get.assert_called_once_with(oidc_config["server_metadata_url"])
# Second call should use cache
client.fetch_metadata()
assert mock_get.call_count == 1
def test_oidc_get_login_url(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "create_authorization_url") as mock_create_url:
mock_create_url.return_value = ("https://example.com/auth?state=xyz", "xyz")
url = client.get_login_url()
assert url == "https://example.com/auth?state=xyz"
mock_create_url.assert_called_once_with(mock_metadata["authorization_endpoint"])
def test_oidc_get_login_url_missing_endpoint(oidc_config):
client = OIDCClient(**oidc_config)
client.metadata = {"some_other": "field"}
with pytest.raises(ValueError, match="authorization_endpoint not found"):
client.get_login_url()
def test_oidc_exchange_code_for_token(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "fetch_token") as mock_fetch_token:
mock_fetch_token.return_value = {"access_token": "abc"}
token = client.exchange_code_for_token("test_code")
assert token == {"access_token": "abc"}
mock_fetch_token.assert_called_once_with(
mock_metadata["token_endpoint"],
code="test_code",
client_secret=oidc_config["client_secret"]
)
def test_oidc_get_user_info(oidc_config, mock_metadata):
client = OIDCClient(**oidc_config)
client.metadata = mock_metadata
with patch.object(client.oauth_session, "get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {"sub": "user123"}
mock_get.return_value = mock_response
user_info = client.get_user_info({"access_token": "abc"})
assert user_info == {"sub": "user123"}
assert client.oauth_session.token == {"access_token": "abc"}
mock_get.assert_called_once_with(mock_metadata["userinfo_endpoint"])

46
tests/test_planner.py Normal file
View File

@@ -0,0 +1,46 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.planner import planner_node
@pytest.fixture
def mock_state():
return {
"messages": [],
"question": "Show me results for New Jersey",
"analysis": {
# "requires_dataset" removed as it's no longer used
"expert": "Data Analyst",
"data": "NJ data",
"unknown": "results",
"condition": "state=NJ"
},
"next_action": "plan",
"plan": None
}
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
"""Test planner node with unified prompt."""
mock_get_summary.return_value = "Column: Name, Type: text"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext
mock_plan = TaskPlanResponse(
goal="Get NJ results",
reflection="The user wants NJ results",
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]),
steps=["Step 1: Load data", "Step 2: Filter by NJ"]
)
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan
result = planner_node(mock_state)
assert "plan" in result
assert "Step 1: Load data" in result["plan"]
assert "Step 2: Filter by NJ" in result["plan"]
# Verify helper was called
mock_get_summary.assert_called_once()

View File

@@ -0,0 +1,80 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_state():
return {
"messages": [],
"question": "Show me the 2024 results for Florida",
"analysis": None,
"next_action": ""
}
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_query_analyzer_data_analysis(mock_get_llm, mock_state):
"""Test that a clear data analysis query is routed to the planner."""
# Mock the LLM and the structured output runnable
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
# Define the expected Pydantic result
expected_analysis = QueryAnalysis(
data_required=["2024 results", "Florida"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
# When structured_llm.invoke is called with messages, return the Pydantic object
mock_structured_llm.invoke.return_value = expected_analysis
new_state_update = query_analyzer_node(mock_state)
assert new_state_update["next_action"] == "plan"
assert "2024 results" in new_state_update["analysis"]["data_required"]
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_query_analyzer_ambiguous(mock_get_llm, mock_state):
"""Test that an ambiguous query is routed to clarification."""
mock_state["question"] = "What happened?"
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
expected_analysis = QueryAnalysis(
data_required=[],
unknowns=["What event?"],
ambiguities=[],
conditions=[],
next_action="clarify"
)
mock_structured_llm.invoke.return_value = expected_analysis
new_state_update = query_analyzer_node(mock_state)
assert new_state_update["next_action"] == "clarify"
assert len(new_state_update["analysis"]["unknowns"]) > 0
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_query_analyzer_uses_config(mock_get_llm, mock_state, monkeypatch):
"""Test that the node uses the configured LLM settings."""
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
mock_structured_llm.invoke.return_value = QueryAnalysis(
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
)
query_analyzer_node(mock_state)
# Verify get_llm_model was called with the overridden config
called_config = mock_get_llm.call_args[0][0]
assert called_config.model == "gpt-3.5-turbo"

View File

@@ -0,0 +1,45 @@
import pytest
import logging
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
@pytest.fixture
def mock_state():
return {
"messages": [],
"question": "Show me the 2024 results for Florida",
"analysis": None,
"next_action": ""
}
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.query_analyzer.get_logger")
def test_query_analyzer_logs_actions(mock_get_logger, mock_get_llm, mock_state):
"""Test that query_analyzer_node logs its main actions."""
# Mock Logger
mock_logger = MagicMock()
mock_get_logger.return_value = mock_logger
# Mock LLM
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
mock_structured_llm = MagicMock()
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
expected_analysis = QueryAnalysis(
data_required=["2024 results", "Florida"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
mock_structured_llm.invoke.return_value = expected_analysis
query_analyzer_node(mock_state)
# Check that logger was called
# We expect at least one log at the start and one at the end
assert mock_logger.info.called
# Verify specific log messages if we decide on them
# For now, just ensuring it's called is enough for Red phase

View File

@@ -0,0 +1,103 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def base_state():
return {
"messages": [],
"question": "",
"analysis": None,
"next_action": "",
"summary": ""
}
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_refinement_coreference_from_history(mock_get_llm, base_state):
"""
Test that the analyzer can resolve Year/State from history.
User asks "What about in NJ?" after a Florida 2024 query.
Expected: next_action = 'plan', NOT 'clarify' due to missing year.
"""
state = base_state.copy()
state["messages"] = [
HumanMessage(content="Show me 2024 results for Florida"),
AIMessage(content="Here are the 2024 results for Florida...")
]
state["question"] = "What about in New Jersey?"
state["summary"] = "The user is looking for 2024 election results."
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
mock_structured = MagicMock()
mock_llm.with_structured_output.return_value = mock_structured
# We expect the LLM to eventually return 'plan' because it sees the context.
# For now, if it returns 'clarify', this test should fail once we update the prompt to BE less strict.
mock_structured.invoke.return_value = QueryAnalysis(
data_required=["2024 results", "New Jersey"],
unknowns=[],
ambiguities=[],
conditions=["state=NJ", "year=2024"],
next_action="plan"
)
result = query_analyzer_node(state)
assert result["next_action"] == "plan"
assert "NJ" in str(result["analysis"]["conditions"])
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_refinement_tolerance_for_missing_format(mock_get_llm, base_state):
"""
Test that the analyzer doesn't flag missing output format or database name.
User asks "Give me a graph of turnout".
Expected: next_action = 'plan', even if 'format' or 'db' is not in query.
"""
state = base_state.copy()
state["question"] = "Give me a graph of voter turnout in 2024 for Florida"
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
mock_structured = MagicMock()
mock_llm.with_structured_output.return_value = mock_structured
mock_structured.invoke.return_value = QueryAnalysis(
data_required=["voter turnout", "Florida"],
unknowns=[],
ambiguities=[],
conditions=["year=2024"],
next_action="plan"
)
result = query_analyzer_node(state)
assert result["next_action"] == "plan"
# Ensure no ambiguities were added by the analyzer itself (hallucinated requirement)
assert len(result["analysis"]["ambiguities"]) == 0
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
def test_refinement_enforces_voter_identity_clarification(mock_get_llm, base_state):
"""
Test that 'track the same voter' still triggers clarification.
"""
state = base_state.copy()
state["question"] = "Track the same voter participation in 2020 and 2024."
mock_llm = MagicMock()
mock_get_llm.return_value = mock_llm
mock_structured = MagicMock()
mock_llm.with_structured_output.return_value = mock_structured
# We WANT it to clarify here because voter identity is not defined.
mock_structured.invoke.return_value = QueryAnalysis(
data_required=["voter participation"],
unknowns=[],
ambiguities=["Please define what fields constitute 'the same voter' (e.g. ID, or Name and DOB)."],
conditions=[],
next_action="clarify"
)
result = query_analyzer_node(state)
assert result["next_action"] == "clarify"
assert "identity" in str(result["analysis"]["ambiguities"]).lower() or "same voter" in str(result["analysis"]["ambiguities"]).lower()

34
tests/test_researcher.py Normal file
View File

@@ -0,0 +1,34 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from ea_chatbot.graph.nodes.researcher import researcher_node
@pytest.fixture
def mock_llm():
with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm:
mock_llm_instance = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm_instance
yield mock_llm_instance
def test_researcher_node_success(mock_llm):
"""Test that researcher_node invokes LLM with web_search tool and returns messages."""
state = {
"question": "What is the capital of France?",
"messages": []
}
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.")
result = researcher_node(state)
assert mock_llm.bind_tools.called
# Check that it was called with web_search
args, kwargs = mock_llm.bind_tools.call_args
assert {"type": "web_search"} in args[0]
assert mock_llm_with_tools.invoke.called
assert "messages" in result
assert result["messages"][0].content == "The capital of France is Paris."

View File

@@ -0,0 +1,62 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.nodes.researcher import researcher_node
@pytest.fixture
def base_state():
return {
"question": "Who won the 2024 election?",
"messages": [],
"summary": ""
}
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_binds_openai_search(mock_get_llm, base_state):
"""Test that OpenAI LLM binds 'web_search' tool."""
mock_llm = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
result = researcher_node(base_state)
# Verify bind_tools called with correct OpenAI tool
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
assert result["messages"][0].content == "OpenAI Search Result"
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_binds_google_search(mock_get_llm, base_state):
"""Test that Google LLM binds 'google_search' tool."""
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
mock_get_llm.return_value = mock_llm
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
result = researcher_node(base_state)
# Verify bind_tools called with correct Google tool
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
assert result["messages"][0].content == "Google Search Result"
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
"""Test that researcher falls back to basic LLM if bind_tools fails."""
mock_llm = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm
# Simulate bind_tools failing (e.g. model doesn't support it)
mock_llm.bind_tools.side_effect = Exception("Not supported")
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
result = researcher_node(base_state)
# Should still succeed using the base LLM
assert result["messages"][0].content == "Basic Result"
mock_llm.invoke.assert_called_once()

41
tests/test_state.py Normal file
View File

@@ -0,0 +1,41 @@
import pytest
from typing import get_type_hints, List
from langchain_core.messages import BaseMessage, HumanMessage
from ea_chatbot.graph.state import AgentState
import operator
def test_agent_state_structure():
"""Verify that AgentState has the required fields and types."""
hints = get_type_hints(AgentState)
assert "messages" in hints
# Check if Annotated is used, we might need to inspect the __metadata__ if feasible,
# but for TypedDict, checking the key existence is a good start.
# The exact type check for Annotated[List[BaseMessage], operator.add] can be complex to assert strictly,
# but we can check if it's there.
assert "question" in hints
assert hints["question"] == str
# analysis should be Optional[Dict[str, Any]] or similar, but the spec says "Dictionary"
# Let's check it exists.
assert "analysis" in hints
assert "next_action" in hints
assert hints["next_action"] == str
assert "summary" in hints
# summary should be Optional[str] or str. Let's assume Optional[str] for flexibility.
assert "plots" in hints
assert "dfs" in hints
def test_messages_reducer_behavior():
"""Verify that the messages field allows adding lists (simulation of operator.add)."""
# This is harder to test directly on the TypedDict definition without instantiating it in a graph context,
# but we can verify that the type hint implies a list.
hints = get_type_hints(AgentState)
# We expect messages to be Annotated[List[BaseMessage], operator.add]
# We can just assume the developer implements it correctly if the previous test passes,
# or try to inspect the annotation.
pass

48
tests/test_summarizer.py Normal file
View File

@@ -0,0 +1,48 @@
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ea_chatbot.graph.nodes.summarizer import summarizer_node
from ea_chatbot.graph.state import AgentState
@pytest.fixture
def mock_llm():
with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm:
mock_llm_instance = MagicMock()
mock_get_llm.return_value = mock_llm_instance
yield mock_llm_instance
def test_summarizer_node_success(mock_llm):
"""Test that summarizer_node invokes LLM with correct inputs and returns messages."""
state = {
"question": "What is the total count?",
"plan": "1. Run query\n2. Sum results",
"code_output": "The total is 100",
"messages": []
}
mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.")
result = summarizer_node(state)
# Verify LLM was called
assert mock_llm.invoke.called
# Verify result structure
assert "messages" in result
assert len(result["messages"]) == 1
assert isinstance(result["messages"][0], AIMessage)
assert result["messages"][0].content == "The final answer is 100."
def test_summarizer_node_empty_state(mock_llm):
"""Test handling of empty or minimal state."""
state = {
"question": "Empty?",
"messages": []
}
mock_llm.invoke.return_value = AIMessage(content="No data provided.")
result = summarizer_node(state)
assert "messages" in result
assert result["messages"][0].content == "No data provided."

93
tests/test_workflow.py Normal file
View File

@@ -0,0 +1,93 @@
import pytest
from unittest.mock import MagicMock, patch
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
from langchain_core.messages import AIMessage
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
@patch("ea_chatbot.graph.nodes.executor.Settings")
@patch("ea_chatbot.graph.nodes.executor.DBClient")
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
"""Test the flow from query_analyzer through planner to coder."""
# Mock Settings for Executor
mock_settings_instance = MagicMock()
mock_settings_instance.db_host = "localhost"
mock_settings_instance.db_port = 5432
mock_settings_instance.db_user = "user"
mock_settings_instance.db_pswd = "pass"
mock_settings_instance.db_name = "test_db"
mock_settings_instance.db_table = "test_table"
mock_settings.return_value = mock_settings_instance
# Mock DBClient
mock_client_instance = MagicMock()
mock_db_client.return_value = mock_client_instance
# 1. Mock Query Analyzer
mock_qa_instance = MagicMock()
mock_qa_llm.return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=["2024 results"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
# 2. Mock Planner
mock_planner_instance = MagicMock()
mock_planner_llm.return_value = mock_planner_instance
mock_get_summary.return_value = "Data summary"
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
goal="Task Goal",
reflection="Reflection",
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
steps=["Step 1"]
)
# 3. Mock Coder
mock_coder_instance = MagicMock()
mock_coder_llm.return_value = mock_coder_instance
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
code="print('Hello')",
explanation="Explanation"
)
# 4. Mock Summarizer
mock_summarizer_instance = MagicMock()
mock_summarizer_llm.return_value = mock_summarizer_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
# 5. Mock Researcher (just in case)
mock_researcher_instance = MagicMock()
mock_researcher_llm.return_value = mock_researcher_instance
# Initial state
initial_state = {
"messages": [],
"question": "Show me the 2024 results",
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"error": None,
"plots": [],
"dfs": {}
}
# Run the graph
# We use recursion_limit to avoid infinite loops in placeholders if any
result = app.invoke(initial_state, config={"recursion_limit": 10})
assert result["next_action"] == "plan"
assert "plan" in result and result["plan"] is not None
assert "code" in result and "print('Hello')" in result["code"]
assert "analysis" in result

139
tests/test_workflow_e2e.py Normal file
View File

@@ -0,0 +1,139 @@
import pytest
import yaml
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from ea_chatbot.graph.workflow import app
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
@pytest.fixture
def mock_llms():
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary:
mock_get_summary.return_value = "Data summary"
# Mock summary LLM to return a simple response
mock_summary_instance = MagicMock()
mock_summary_llm.return_value = mock_summary_instance
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
yield {
"qa": mock_qa_llm,
"planner": mock_planner_llm,
"coder": mock_coder_llm,
"summarizer": mock_summarizer_llm,
"researcher": mock_researcher_llm,
"summary": mock_summary_llm
}
def test_workflow_data_analysis_flow(mock_llms):
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer."""
# 1. Mock Query Analyzer (routes to plan)
mock_qa_instance = MagicMock()
mock_llms["qa"].return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=["2024 results"],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="plan"
)
# 2. Mock Planner
mock_planner_instance = MagicMock()
mock_llms["planner"].return_value = mock_planner_instance
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
goal="Get results",
reflection="Reflect",
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
steps=["Step 1"]
)
# 3. Mock Coder
mock_coder_instance = MagicMock()
mock_llms["coder"].return_value = mock_coder_instance
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
code="print('Execution Success')",
explanation="Explain"
)
# 4. Mock Summarizer
mock_summarizer_instance = MagicMock()
mock_llms["summarizer"].return_value = mock_summarizer_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
# Initial state
initial_state = {
"messages": [],
"question": "Show me 2024 results",
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"error": None,
"plots": [],
"dfs": {}
}
# Run the graph
result = app.invoke(initial_state, config={"recursion_limit": 15})
assert result["next_action"] == "plan"
assert "Execution Success" in result["code_output"]
assert "Final Summary: Success" in result["messages"][-1].content
def test_workflow_research_flow(mock_llms):
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer."""
# 1. Mock Query Analyzer (routes to research)
mock_qa_instance = MagicMock()
mock_llms["qa"].return_value = mock_qa_instance
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
data_required=[],
unknowns=[],
ambiguities=[],
conditions=[],
next_action="research"
)
# 2. Mock Researcher
mock_researcher_instance = MagicMock()
mock_llms["researcher"].return_value = mock_researcher_instance
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI
# Since it's a MagicMock, it will fallback to using the base instance
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results")
# Also mock bind_tools just in case we ever use spec
mock_llm_with_tools = MagicMock()
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results")
# 3. Mock Summarizer (not used in this flow, but kept for completeness)
mock_summarizer_instance = MagicMock()
mock_llms["summarizer"].return_value = mock_summarizer_instance
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success")
# Initial state
initial_state = {
"messages": [],
"question": "Who is the governor of Florida?",
"analysis": None,
"next_action": "",
"plan": None,
"code": None,
"error": None,
"plots": [],
"dfs": {}
}
# Run the graph
result = app.invoke(initial_state, config={"recursion_limit": 10})
assert result["next_action"] == "research"
assert "Research Results" in result["messages"][-1].content

View File

@@ -0,0 +1,62 @@
import pytest
from ea_chatbot.history.manager import HistoryManager
from ea_chatbot.history.models import User, Conversation, Message, Plot
from ea_chatbot.config import Settings
from sqlalchemy import delete
@pytest.fixture
def history_manager():
settings = Settings()
manager = HistoryManager(settings.history_db_url)
with manager.get_session() as session:
session.execute(delete(Plot))
session.execute(delete(Message))
session.execute(delete(Conversation))
session.execute(delete(User))
return manager
def test_full_history_workflow(history_manager):
# 1. Create and Authenticate User
email = "e2e@example.com"
password = "password123"
history_manager.create_user(email, password, "E2E User")
user = history_manager.authenticate_user(email, password)
assert user is not None
assert user.display_name == "E2E User"
# 2. Create Conversation
conv = history_manager.create_conversation(user.id, "nj", "Test Analytics")
assert conv.id is not None
# 3. Add User Message
history_manager.add_message(conv.id, "user", "How many voters in NJ?")
# 4. Add Assistant Message with Plot
plot_data = b"fake_png_data"
history_manager.add_message(
conv.id,
"assistant",
"There are X voters.",
plots=[plot_data]
)
# 5. Retrieve and Verify History
messages = history_manager.get_messages(conv.id)
assert len(messages) == 2
assert messages[0].role == "user"
assert messages[1].role == "assistant"
assert len(messages[1].plots) == 1
assert messages[1].plots[0].image_data == plot_data
# 6. Verify Conversation listing
convs = history_manager.get_conversations(user.id, "nj")
assert len(convs) == 1
assert convs[0].name == "Test Analytics"
# 7. Update summary
history_manager.update_conversation_summary(conv.id, "Voter count analysis")
# 8. Reload and verify summary
updated_convs = history_manager.get_conversations(user.id, "nj")
assert updated_convs[0].summary == "Voter count analysis"