Refactor: Move backend files to backend/ directory and split .gitignore
This commit is contained in:
64
backend/tests/api/test_agent.py
Normal file
64
backend/tests/api/test_agent.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
from ea_chatbot.history.models import User
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
return User(id="user-123", username="test@example.com", display_name="Test User")
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header(mock_user):
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
token = create_access_token(data={"sub": mock_user.id})
|
||||
yield {"Authorization": f"Bearer {token}"}
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
def test_stream_agent_unauthorized():
|
||||
"""Test that streaming requires authentication."""
|
||||
response = client.post("/chat/stream", json={"message": "hello"})
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_agent_success(auth_header, mock_user):
|
||||
"""Test successful agent streaming with SSE."""
|
||||
# We need to mock the LangGraph app.astream_events
|
||||
mock_events = [
|
||||
{"event": "on_chat_model_start", "name": "gpt-5", "data": {"input": "..."}},
|
||||
{"event": "on_chat_model_stream", "name": "gpt-5", "data": {"chunk": "Hello"}},
|
||||
{"event": "on_chain_end", "name": "agent", "data": {"output": "..."}}
|
||||
]
|
||||
|
||||
async def mock_astream_events(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
|
||||
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
||||
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
||||
|
||||
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
||||
|
||||
# Mock session and DB objects
|
||||
mock_session = MagicMock()
|
||||
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||
from ea_chatbot.history.models import Conversation
|
||||
mock_conv = Conversation(id="t1", user_id=mock_user.id)
|
||||
mock_session.get.return_value = mock_conv
|
||||
|
||||
# Using TestClient with a stream context
|
||||
with client.stream("POST", "/chat/stream",
|
||||
json={"message": "hello", "thread_id": "t1"},
|
||||
headers=auth_header) as response:
|
||||
assert response.status_code == 200
|
||||
assert "text/event-stream" in response.headers["content-type"]
|
||||
|
||||
lines = list(response.iter_lines())
|
||||
# Each event should start with 'data: ' and be valid JSON
|
||||
data_lines = [line for line in lines if line.startswith("data: ")]
|
||||
assert len(data_lines) >= len(mock_events)
|
||||
107
backend/tests/api/test_api_auth.py
Normal file
107
backend/tests/api/test_api_auth.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.history.models import User
|
||||
|
||||
# We will need to mock HistoryManager and get_db dependencies later
|
||||
# For now, we define the expected behavior of the auth endpoints.
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
return User(
|
||||
id="user-123",
|
||||
username="test@example.com",
|
||||
display_name="Test User",
|
||||
password_hash="hashed_password"
|
||||
)
|
||||
|
||||
def test_register_user_success():
|
||||
"""Test successful user registration."""
|
||||
# We mock it where it is used in the router
|
||||
with patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
mock_hm.get_user.return_value = None
|
||||
mock_hm.create_user.return_value = User(id="1", username="new@example.com", display_name="New")
|
||||
|
||||
response = client.post(
|
||||
"/auth/register",
|
||||
json={"email": "new@example.com", "password": "password123", "display_name": "New"}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["email"] == "new@example.com"
|
||||
|
||||
def test_login_success():
|
||||
"""Test successful login and JWT return."""
|
||||
with patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
mock_hm.authenticate_user.return_value = User(id="1", username="test@example.com")
|
||||
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
data={"username": "test@example.com", "password": "password123"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
|
||||
def test_login_invalid_credentials():
|
||||
"""Test login with wrong password."""
|
||||
with patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
mock_hm.authenticate_user.return_value = None
|
||||
|
||||
response = client.post(
|
||||
"/auth/login",
|
||||
data={"username": "test@example.com", "password": "wrongpassword"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "detail" in response.json()
|
||||
|
||||
def test_protected_route_without_token():
|
||||
"""Test that protected routes require a token."""
|
||||
response = client.get("/auth/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_oidc_login_redirect():
|
||||
"""Test that OIDC login returns a redirect URL."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc:
|
||||
mock_oidc.get_login_url.return_value = "https://oidc-provider.com/auth"
|
||||
|
||||
response = client.get("/auth/oidc/login")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == "https://oidc-provider.com/auth"
|
||||
|
||||
def test_oidc_callback_success():
|
||||
"""Test successful OIDC callback and JWT issuance."""
|
||||
with patch("ea_chatbot.api.routers.auth.oidc_client") as mock_oidc, \
|
||||
patch("ea_chatbot.api.routers.auth.history_manager") as mock_hm:
|
||||
|
||||
mock_oidc.exchange_code_for_token.return_value = {"access_token": "oidc-token"}
|
||||
mock_oidc.get_user_info.return_value = {"email": "sso@example.com", "name": "SSO User"}
|
||||
mock_hm.sync_user_from_oidc.return_value = User(id="sso-123", username="sso@example.com", display_name="SSO User")
|
||||
|
||||
response = client.get("/auth/oidc/callback?code=some-code")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
|
||||
def test_get_me_success():
|
||||
"""Test getting current user with a valid token."""
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
token = create_access_token(data={"sub": "123"})
|
||||
|
||||
with patch("ea_chatbot.api.dependencies.history_manager") as mock_hm:
|
||||
mock_hm.get_user_by_id.return_value = User(id="123", username="test@example.com", display_name="Test")
|
||||
|
||||
response = client.get(
|
||||
"/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["email"] == "test@example.com"
|
||||
assert response.json()["id"] == "123"
|
||||
116
backend/tests/api/test_api_history.py
Normal file
116
backend/tests/api/test_api_history.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
from ea_chatbot.history.models import Conversation, Message, Plot, User
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from datetime import datetime, timezone
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = User(id="user-123", username="test@example.com", display_name="Test User")
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header(mock_user):
|
||||
# Override get_current_user to return our mock user
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
token = create_access_token(data={"sub": mock_user.id})
|
||||
yield {"Authorization": f"Bearer {token}"}
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
def test_get_conversations_success(auth_header, mock_user):
|
||||
"""Test retrieving list of conversations."""
|
||||
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
||||
mock_hm.get_conversations.return_value = [
|
||||
Conversation(
|
||||
id="c1",
|
||||
name="Conv 1",
|
||||
user_id=mock_user.id,
|
||||
data_state="nj",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
]
|
||||
|
||||
response = client.get("/conversations", headers=auth_header)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert response.json()[0]["name"] == "Conv 1"
|
||||
|
||||
def test_create_conversation_success(auth_header, mock_user):
|
||||
"""Test creating a new conversation."""
|
||||
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
||||
mock_hm.create_conversation.return_value = Conversation(
|
||||
id="c2",
|
||||
name="New Conv",
|
||||
user_id=mock_user.id,
|
||||
data_state="nj",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/conversations",
|
||||
json={"name": "New Conv"},
|
||||
headers=auth_header
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == "New Conv"
|
||||
assert response.json()["id"] == "c2"
|
||||
|
||||
def test_get_messages_success(auth_header):
|
||||
"""Test retrieving messages for a conversation."""
|
||||
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
||||
mock_hm.get_messages.return_value = [
|
||||
Message(
|
||||
id="m1",
|
||||
role="user",
|
||||
content="Hello",
|
||||
conversation_id="c1",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
]
|
||||
|
||||
response = client.get("/conversations/c1/messages", headers=auth_header)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert response.json()[0]["content"] == "Hello"
|
||||
|
||||
def test_delete_conversation_success(auth_header):
|
||||
"""Test deleting a conversation."""
|
||||
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
||||
mock_hm.delete_conversation.return_value = True
|
||||
|
||||
response = client.delete("/conversations/c1", headers=auth_header)
|
||||
assert response.status_code == 204
|
||||
|
||||
def test_get_plot_success(auth_header, mock_user):
|
||||
"""Test retrieving a plot artifact."""
|
||||
with patch("ea_chatbot.api.routers.artifacts.history_manager") as mock_hm:
|
||||
# Mocking finding a plot by ID
|
||||
mock_session = MagicMock()
|
||||
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mocking the models and their relationships
|
||||
mock_conv = Conversation(id="c1", user_id=mock_user.id, user=mock_user)
|
||||
mock_msg = Message(id="m1", conversation_id="c1", conversation=mock_conv)
|
||||
mock_plot = Plot(id="p1", image_data=b"fake-image-data", message_id="m1", message=mock_msg)
|
||||
|
||||
def mock_get(model, id):
|
||||
if model == Plot: return mock_plot
|
||||
if model == Message: return mock_msg
|
||||
if model == Conversation: return mock_conv
|
||||
return None
|
||||
|
||||
mock_session.get.side_effect = mock_get
|
||||
|
||||
response = client.get("/artifacts/plots/p1", headers=auth_header)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"fake-image-data"
|
||||
assert response.headers["content-type"] == "image/png"
|
||||
9
backend/tests/api/test_api_main.py
Normal file
9
backend/tests/api/test_api_main.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi.testclient import TestClient
|
||||
from ea_chatbot.api.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
def test_health_check():
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
63
backend/tests/api/test_persistence.py
Normal file
63
backend/tests/api/test_persistence.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from ea_chatbot.api.main import app
|
||||
from ea_chatbot.api.dependencies import get_current_user
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.api.utils import create_access_token
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
return User(id="user-123", username="test@example.com", display_name="Test User")
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header(mock_user):
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
token = create_access_token(data={"sub": mock_user.id})
|
||||
yield {"Authorization": f"Bearer {token}"}
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
def test_persistence_integration_success(auth_header, mock_user):
|
||||
"""Test that messages and plots are persisted correctly during streaming."""
|
||||
mock_events = [
|
||||
{"event": "on_chat_model_stream", "name": "summarizer", "data": {"chunk": "Final answer"}},
|
||||
{"event": "on_chain_end", "name": "summarizer", "data": {"output": {"messages": [{"content": "Final answer"}]}}},
|
||||
{"event": "on_chain_end", "name": "summarize_conversation", "data": {"output": {"summary": "New summary"}}}
|
||||
]
|
||||
|
||||
async def mock_astream_events(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \
|
||||
patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \
|
||||
patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm:
|
||||
|
||||
mock_cp.return_value.__aenter__.return_value = AsyncMock()
|
||||
|
||||
# Mock session and DB objects
|
||||
mock_session = MagicMock()
|
||||
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
||||
mock_conv = Conversation(id="t1", user_id=mock_user.id)
|
||||
mock_session.get.return_value = mock_conv
|
||||
|
||||
# Act
|
||||
with client.stream("POST", "/chat/stream",
|
||||
json={"message": "persistence test", "thread_id": "t1"},
|
||||
headers=auth_header) as response:
|
||||
assert response.status_code == 200
|
||||
list(response.iter_lines()) # Consume stream
|
||||
|
||||
# Assertions
|
||||
# 1. User message should be saved immediately
|
||||
mock_hm.add_message.assert_any_call("t1", "user", "persistence test")
|
||||
|
||||
# 2. Assistant message should be saved at the end
|
||||
mock_hm.add_message.assert_any_call("t1", "assistant", "Final answer", plots=[])
|
||||
|
||||
# 3. Summary should be updated
|
||||
mock_hm.update_conversation_summary.assert_called_once_with("t1", "New summary")
|
||||
51
backend/tests/api/test_utils.py
Normal file
51
backend/tests/api/test_utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from datetime import timedelta
|
||||
from ea_chatbot.api.utils import create_access_token, decode_access_token, convert_to_json_compatible
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
def test_create_and_decode_access_token():
|
||||
"""Test that a token can be created and then decoded."""
|
||||
data = {"sub": "test@example.com", "user_id": "123"}
|
||||
token = create_access_token(data)
|
||||
|
||||
decoded = decode_access_token(token)
|
||||
assert decoded["sub"] == data["sub"]
|
||||
assert decoded["user_id"] == data["user_id"]
|
||||
assert "exp" in decoded
|
||||
|
||||
def test_decode_invalid_token():
|
||||
"""Test that an invalid token returns None."""
|
||||
assert decode_access_token("invalid-token") is None
|
||||
|
||||
def test_expired_token():
|
||||
"""Test that an expired token returns None."""
|
||||
data = {"sub": "test@example.com"}
|
||||
# Create a token that expired 1 minute ago
|
||||
token = create_access_token(data, expires_delta=timedelta(minutes=-1))
|
||||
|
||||
assert decode_access_token(token) is None
|
||||
|
||||
def test_convert_to_json_compatible_complex_message():
|
||||
"""Test that list-based message content is handled correctly."""
|
||||
# Mock a message with list-based content (blocks)
|
||||
msg = AIMessage(content=[
|
||||
{"type": "text", "text": "Hello "},
|
||||
{"type": "text", "text": "world!"},
|
||||
{"type": "other", "data": "ignore me"}
|
||||
])
|
||||
|
||||
result = convert_to_json_compatible(msg)
|
||||
assert result["content"] == "Hello world!"
|
||||
assert result["type"] == "ai"
|
||||
|
||||
def test_convert_to_json_compatible_message_with_text_prop():
|
||||
"""Test that .text property is prioritized if available."""
|
||||
# Using a MagicMock to simulate the property safely
|
||||
from unittest.mock import MagicMock
|
||||
msg = MagicMock(spec=AIMessage)
|
||||
msg.content = "Raw content"
|
||||
msg.text = "Just the text"
|
||||
msg.type = "ai"
|
||||
msg.additional_kwargs = {}
|
||||
|
||||
result = convert_to_json_compatible(msg)
|
||||
assert result["content"] == "Just the text"
|
||||
49
backend/tests/graph/test_checkpoint.py
Normal file
49
backend/tests/graph/test_checkpoint.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from ea_chatbot.graph.checkpoint import get_checkpointer
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_checkpointer_initialization():
|
||||
"""Test that the checkpointer setup is called."""
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = MagicMock() # Changed from AsyncMock to MagicMock
|
||||
mock_pool.closed = True
|
||||
mock_pool.open = AsyncMock() # Ensure open is awaitable
|
||||
|
||||
# Setup mock_pool.connection() to return an async context manager
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool.connection.return_value = mock_cm
|
||||
|
||||
# We need to patch the get_pool function and AsyncPostgresSaver in the module
|
||||
with MagicMock() as mock_get_pool, \
|
||||
MagicMock() as mock_saver_class:
|
||||
import ea_chatbot.graph.checkpoint as checkpoint
|
||||
mock_get_pool.return_value = mock_pool
|
||||
|
||||
# Mock AsyncPostgresSaver class
|
||||
mock_saver_instance = AsyncMock()
|
||||
mock_saver_class.return_value = mock_saver_instance
|
||||
|
||||
original_get_pool = checkpoint.get_pool
|
||||
checkpoint.get_pool = mock_get_pool
|
||||
|
||||
# Patch AsyncPostgresSaver where it's imported in checkpoint.py
|
||||
import langgraph.checkpoint.postgres.aio as pg_aio
|
||||
original_saver = checkpoint.AsyncPostgresSaver
|
||||
checkpoint.AsyncPostgresSaver = mock_saver_class
|
||||
|
||||
try:
|
||||
async with get_checkpointer() as checkpointer:
|
||||
assert checkpointer == mock_saver_instance
|
||||
# Verify setup was called
|
||||
mock_saver_instance.setup.assert_called_once()
|
||||
|
||||
# Verify pool was opened
|
||||
mock_pool.open.assert_called_once()
|
||||
# Verify connection was requested
|
||||
mock_pool.connection.assert_called_once()
|
||||
finally:
|
||||
checkpoint.get_pool = original_get_pool
|
||||
checkpoint.AsyncPostgresSaver = original_saver
|
||||
90
backend/tests/test_app.py
Normal file
90
backend/tests/test_app.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# Ensure src is in python path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_history_manager():
|
||||
"""Globally mock HistoryManager to avoid DB calls during AppTest."""
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
instance.create_conversation.return_value = MagicMock(id="conv_123")
|
||||
instance.get_conversations.return_value = []
|
||||
instance.get_messages.return_value = []
|
||||
instance.add_message.return_value = MagicMock()
|
||||
instance.update_conversation_summary.return_value = MagicMock()
|
||||
yield instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_stream():
|
||||
with patch("ea_chatbot.graph.workflow.app.stream") as mock_stream:
|
||||
# Mock events from app.stream
|
||||
mock_stream.return_value = [
|
||||
{"query_analyzer": {"next_action": "research"}},
|
||||
{"researcher": {"messages": [AIMessage(content="Research result")]}}
|
||||
]
|
||||
yield mock_stream
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = MagicMock()
|
||||
user.id = "test_id"
|
||||
user.username = "test@example.com"
|
||||
user.display_name = "Test User"
|
||||
return user
|
||||
|
||||
def test_app_initial_state(mock_app_stream, mock_user):
|
||||
"""Test that the app initializes with the correct title and empty history."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
|
||||
# Simulate logged-in user
|
||||
at.session_state["user"] = mock_user
|
||||
|
||||
at.run()
|
||||
|
||||
assert not at.exception
|
||||
assert at.title[0].value == "🗳️ Election Analytics Chatbot"
|
||||
|
||||
# Check session state initialization
|
||||
assert "messages" in at.session_state
|
||||
assert len(at.session_state["messages"]) == 0
|
||||
|
||||
def test_app_dev_mode_toggle(mock_app_stream, mock_user):
|
||||
"""Test that the dev mode toggle exists in the sidebar."""
|
||||
with patch.dict(os.environ, {"DEV_MODE": "false"}):
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Check for sidebar toggle (checkbox)
|
||||
assert len(at.sidebar.checkbox) > 0
|
||||
dev_mode_toggle = at.sidebar.checkbox[0]
|
||||
assert dev_mode_toggle.label == "Dev Mode"
|
||||
assert dev_mode_toggle.value is False
|
||||
|
||||
def test_app_graph_execution_streaming(mock_app_stream, mock_user, mock_history_manager):
|
||||
"""Test that entering a prompt triggers the graph stream and displays response."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.session_state["user"] = mock_user
|
||||
at.run()
|
||||
|
||||
# Input a question
|
||||
at.chat_input[0].set_value("Test question").run()
|
||||
|
||||
# Verify graph stream was called
|
||||
assert mock_app_stream.called
|
||||
|
||||
# Message should be added to history
|
||||
assert len(at.session_state["messages"]) == 2
|
||||
assert at.session_state["messages"][0]["role"] == "user"
|
||||
assert at.session_state["messages"][1]["role"] == "assistant"
|
||||
assert "Research result" in at.session_state["messages"][1]["content"]
|
||||
|
||||
# Verify history manager was used
|
||||
assert mock_history_manager.create_conversation.called
|
||||
assert mock_history_manager.add_message.called
|
||||
83
backend/tests/test_app_auth.py
Normal file
83
backend/tests/test_app_auth.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
from ea_chatbot.auth import AuthType
|
||||
|
||||
@pytest.fixture
|
||||
def mock_history_manager_instance():
|
||||
# We need to patch before the AppTest loads the module
|
||||
with patch("ea_chatbot.history.manager.HistoryManager") as mock_cls:
|
||||
instance = mock_cls.return_value
|
||||
yield instance
|
||||
|
||||
def test_auth_ui_flow_step1_to_password(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2a (password) for LOCAL user."""
|
||||
# Patch BEFORE creating AppTest
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_password"
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
assert at.session_state["login_step"] == "email"
|
||||
at.text_input[0].set_value("local@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to password step
|
||||
assert at.session_state["login_step"] == "login_password"
|
||||
assert at.session_state["login_email"] == "local@example.com"
|
||||
assert "Welcome back" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_register(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2b (registration) for NEW user."""
|
||||
mock_history_manager_instance.get_user.return_value = None
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("new@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to registration step
|
||||
assert at.session_state["login_step"] == "register_details"
|
||||
assert at.session_state["login_email"] == "new@example.com"
|
||||
assert "Create an account" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_step1_to_oidc(mock_history_manager_instance):
|
||||
"""Test UI transition from Step 1 (email) to Step 2c (OIDC) for OIDC user."""
|
||||
# Mock history_manager.get_user to return a user WITHOUT a password
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None
|
||||
mock_history_manager_instance.get_user.return_value = mock_user
|
||||
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
at.run()
|
||||
|
||||
# Step 1: Identification
|
||||
at.text_input[0].set_value("oidc@example.com")
|
||||
|
||||
at.button[0].click().run()
|
||||
|
||||
# Verify transition to OIDC step
|
||||
assert at.session_state["login_step"] == "oidc_login"
|
||||
assert at.session_state["login_email"] == "oidc@example.com"
|
||||
assert "configured for Single Sign-On" in at.info[0].value
|
||||
|
||||
def test_auth_ui_flow_back_button(mock_history_manager_instance):
|
||||
"""Test that the 'Back' button returns to Step 1."""
|
||||
at = AppTest.from_file("src/ea_chatbot/app.py")
|
||||
# Simulate being on Step 2a
|
||||
at.session_state["login_step"] = "login_password"
|
||||
at.session_state["login_email"] = "local@example.com"
|
||||
at.run()
|
||||
|
||||
# Click Back (index 1 in Step 2a)
|
||||
at.button[1].click().run()
|
||||
|
||||
# Verify return to email step
|
||||
assert at.session_state["login_step"] == "email"
|
||||
43
backend/tests/test_auth.py
Normal file
43
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_oidc_client_initialization(mock_oauth):
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
assert client.oauth_session is not None
|
||||
|
||||
@patch("ea_chatbot.auth.requests")
|
||||
@patch("ea_chatbot.auth.OAuth2Session")
|
||||
def test_get_login_url(mock_oauth_cls, mock_requests):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_oauth_cls.return_value = mock_session
|
||||
|
||||
# Mock metadata response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://test.server/auth",
|
||||
"token_endpoint": "https://test.server/token",
|
||||
"userinfo_endpoint": "https://test.server/userinfo"
|
||||
}
|
||||
mock_requests.get.return_value = mock_response
|
||||
|
||||
# Mock authorization url generation
|
||||
mock_session.create_authorization_url.return_value = ("https://test.server/auth?response_type=code", "state")
|
||||
|
||||
client = OIDCClient(
|
||||
client_id="test_id",
|
||||
client_secret="test_secret",
|
||||
server_metadata_url="https://test.server/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://test.server/auth?response_type=code"
|
||||
# Verify metadata was fetched via requests
|
||||
mock_requests.get.assert_called_with("https://test.server/.well-known/openid-configuration")
|
||||
49
backend/tests/test_auth_flow.py
Normal file
49
backend/tests/test_auth_flow.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.auth import get_user_auth_type, AuthType
|
||||
|
||||
# Mocks
|
||||
@pytest.fixture
|
||||
def mock_history_manager():
|
||||
return MagicMock(spec=HistoryManager)
|
||||
|
||||
def test_auth_flow_existing_local_user(mock_history_manager):
|
||||
"""Test that an existing user with a password returns LOCAL auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = "hashed_secret"
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("test@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.LOCAL
|
||||
mock_history_manager.get_user.assert_called_once_with("test@example.com")
|
||||
|
||||
def test_auth_flow_existing_oidc_user(mock_history_manager):
|
||||
"""Test that an existing user WITHOUT a password returns OIDC auth type."""
|
||||
# Setup
|
||||
mock_user = MagicMock()
|
||||
mock_user.password_hash = None # No password implies OIDC
|
||||
mock_history_manager.get_user.return_value = mock_user
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("sso@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.OIDC
|
||||
mock_history_manager.get_user.assert_called_once_with("sso@example.com")
|
||||
|
||||
def test_auth_flow_new_user(mock_history_manager):
|
||||
"""Test that a non-existent user returns NEW auth type."""
|
||||
# Setup
|
||||
mock_history_manager.get_user.return_value = None
|
||||
|
||||
# Execute
|
||||
auth_type = get_user_auth_type("new@example.com", mock_history_manager)
|
||||
|
||||
# Verify
|
||||
assert auth_type == AuthType.NEW
|
||||
mock_history_manager.get_user.assert_called_once_with("new@example.com")
|
||||
62
backend/tests/test_coder.py
Normal file
62
backend/tests/test_coder.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.coder import coder_node
|
||||
from ea_chatbot.graph.nodes.error_corrector import error_corrector_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me results for New Jersey",
|
||||
"plan": "Step 1: Load data\nStep 2: Filter by NJ",
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {},
|
||||
"next_action": "plan"
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
def test_coder_node(mock_get_summary, mock_get_llm, mock_state):
|
||||
"""Test coder node generates code from plan."""
|
||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
mock_response = CodeGenerationResponse(
|
||||
code="import pandas as pd\nprint('Hello')",
|
||||
explanation="Generated code"
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = coder_node(mock_state)
|
||||
|
||||
assert "code" in result
|
||||
assert "import pandas as pd" in result["code"]
|
||||
assert "error" in result
|
||||
assert result["error"] is None
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.error_corrector.get_llm_model")
|
||||
def test_error_corrector_node(mock_get_llm, mock_state):
|
||||
"""Test error corrector node fixes code."""
|
||||
mock_state["code"] = "import pandas as pd\nprint(undefined_var)"
|
||||
mock_state["error"] = "NameError: name 'undefined_var' is not defined"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import CodeGenerationResponse
|
||||
mock_response = CodeGenerationResponse(
|
||||
code="import pandas as pd\nprint('Defined')",
|
||||
explanation="Fixed variable"
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_response
|
||||
|
||||
result = error_corrector_node(mock_state)
|
||||
|
||||
assert "code" in result
|
||||
assert "print('Defined')" in result["code"]
|
||||
assert result["error"] is None
|
||||
47
backend/tests/test_config.py
Normal file
47
backend/tests/test_config.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from ea_chatbot.config import Settings, LLMConfig
|
||||
|
||||
def test_default_settings():
|
||||
"""Test that default settings are loaded correctly."""
|
||||
settings = Settings()
|
||||
|
||||
# Check default config for query analyzer
|
||||
assert isinstance(settings.query_analyzer_llm, LLMConfig)
|
||||
assert settings.query_analyzer_llm.provider == "openai"
|
||||
assert settings.query_analyzer_llm.model == "gpt-5-mini"
|
||||
assert settings.query_analyzer_llm.temperature == 0.0
|
||||
|
||||
# Check default config for planner
|
||||
assert isinstance(settings.planner_llm, LLMConfig)
|
||||
assert settings.planner_llm.provider == "openai"
|
||||
assert settings.planner_llm.model == "gpt-5-mini"
|
||||
|
||||
def test_env_override(monkeypatch):
|
||||
"""Test that environment variables override defaults."""
|
||||
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
|
||||
monkeypatch.setenv("QUERY_ANALYZER_LLM__TEMPERATURE", "0.7")
|
||||
|
||||
settings = Settings()
|
||||
assert settings.query_analyzer_llm.model == "gpt-3.5-turbo"
|
||||
assert settings.query_analyzer_llm.temperature == 0.7
|
||||
|
||||
def test_provider_specific_params():
|
||||
"""Test that provider specific parameters can be set."""
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="o1-preview",
|
||||
provider_specific={"reasoning_effort": "high"}
|
||||
)
|
||||
assert config.provider_specific["reasoning_effort"] == "high"
|
||||
|
||||
def test_oidc_settings(monkeypatch):
|
||||
"""Test OIDC settings configuration."""
|
||||
monkeypatch.setenv("OIDC_CLIENT_ID", "test_client_id")
|
||||
monkeypatch.setenv("OIDC_CLIENT_SECRET", "test_client_secret")
|
||||
monkeypatch.setenv("OIDC_SERVER_METADATA_URL", "https://test.server/.well-known/openid-configuration")
|
||||
|
||||
settings = Settings()
|
||||
assert settings.oidc_client_id == "test_client_id"
|
||||
assert settings.oidc_client_secret == "test_client_secret"
|
||||
assert settings.oidc_server_metadata_url == "https://test.server/.well-known/openid-configuration"
|
||||
56
backend/tests/test_conversation_summary.py
Normal file
56
backend/tests/test_conversation_summary.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.summarize_conversation import summarize_conversation_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
],
|
||||
"summary": "The user is asking about 2024 election results."
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
|
||||
def test_summarize_conversation_node_updates_summary(mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
# Mock LLM response for updating summary
|
||||
mock_llm_instance.invoke.return_value = AIMessage(content="Updated summary including NJ results.")
|
||||
|
||||
# Add new messages to simulate a completed turn
|
||||
mock_state_with_history["messages"].extend([
|
||||
HumanMessage(content="What about in New Jersey?"),
|
||||
AIMessage(content="In New Jersey, the 2024 results were...")
|
||||
])
|
||||
|
||||
result = summarize_conversation_node(mock_state_with_history)
|
||||
|
||||
assert "summary" in result
|
||||
assert result["summary"] == "Updated summary including NJ results."
|
||||
|
||||
# Verify LLM was called with the correct context
|
||||
call_messages = mock_llm_instance.invoke.call_args[0][0]
|
||||
# Should include current summary and last turn messages
|
||||
assert "Current summary: The user is asking about 2024 election results." in call_messages[0].content
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model")
|
||||
def test_summarize_conversation_node_initial_summary(mock_get_llm):
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="Hi"),
|
||||
AIMessage(content="Hello! How can I help you today?")
|
||||
],
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_llm_instance.invoke.return_value = AIMessage(content="Initial greeting.")
|
||||
|
||||
result = summarize_conversation_node(state)
|
||||
|
||||
assert result["summary"] == "Initial greeting."
|
||||
195
backend/tests/test_database_inspection.py
Normal file
195
backend/tests/test_database_inspection.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.utils.database_inspection import get_primary_key, inspect_db_table, get_data_summary
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_client():
|
||||
mock_client = MagicMock()
|
||||
mock_client.settings = {"table": "test_table"}
|
||||
return mock_client
|
||||
|
||||
def test_get_primary_key(mock_db_client):
|
||||
"""Test dynamic primary key discovery."""
|
||||
# Mock response for primary key query
|
||||
mock_df = pd.DataFrame({"column_name": ["my_pk"]})
|
||||
mock_db_client.query_df.return_value = mock_df
|
||||
|
||||
pk = get_primary_key(mock_db_client, "test_table")
|
||||
|
||||
assert pk == "my_pk"
|
||||
# Verify the query was called (at least once)
|
||||
assert mock_db_client.query_df.called
|
||||
|
||||
def test_inspect_db_table_improved(mock_db_client, tmp_path):
|
||||
"""Test improved inspect_db_table with cardinality and sampling."""
|
||||
data_dir = str(tmp_path)
|
||||
|
||||
# 1. Mock columns and types
|
||||
columns_df = pd.DataFrame({
|
||||
"column_name": ["id", "category", "count"],
|
||||
"data_type": ["integer", "text", "integer"]
|
||||
})
|
||||
|
||||
# 2. Mock row count
|
||||
total_rows_df = pd.DataFrame([{"count": 100}])
|
||||
|
||||
# 3. Mock PK discovery
|
||||
pk_df = pd.DataFrame({"column_name": ["id"]})
|
||||
|
||||
# 4. Mock stats for columns
|
||||
# We need to handle multiple calls to query_df
|
||||
def side_effect(query):
|
||||
if "information_schema.columns" in query:
|
||||
return columns_df
|
||||
if "COUNT(*)" in query:
|
||||
return total_rows_df
|
||||
if "information_schema.key_column_usage" in query:
|
||||
return pk_df
|
||||
|
||||
# Category stats
|
||||
if 'COUNT("category")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "category")' in query:
|
||||
return pd.DataFrame([{"count": 5}])
|
||||
if 'SELECT DISTINCT "category"' in query:
|
||||
return pd.DataFrame({"category": ["A", "B", "C", "D", "E"]})
|
||||
|
||||
# Count stats
|
||||
if 'COUNT("count")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "count")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'AVG("count")' in query:
|
||||
return pd.DataFrame([{"avg": 10.0, "min": 1, "max": 20}])
|
||||
|
||||
# ID stats (fix for IndexError)
|
||||
if 'COUNT("id")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "id")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'AVG("id")' in query:
|
||||
return pd.DataFrame([{"avg": 50.0, "min": 1, "max": 100}])
|
||||
|
||||
return pd.DataFrame()
|
||||
mock_db_client.query_df.side_effect = side_effect
|
||||
|
||||
# Run inspection
|
||||
inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
# Read summary to verify
|
||||
summary = get_data_summary(data_dir)
|
||||
assert summary is not None
|
||||
assert "test_table" in summary
|
||||
assert "category" in summary
|
||||
assert "distinct_values" in summary
|
||||
assert "unique_count: 5" in summary
|
||||
assert "- A" in summary
|
||||
assert "- E" in summary
|
||||
assert "primary_key: id" in summary
|
||||
|
||||
def test_get_data_summary_none(tmp_path):
|
||||
"""Test get_data_summary when file doesn't exist."""
|
||||
assert get_data_summary(str(tmp_path)) is None
|
||||
|
||||
def test_inspect_db_table_temporal(mock_db_client, tmp_path):
|
||||
"""Test inspect_db_table with temporal columns."""
|
||||
data_dir = str(tmp_path)
|
||||
|
||||
columns_df = pd.DataFrame({
|
||||
"column_name": ["created_at"],
|
||||
"data_type": ["timestamp without time zone"]
|
||||
})
|
||||
total_rows_df = pd.DataFrame([{"count": 50}])
|
||||
pk_df = pd.DataFrame() # No PK
|
||||
|
||||
def side_effect(query):
|
||||
if "information_schema.columns" in query:
|
||||
return columns_df
|
||||
if "COUNT(*)" in query:
|
||||
return total_rows_df
|
||||
if "information_schema.key_column_usage" in query:
|
||||
return pk_df
|
||||
if 'COUNT("created_at")' in query:
|
||||
return pd.DataFrame([{"count": 50}])
|
||||
if 'COUNT(DISTINCT "created_at")' in query:
|
||||
return pd.DataFrame([{"count": 50}])
|
||||
if 'MIN("created_at")' in query:
|
||||
return pd.DataFrame([{"min": "2023-01-01", "max": "2023-12-31"}])
|
||||
return pd.DataFrame()
|
||||
|
||||
mock_db_client.query_df.side_effect = side_effect
|
||||
|
||||
inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
summary = get_data_summary(data_dir)
|
||||
assert "created_at" in summary
|
||||
assert "min: '2023-01-01'" in summary
|
||||
assert "max: '2023-12-31'" in summary
|
||||
|
||||
def test_inspect_db_table_high_cardinality(mock_db_client, tmp_path):
|
||||
"""Test inspect_db_table with high cardinality categorical column (no sample values)."""
|
||||
data_dir = str(tmp_path)
|
||||
|
||||
columns_df = pd.DataFrame({
|
||||
"column_name": ["user_id"],
|
||||
"data_type": ["text"]
|
||||
})
|
||||
total_rows_df = pd.DataFrame([{"count": 100}])
|
||||
pk_df = pd.DataFrame()
|
||||
|
||||
def side_effect(query):
|
||||
if "information_schema.columns" in query:
|
||||
return columns_df
|
||||
if "COUNT(*)" in query:
|
||||
return total_rows_df
|
||||
if "information_schema.key_column_usage" in query:
|
||||
return pk_df
|
||||
if 'COUNT("user_id")' in query:
|
||||
return pd.DataFrame([{"count": 100}])
|
||||
if 'COUNT(DISTINCT "user_id")' in query:
|
||||
# High cardinality > 20
|
||||
return pd.DataFrame([{"count": 50}])
|
||||
return pd.DataFrame()
|
||||
|
||||
mock_db_client.query_df.side_effect = side_effect
|
||||
|
||||
inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
summary = get_data_summary(data_dir)
|
||||
assert "user_id" in summary
|
||||
assert "unique_count: 50" in summary
|
||||
# Should NOT have distinct_values
|
||||
assert "distinct_values" not in summary
|
||||
|
||||
def test_inspect_db_table_checksum_skip(mock_db_client, tmp_path):
|
||||
"""Test that inspection is skipped if checksum matches."""
|
||||
data_dir = str(tmp_path)
|
||||
table = "test_table"
|
||||
|
||||
# 1. Create a fake checksum file
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
# Checksum is md5 of "ins|upd|del". Let's say mock returns "my_hash"
|
||||
|
||||
# Mock checksum query
|
||||
mock_db_client.query_df.return_value = pd.DataFrame([{"dml_hash": "my_hash"}])
|
||||
|
||||
# Write existing checksum
|
||||
with open(os.path.join(data_dir, "checksum"), "w") as f:
|
||||
f.write(f"{table}:my_hash\n")
|
||||
|
||||
# Write existing inspection
|
||||
with open(os.path.join(data_dir, "inspection.yaml"), "w") as f:
|
||||
f.write(f"{table}: {{ existing: true }}")
|
||||
|
||||
# Run inspection
|
||||
result = inspect_db_table(mock_db_client, data_dir=data_dir)
|
||||
|
||||
# Should return existing content
|
||||
assert "existing: true" in result
|
||||
# query_df should be called ONLY for checksum (once)
|
||||
# verify count of calls?
|
||||
# Logic: 1 call for checksum. If match, return.
|
||||
assert mock_db_client.query_df.call_count == 1
|
||||
|
||||
123
backend/tests/test_executor.py
Normal file
123
backend/tests/test_executor.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import MagicMock, patch
|
||||
from matplotlib.figure import Figure
|
||||
from ea_chatbot.graph.nodes.executor import executor_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
with patch("ea_chatbot.graph.nodes.executor.Settings") as MockSettings:
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.db_host = "localhost"
|
||||
mock_settings_instance.db_port = 5432
|
||||
mock_settings_instance.db_user = "user"
|
||||
mock_settings_instance.db_pswd = "pass"
|
||||
mock_settings_instance.db_name = "test_db"
|
||||
mock_settings_instance.db_table = "test_table"
|
||||
MockSettings.return_value = mock_settings_instance
|
||||
yield mock_settings_instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_client():
|
||||
with patch("ea_chatbot.graph.nodes.executor.DBClient") as MockDBClient:
|
||||
mock_client_instance = MagicMock()
|
||||
MockDBClient.return_value = mock_client_instance
|
||||
yield mock_client_instance
|
||||
|
||||
def test_executor_node_success_simple_print(mock_settings, mock_db_client):
|
||||
"""Test executing simple code that prints to stdout."""
|
||||
state = {
|
||||
"code": "print('Hello, World!')",
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "code_output" in result
|
||||
assert "Hello, World!" in result["code_output"]
|
||||
assert result["error"] is None
|
||||
assert result["plots"] == []
|
||||
assert result["dfs"] == {}
|
||||
|
||||
def test_executor_node_success_dataframe(mock_settings, mock_db_client):
|
||||
"""Test executing code that creates a DataFrame."""
|
||||
code = """
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
|
||||
print(df)
|
||||
"""
|
||||
state = {
|
||||
"code": code,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "code_output" in result
|
||||
assert "a b" in result["code_output"] # Check part of DF string representation
|
||||
assert "dfs" in result
|
||||
assert "df" in result["dfs"]
|
||||
assert isinstance(result["dfs"]["df"], pd.DataFrame)
|
||||
|
||||
def test_executor_node_success_plot(mock_settings, mock_db_client):
|
||||
"""Test executing code that generates a plot."""
|
||||
code = """
|
||||
import matplotlib.pyplot as plt
|
||||
fig = plt.figure()
|
||||
plots.append(fig)
|
||||
print('Plot generated')
|
||||
"""
|
||||
state = {
|
||||
"code": code,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "Plot generated" in result["code_output"]
|
||||
assert "plots" in result
|
||||
assert len(result["plots"]) == 1
|
||||
assert isinstance(result["plots"][0], Figure)
|
||||
|
||||
def test_executor_node_error_syntax(mock_settings, mock_db_client):
|
||||
"""Test executing code with a syntax error."""
|
||||
state = {
|
||||
"code": "print('Hello World", # Missing closing quote
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert result["error"] is not None
|
||||
assert "SyntaxError" in result["error"]
|
||||
|
||||
def test_executor_node_error_runtime(mock_settings, mock_db_client):
|
||||
"""Test executing code with a runtime error."""
|
||||
state = {
|
||||
"code": "print(1 / 0)",
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert result["error"] is not None
|
||||
assert "ZeroDivisionError" in result["error"]
|
||||
|
||||
def test_executor_node_no_code(mock_settings, mock_db_client):
|
||||
"""Test handling when no code is provided."""
|
||||
state = {
|
||||
"code": None,
|
||||
"question": "test",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
result = executor_node(state)
|
||||
|
||||
assert "error" in result
|
||||
assert "No code provided" in result["error"]
|
||||
77
backend/tests/test_helpers.py
Normal file
77
backend/tests/test_helpers.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.utils.helpers import merge_agent_state
|
||||
|
||||
def test_merge_agent_state_list_accumulation():
|
||||
"""Verify that list fields (messages, plots) are accumulated (appended)."""
|
||||
current_state = {
|
||||
"messages": [HumanMessage(content="hello")],
|
||||
"plots": ["plot1"]
|
||||
}
|
||||
update = {
|
||||
"messages": [AIMessage(content="hi")],
|
||||
"plots": ["plot2"]
|
||||
}
|
||||
|
||||
merged = merge_agent_state(current_state, update)
|
||||
|
||||
assert len(merged["messages"]) == 2
|
||||
assert merged["messages"][0].content == "hello"
|
||||
assert merged["messages"][1].content == "hi"
|
||||
|
||||
assert len(merged["plots"]) == 2
|
||||
assert merged["plots"] == ["plot1", "plot2"]
|
||||
|
||||
def test_merge_agent_state_dict_update():
|
||||
"""Verify that dictionary fields (dfs) are updated (shallow merge)."""
|
||||
current_state = {
|
||||
"dfs": {"df1": "data1"}
|
||||
}
|
||||
update = {
|
||||
"dfs": {"df2": "data2"}
|
||||
}
|
||||
|
||||
merged = merge_agent_state(current_state, update)
|
||||
|
||||
assert merged["dfs"] == {"df1": "data1", "df2": "data2"}
|
||||
|
||||
# Verify overwrite within dict
|
||||
update_overwrite = {
|
||||
"dfs": {"df1": "new_data1"}
|
||||
}
|
||||
merged_overwrite = merge_agent_state(merged, update_overwrite)
|
||||
assert merged_overwrite["dfs"] == {"df1": "new_data1", "df2": "data2"}
|
||||
|
||||
def test_merge_agent_state_standard_overwrite():
|
||||
"""Verify that standard fields are overwritten."""
|
||||
current_state = {
|
||||
"question": "old question",
|
||||
"next_action": "old action",
|
||||
"plan": "old plan"
|
||||
}
|
||||
update = {
|
||||
"question": "new question",
|
||||
"next_action": "new action",
|
||||
"plan": "new plan"
|
||||
}
|
||||
|
||||
merged = merge_agent_state(current_state, update)
|
||||
|
||||
assert merged["question"] == "new question"
|
||||
assert merged["next_action"] == "new action"
|
||||
assert merged["plan"] == "new plan"
|
||||
|
||||
def test_merge_agent_state_none_handling():
|
||||
"""Verify that None updates or missing keys in update don't break things."""
|
||||
current_state = {
|
||||
"question": "test",
|
||||
"messages": ["msg1"]
|
||||
}
|
||||
|
||||
# Empty update
|
||||
assert merge_agent_state(current_state, {}) == current_state
|
||||
|
||||
# Update with None value for overwritable field
|
||||
merged_none = merge_agent_state(current_state, {"question": None})
|
||||
assert merged_none["question"] is None
|
||||
assert merged_none["messages"] == ["msg1"]
|
||||
145
backend/tests/test_history_manager.py
Normal file
145
backend/tests/test_history_manager.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import pytest
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.config import Settings
|
||||
from sqlalchemy import delete
|
||||
|
||||
@pytest.fixture
|
||||
def history_manager():
|
||||
settings = Settings()
|
||||
manager = HistoryManager(settings.history_db_url)
|
||||
# Clean up tables before tests (order matters because of foreign keys)
|
||||
with manager.get_session() as session:
|
||||
session.execute(delete(Plot))
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Conversation))
|
||||
session.execute(delete(User))
|
||||
return manager
|
||||
|
||||
def test_history_manager_initialization(history_manager):
|
||||
assert history_manager.engine is not None
|
||||
assert history_manager.SessionLocal is not None
|
||||
|
||||
def test_history_manager_session_context(history_manager):
|
||||
with history_manager.get_session() as session:
|
||||
assert session is not None
|
||||
|
||||
def test_get_user_not_found(history_manager):
|
||||
user = history_manager.get_user("nonexistent@example.com")
|
||||
assert user is None
|
||||
|
||||
def test_authenticate_user_success(history_manager):
|
||||
email = "test@example.com"
|
||||
password = "secretpassword"
|
||||
history_manager.create_user(email=email, password=password)
|
||||
|
||||
user = history_manager.authenticate_user(email, password)
|
||||
assert user is not None
|
||||
assert user.username == email
|
||||
|
||||
def test_authenticate_user_failure(history_manager):
|
||||
email = "test@example.com"
|
||||
history_manager.create_user(email=email, password="correctpassword")
|
||||
|
||||
user = history_manager.authenticate_user(email, "wrongpassword")
|
||||
assert user is None
|
||||
|
||||
def test_sync_user_from_oidc_new_user(history_manager):
|
||||
user = history_manager.sync_user_from_oidc(
|
||||
email="new@example.com",
|
||||
display_name="New User"
|
||||
)
|
||||
assert user is not None
|
||||
assert user.username == "new@example.com"
|
||||
assert user.display_name == "New User"
|
||||
|
||||
def test_sync_user_from_oidc_existing_user(history_manager):
|
||||
# First sync
|
||||
history_manager.sync_user_from_oidc(
|
||||
email="existing@example.com",
|
||||
display_name="First Name"
|
||||
)
|
||||
# Second sync should update or return same user
|
||||
user = history_manager.sync_user_from_oidc(
|
||||
email="existing@example.com",
|
||||
display_name="Updated Name"
|
||||
)
|
||||
assert user.display_name == "Updated Name"
|
||||
|
||||
# --- Conversation Management Tests ---
|
||||
|
||||
@pytest.fixture
|
||||
def user(history_manager):
|
||||
return history_manager.create_user(email="conv_user@example.com")
|
||||
|
||||
def test_create_conversation(history_manager, user):
|
||||
conv = history_manager.create_conversation(
|
||||
user_id=user.id,
|
||||
data_state="new_jersey",
|
||||
name="Test Chat",
|
||||
summary="A test conversation summary"
|
||||
)
|
||||
assert conv is not None
|
||||
assert conv.name == "Test Chat"
|
||||
assert conv.summary == "A test conversation summary"
|
||||
assert conv.user_id == user.id
|
||||
|
||||
def test_get_conversations(history_manager, user):
|
||||
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C1")
|
||||
history_manager.create_conversation(user_id=user.id, data_state="nj", name="C2")
|
||||
history_manager.create_conversation(user_id=user.id, data_state="ny", name="C3")
|
||||
|
||||
nj_convs = history_manager.get_conversations(user_id=user.id, data_state="nj")
|
||||
assert len(nj_convs) == 2
|
||||
|
||||
ny_convs = history_manager.get_conversations(user_id=user.id, data_state="ny")
|
||||
assert len(ny_convs) == 1
|
||||
|
||||
def test_rename_conversation(history_manager, user):
|
||||
conv = history_manager.create_conversation(user.id, "nj", "Old Name")
|
||||
updated = history_manager.rename_conversation(conv.id, "New Name")
|
||||
assert updated.name == "New Name"
|
||||
|
||||
def test_delete_conversation(history_manager, user):
|
||||
conv = history_manager.create_conversation(user.id, "nj", "To Delete")
|
||||
history_manager.delete_conversation(conv.id)
|
||||
|
||||
convs = history_manager.get_conversations(user.id, "nj")
|
||||
assert len(convs) == 0
|
||||
|
||||
# --- Message Management Tests ---
|
||||
|
||||
@pytest.fixture
|
||||
def conversation(history_manager, user):
|
||||
return history_manager.create_conversation(user.id, "nj", "Msg Test Conv")
|
||||
|
||||
def test_add_message(history_manager, conversation):
|
||||
msg = history_manager.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content="Hello world"
|
||||
)
|
||||
assert msg is not None
|
||||
assert msg.content == "Hello world"
|
||||
assert msg.role == "user"
|
||||
assert msg.conversation_id == conversation.id
|
||||
|
||||
def test_add_message_with_plots(history_manager, conversation):
|
||||
plots_data = [b"fake_plot_1", b"fake_plot_2"]
|
||||
msg = history_manager.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
content="Here are plots",
|
||||
plots=plots_data
|
||||
)
|
||||
assert len(msg.plots) == 2
|
||||
assert msg.plots[0].image_data == b"fake_plot_1"
|
||||
|
||||
def test_get_messages(history_manager, conversation):
|
||||
history_manager.add_message(conversation.id, "user", "Q1")
|
||||
history_manager.add_message(conversation.id, "assistant", "A1")
|
||||
|
||||
messages = history_manager.get_messages(conversation.id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Q1"
|
||||
assert messages[1].content == "A1"
|
||||
55
backend/tests/test_history_models.py
Normal file
55
backend/tests/test_history_models.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
||||
|
||||
# We anticipate these imports will fail initially
|
||||
try:
|
||||
from ea_chatbot.history.models import Base, User, Conversation, Message, Plot
|
||||
except ImportError:
|
||||
Base = None
|
||||
User = None
|
||||
Conversation = None
|
||||
Message = None
|
||||
Plot = None
|
||||
|
||||
def test_models_exist():
|
||||
assert User is not None, "User model not found"
|
||||
assert Conversation is not None, "Conversation model not found"
|
||||
assert Message is not None, "Message model not found"
|
||||
assert Plot is not None, "Plot model not found"
|
||||
assert Base is not None, "Base declarative class not found"
|
||||
|
||||
def test_user_model_columns():
|
||||
if not User: pytest.fail("User model undefined")
|
||||
# Basic check if columns exist (by inspecting __table__.columns)
|
||||
columns = User.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "username" in columns
|
||||
assert "password_hash" in columns
|
||||
assert "display_name" in columns
|
||||
|
||||
def test_conversation_model_columns():
|
||||
if not Conversation: pytest.fail("Conversation model undefined")
|
||||
columns = Conversation.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "user_id" in columns
|
||||
assert "data_state" in columns
|
||||
assert "name" in columns
|
||||
assert "summary" in columns
|
||||
assert "created_at" in columns
|
||||
|
||||
def test_message_model_columns():
|
||||
if not Message: pytest.fail("Message model undefined")
|
||||
columns = Message.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "role" in columns
|
||||
assert "content" in columns
|
||||
assert "conversation_id" in columns
|
||||
assert "created_at" in columns
|
||||
|
||||
def test_plot_model_columns():
|
||||
if not Plot: pytest.fail("Plot model undefined")
|
||||
columns = Plot.__table__.columns.keys()
|
||||
assert "id" in columns
|
||||
assert "message_id" in columns
|
||||
assert "image_data" in columns
|
||||
4
backend/tests/test_history_module.py
Normal file
4
backend/tests/test_history_module.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import ea_chatbot.history
|
||||
|
||||
def test_history_module_importable():
|
||||
assert ea_chatbot.history is not None
|
||||
54
backend/tests/test_llm_factory.py
Normal file
54
backend/tests/test_llm_factory.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.config import LLMConfig
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
def test_get_openai_model(monkeypatch):
|
||||
"""Test creating an OpenAI model."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
temperature=0.5,
|
||||
max_tokens=100
|
||||
)
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert model.model_name == "gpt-4o"
|
||||
assert model.temperature == 0.5
|
||||
assert model.max_tokens == 100
|
||||
|
||||
def test_get_google_model(monkeypatch):
|
||||
"""Test creating a Google model."""
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="google",
|
||||
model="gemini-1.5-pro",
|
||||
temperature=0.7
|
||||
)
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatGoogleGenerativeAI)
|
||||
assert model.model == "gemini-1.5-pro"
|
||||
assert model.temperature == 0.7
|
||||
|
||||
def test_unsupported_provider():
|
||||
"""Test that an unsupported provider raises an error."""
|
||||
config = LLMConfig(provider="unknown", model="test")
|
||||
with pytest.raises(ValueError, match="Unsupported LLM provider: unknown"):
|
||||
get_llm_model(config)
|
||||
|
||||
def test_provider_specific_params(monkeypatch):
|
||||
"""Test passing provider specific params."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="o1-preview",
|
||||
provider_specific={"reasoning_effort": "high"}
|
||||
)
|
||||
# Note: reasoning_effort support depends on the langchain-openai version,
|
||||
# but we check if kwargs are passed.
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
# Check if reasoning_effort was passed correctly
|
||||
assert getattr(model, "reasoning_effort", None) == "high"
|
||||
19
backend/tests/test_llm_factory_callbacks.py
Normal file
19
backend/tests/test_llm_factory_callbacks.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ea_chatbot.config import LLMConfig
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
class MockHandler(BaseCallbackHandler):
|
||||
pass
|
||||
|
||||
def test_get_llm_model_with_callbacks(monkeypatch):
|
||||
"""Test that callbacks are passed to the model."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(provider="openai", model="gpt-4o")
|
||||
handler = MockHandler()
|
||||
|
||||
model = get_llm_model(config, callbacks=[handler])
|
||||
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert handler in model.callbacks
|
||||
44
backend/tests/test_logging_context.py
Normal file
44
backend/tests/test_logging_context.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import pytest
|
||||
import io
|
||||
import json
|
||||
from ea_chatbot.utils.logging import ContextLoggerAdapter, JsonFormatter
|
||||
|
||||
@pytest.fixture
|
||||
def json_log_capture():
|
||||
"""Fixture to capture JSON logs."""
|
||||
log_stream = io.StringIO()
|
||||
logger = logging.getLogger("test_context")
|
||||
logger.setLevel(logging.INFO)
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler(log_stream)
|
||||
handler.setFormatter(JsonFormatter())
|
||||
logger.addHandler(handler)
|
||||
return logger, log_stream
|
||||
|
||||
def test_context_logger_adapter_injects_metadata(json_log_capture):
|
||||
"""Test that ContextLoggerAdapter injects metadata into the log record."""
|
||||
logger, log_stream = json_log_capture
|
||||
adapter = ContextLoggerAdapter(logger, {"run_id": "123", "node_name": "test_node"})
|
||||
|
||||
adapter.info("test message")
|
||||
|
||||
data = json.loads(log_stream.getvalue())
|
||||
assert data["message"] == "test message"
|
||||
assert data["run_id"] == "123"
|
||||
assert data["node_name"] == "test_node"
|
||||
|
||||
def test_context_logger_adapter_override_metadata(json_log_capture):
|
||||
"""Test that extra metadata can be provided during call."""
|
||||
logger, log_stream = json_log_capture
|
||||
adapter = ContextLoggerAdapter(logger, {"run_id": "123"})
|
||||
|
||||
# Passing extra context via the 'extra' parameter in standard logging
|
||||
# Note: Our adapter should handle merging this.
|
||||
adapter.info("test message", extra={"node_name": "dynamic_node"})
|
||||
|
||||
data = json.loads(log_stream.getvalue())
|
||||
assert data["run_id"] == "123"
|
||||
assert data["node_name"] == "dynamic_node"
|
||||
67
backend/tests/test_logging_core.py
Normal file
67
backend/tests/test_logging_core.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
import pytest
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_logging():
|
||||
"""Reset the ea_chatbot logger handlers before each test."""
|
||||
logger = logging.getLogger("ea_chatbot")
|
||||
# Remove all existing handlers
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
yield
|
||||
# Also clean up after test
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_get_logger_singleton():
|
||||
"""Test that get_logger returns the same logger instance for the same name."""
|
||||
logger1 = get_logger("test_logger")
|
||||
logger2 = get_logger("test_logger")
|
||||
assert logger1 is logger2
|
||||
|
||||
def test_get_logger_rich_handler():
|
||||
"""Test that get_logger configures a RichHandler on root."""
|
||||
get_logger("test_rich")
|
||||
root = logging.getLogger("ea_chatbot")
|
||||
# Check if any handler is a RichHandler
|
||||
handler_names = [h.__class__.__name__ for h in root.handlers]
|
||||
assert "RichHandler" in handler_names
|
||||
|
||||
def test_get_logger_level():
|
||||
"""Test that get_logger sets the correct log level."""
|
||||
logger = get_logger("test_level", level="DEBUG")
|
||||
assert logger.level == logging.DEBUG
|
||||
|
||||
def test_json_formatter_serializes_dict():
|
||||
"""Test that JsonFormatter serializes log records to JSON."""
|
||||
from ea_chatbot.utils.logging import JsonFormatter
|
||||
import json
|
||||
|
||||
formatter = JsonFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test", level=logging.INFO, pathname="test.py", lineno=10,
|
||||
msg="test message", args=(), exc_info=None
|
||||
)
|
||||
formatted = formatter.format(record)
|
||||
data = json.loads(formatted)
|
||||
|
||||
assert data["message"] == "test message"
|
||||
assert data["level"] == "INFO"
|
||||
assert "timestamp" in data
|
||||
|
||||
def test_get_logger_file_handler(tmp_path):
|
||||
"""Test that get_logger configures a file handler on root."""
|
||||
log_file = tmp_path / "test.json"
|
||||
logger = get_logger("test_file", log_file=str(log_file))
|
||||
|
||||
root = logging.getLogger("ea_chatbot")
|
||||
handler_names = [h.__class__.__name__ for h in root.handlers]
|
||||
assert "RotatingFileHandler" in handler_names
|
||||
|
||||
logger.info("file log test")
|
||||
|
||||
# Check if file exists and has content
|
||||
assert log_file.exists()
|
||||
content = log_file.read_text()
|
||||
assert "file log test" in content
|
||||
83
backend/tests/test_logging_e2e.py
Normal file
83
backend/tests/test_logging_e2e.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
from ea_chatbot.utils.logging import get_logger
|
||||
from langchain_community.chat_models import FakeListChatModel
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_logging():
|
||||
"""Reset handlers on the root ea_chatbot logger."""
|
||||
root = logging.getLogger("ea_chatbot")
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
yield
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
|
||||
class FakeStructuredModel(FakeListChatModel):
|
||||
def with_structured_output(self, schema, **kwargs):
|
||||
# Return a runnable that returns a parsed object
|
||||
def _invoke(input, config=None, **kwargs):
|
||||
content = self.responses[0]
|
||||
import json
|
||||
data = json.loads(content)
|
||||
if hasattr(schema, "model_validate"):
|
||||
return schema.model_validate(data)
|
||||
return data
|
||||
|
||||
return RunnableLambda(_invoke)
|
||||
|
||||
def test_logging_e2e_json_output(tmp_path):
|
||||
"""Test that a full graph run produces structured JSON logs from multiple nodes."""
|
||||
log_file = tmp_path / "e2e_test.jsonl"
|
||||
|
||||
# Configure the root logger
|
||||
get_logger("ea_chatbot", log_file=str(log_file))
|
||||
|
||||
initial_state: AgentState = {
|
||||
"messages": [],
|
||||
"question": "Who won in 2024?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"code_output": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Create fake models that support callbacks and structured output
|
||||
fake_analyzer_response = """{"data_required": [], "unknowns": [], "ambiguities": ["Which year?"], "conditions": [], "next_action": "clarify"}"""
|
||||
fake_analyzer = FakeStructuredModel(responses=[fake_analyzer_response])
|
||||
|
||||
fake_clarify = FakeListChatModel(responses=["Please specify."])
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_llm_factory:
|
||||
mock_llm_factory.return_value = fake_analyzer
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.clarification.get_llm_model") as mock_clarify_llm_factory:
|
||||
mock_clarify_llm_factory.return_value = fake_clarify
|
||||
|
||||
# Run the graph
|
||||
list(app.stream(initial_state))
|
||||
|
||||
# Verify file content
|
||||
assert log_file.exists()
|
||||
lines = log_file.read_text().splitlines()
|
||||
assert len(lines) > 0
|
||||
|
||||
# Verify we have logs from different nodes
|
||||
node_names = [json.loads(line)["name"] for line in lines]
|
||||
assert "ea_chatbot.query_analyzer" in node_names
|
||||
assert "ea_chatbot.clarification" in node_names
|
||||
|
||||
# Verify events
|
||||
messages = [json.loads(line)["message"] for line in lines]
|
||||
assert any("Analyzing question" in m for m in messages)
|
||||
assert any("Clarification generated" in m for m in messages)
|
||||
64
backend/tests/test_logging_langchain.py
Normal file
64
backend/tests/test_logging_langchain.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import logging
|
||||
import pytest
|
||||
import io
|
||||
from unittest.mock import MagicMock
|
||||
from ea_chatbot.utils.logging import LangChainLoggingHandler
|
||||
|
||||
@pytest.fixture
|
||||
def log_capture():
|
||||
"""Fixture to capture logs from a logger."""
|
||||
log_stream = io.StringIO()
|
||||
logger = logging.getLogger("test_langchain")
|
||||
logger.setLevel(logging.INFO)
|
||||
# Remove existing handlers
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler(log_stream)
|
||||
logger.addHandler(handler)
|
||||
return logger, log_stream
|
||||
|
||||
def test_langchain_logging_handler_on_llm_start(log_capture):
|
||||
"""Test that on_llm_start logs the correct message."""
|
||||
logger, log_stream = log_capture
|
||||
handler = LangChainLoggingHandler(logger=logger)
|
||||
handler.on_llm_start(serialized={"name": "test_model"}, prompts=["test prompt"])
|
||||
|
||||
output = log_stream.getvalue()
|
||||
assert "LLM Started:" in output
|
||||
assert "test_model" in output
|
||||
|
||||
def test_langchain_logging_handler_on_llm_end(log_capture):
|
||||
"""Test that on_llm_end logs token usage."""
|
||||
logger, log_stream = log_capture
|
||||
handler = LangChainLoggingHandler(logger=logger)
|
||||
response = MagicMock()
|
||||
response.llm_output = {
|
||||
"token_usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
},
|
||||
"model_name": "test_model"
|
||||
}
|
||||
|
||||
handler.on_llm_end(response=response)
|
||||
|
||||
output = log_stream.getvalue()
|
||||
assert "LLM Ended:" in output
|
||||
assert "test_model" in output
|
||||
assert "Tokens: 30" in output
|
||||
assert "10 prompt" in output
|
||||
assert "20 completion" in output
|
||||
|
||||
def test_langchain_logging_handler_on_llm_error(log_capture):
|
||||
"""Test that on_llm_error logs the error."""
|
||||
logger, log_stream = log_capture
|
||||
handler = LangChainLoggingHandler(logger=logger)
|
||||
error = Exception("test error")
|
||||
|
||||
handler.on_llm_error(error=error)
|
||||
|
||||
output = log_stream.getvalue()
|
||||
assert "LLM Error:" in output
|
||||
assert "test error" in output
|
||||
79
backend/tests/test_multi_turn_planner_researcher.py
Normal file
79
backend/tests/test_multi_turn_planner_researcher.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.schemas import TaskPlanResponse
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
],
|
||||
"question": "What about in New Jersey?",
|
||||
"analysis": {"data_required": ["2024 results", "New Jersey"], "unknowns": [], "ambiguities": [], "conditions": []},
|
||||
"next_action": "plan",
|
||||
"summary": "The user is asking about 2024 election results.",
|
||||
"plan": "Plan steps...",
|
||||
"code_output": "Code output..."
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
@patch("ea_chatbot.graph.nodes.planner.PLANNER_PROMPT")
|
||||
def test_planner_uses_history_and_summary(mock_prompt, mock_get_summary, mock_get_llm, mock_state_with_history):
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
mock_structured_llm.invoke.return_value = TaskPlanResponse(
|
||||
goal="goal",
|
||||
reflection="reflection",
|
||||
context={
|
||||
"initial_context": "context",
|
||||
"assumptions": [],
|
||||
"constraints": []
|
||||
},
|
||||
steps=["Step 1: test"]
|
||||
)
|
||||
|
||||
planner_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.RESEARCHER_PROMPT")
|
||||
def test_researcher_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
researcher_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.SUMMARIZER_PROMPT")
|
||||
def test_summarizer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
|
||||
summarizer_node(mock_state_with_history)
|
||||
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert len(kwargs["history"]) == 2
|
||||
76
backend/tests/test_multi_turn_query_analyzer.py
Normal file
76
backend/tests/test_multi_turn_query_analyzer.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state_with_history():
|
||||
return {
|
||||
"messages": [
|
||||
HumanMessage(content="Show me the 2024 results for Florida"),
|
||||
AIMessage(content="Here are the results for Florida in 2024...")
|
||||
],
|
||||
"question": "What about in New Jersey?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"summary": "The user is asking about 2024 election results."
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT")
|
||||
def test_query_analyzer_uses_history_and_summary(mock_prompt, mock_get_llm, mock_state_with_history):
|
||||
"""Test that query_analyzer_node passes history and summary to the prompt."""
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
mock_structured_llm.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results", "New Jersey"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
query_analyzer_node(mock_state_with_history)
|
||||
|
||||
# Verify that the prompt was formatted with the correct variables
|
||||
mock_prompt.format_messages.assert_called_once()
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
|
||||
assert kwargs["question"] == "What about in New Jersey?"
|
||||
assert "summary" in kwargs
|
||||
assert kwargs["summary"] == mock_state_with_history["summary"]
|
||||
assert "history" in kwargs
|
||||
# History should contain the messages from the state
|
||||
assert len(kwargs["history"]) == 2
|
||||
assert kwargs["history"][0].content == "Show me the 2024 results for Florida"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_context_window(mock_get_llm):
|
||||
"""Test that query_analyzer_node only uses the last 6 messages (3 turns)."""
|
||||
messages = [HumanMessage(content=f"Msg {i}") for i in range(10)]
|
||||
state = {
|
||||
"messages": messages,
|
||||
"question": "Latest question",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"summary": "Summary"
|
||||
}
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
mock_structured_llm.invoke.return_value = QueryAnalysis(
|
||||
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
|
||||
)
|
||||
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.QUERY_ANALYZER_PROMPT") as mock_prompt:
|
||||
query_analyzer_node(state)
|
||||
kwargs = mock_prompt.format_messages.call_args[1]
|
||||
# Should only have last 6 messages
|
||||
assert len(kwargs["history"]) == 6
|
||||
assert kwargs["history"][0].content == "Msg 4"
|
||||
87
backend/tests/test_oidc_client.py
Normal file
87
backend/tests/test_oidc_client.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.auth import OIDCClient
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_config():
|
||||
return {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"server_metadata_url": "https://example.com/.well-known/openid-configuration",
|
||||
"redirect_uri": "http://localhost:8501"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata():
|
||||
return {
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}
|
||||
|
||||
def test_oidc_fetch_metadata(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_metadata
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
metadata = client.fetch_metadata()
|
||||
|
||||
assert metadata == mock_metadata
|
||||
mock_get.assert_called_once_with(oidc_config["server_metadata_url"])
|
||||
|
||||
# Second call should use cache
|
||||
client.fetch_metadata()
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
def test_oidc_get_login_url(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "create_authorization_url") as mock_create_url:
|
||||
mock_create_url.return_value = ("https://example.com/auth?state=xyz", "xyz")
|
||||
|
||||
url = client.get_login_url()
|
||||
|
||||
assert url == "https://example.com/auth?state=xyz"
|
||||
mock_create_url.assert_called_once_with(mock_metadata["authorization_endpoint"])
|
||||
|
||||
def test_oidc_get_login_url_missing_endpoint(oidc_config):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = {"some_other": "field"}
|
||||
|
||||
with pytest.raises(ValueError, match="authorization_endpoint not found"):
|
||||
client.get_login_url()
|
||||
|
||||
def test_oidc_exchange_code_for_token(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "fetch_token") as mock_fetch_token:
|
||||
mock_fetch_token.return_value = {"access_token": "abc"}
|
||||
|
||||
token = client.exchange_code_for_token("test_code")
|
||||
|
||||
assert token == {"access_token": "abc"}
|
||||
mock_fetch_token.assert_called_once_with(
|
||||
mock_metadata["token_endpoint"],
|
||||
code="test_code",
|
||||
client_secret=oidc_config["client_secret"]
|
||||
)
|
||||
|
||||
def test_oidc_get_user_info(oidc_config, mock_metadata):
|
||||
client = OIDCClient(**oidc_config)
|
||||
client.metadata = mock_metadata
|
||||
|
||||
with patch.object(client.oauth_session, "get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"sub": "user123"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
user_info = client.get_user_info({"access_token": "abc"})
|
||||
|
||||
assert user_info == {"sub": "user123"}
|
||||
assert client.oauth_session.token == {"access_token": "abc"}
|
||||
mock_get.assert_called_once_with(mock_metadata["userinfo_endpoint"])
|
||||
46
backend/tests/test_planner.py
Normal file
46
backend/tests/test_planner.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.planner import planner_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me results for New Jersey",
|
||||
"analysis": {
|
||||
# "requires_dataset" removed as it's no longer used
|
||||
"expert": "Data Analyst",
|
||||
"data": "NJ data",
|
||||
"unknown": "results",
|
||||
"condition": "state=NJ"
|
||||
},
|
||||
"next_action": "plan",
|
||||
"plan": None
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
def test_planner_node(mock_get_summary, mock_get_llm, mock_state):
|
||||
"""Test planner node with unified prompt."""
|
||||
mock_get_summary.return_value = "Column: Name, Type: text"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext
|
||||
mock_plan = TaskPlanResponse(
|
||||
goal="Get NJ results",
|
||||
reflection="The user wants NJ results",
|
||||
context=TaskPlanContext(initial_context="NJ data", assumptions=[], constraints=[]),
|
||||
steps=["Step 1: Load data", "Step 2: Filter by NJ"]
|
||||
)
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = mock_plan
|
||||
|
||||
result = planner_node(mock_state)
|
||||
|
||||
assert "plan" in result
|
||||
assert "Step 1: Load data" in result["plan"]
|
||||
assert "Step 2: Filter by NJ" in result["plan"]
|
||||
|
||||
# Verify helper was called
|
||||
mock_get_summary.assert_called_once()
|
||||
80
backend/tests/test_query_analyzer.py
Normal file
80
backend/tests/test_query_analyzer.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results for Florida",
|
||||
"analysis": None,
|
||||
"next_action": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_data_analysis(mock_get_llm, mock_state):
|
||||
"""Test that a clear data analysis query is routed to the planner."""
|
||||
# Mock the LLM and the structured output runnable
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
# Define the expected Pydantic result
|
||||
expected_analysis = QueryAnalysis(
|
||||
data_required=["2024 results", "Florida"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
# When structured_llm.invoke is called with messages, return the Pydantic object
|
||||
mock_structured_llm.invoke.return_value = expected_analysis
|
||||
|
||||
new_state_update = query_analyzer_node(mock_state)
|
||||
|
||||
assert new_state_update["next_action"] == "plan"
|
||||
assert "2024 results" in new_state_update["analysis"]["data_required"]
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_ambiguous(mock_get_llm, mock_state):
|
||||
"""Test that an ambiguous query is routed to clarification."""
|
||||
mock_state["question"] = "What happened?"
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
expected_analysis = QueryAnalysis(
|
||||
data_required=[],
|
||||
unknowns=["What event?"],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="clarify"
|
||||
)
|
||||
|
||||
mock_structured_llm.invoke.return_value = expected_analysis
|
||||
|
||||
new_state_update = query_analyzer_node(mock_state)
|
||||
|
||||
assert new_state_update["next_action"] == "clarify"
|
||||
assert len(new_state_update["analysis"]["unknowns"]) > 0
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_query_analyzer_uses_config(mock_get_llm, mock_state, monkeypatch):
|
||||
"""Test that the node uses the configured LLM settings."""
|
||||
monkeypatch.setenv("QUERY_ANALYZER_LLM__MODEL", "gpt-3.5-turbo")
|
||||
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
mock_structured_llm.invoke.return_value = QueryAnalysis(
|
||||
data_required=[], unknowns=[], ambiguities=[], conditions=[], next_action="plan"
|
||||
)
|
||||
|
||||
query_analyzer_node(mock_state)
|
||||
|
||||
# Verify get_llm_model was called with the overridden config
|
||||
called_config = mock_get_llm.call_args[0][0]
|
||||
assert called_config.model == "gpt-3.5-turbo"
|
||||
45
backend/tests/test_query_analyzer_logging.py
Normal file
45
backend/tests/test_query_analyzer_logging.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results for Florida",
|
||||
"analysis": None,
|
||||
"next_action": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_logger")
|
||||
def test_query_analyzer_logs_actions(mock_get_logger, mock_get_llm, mock_state):
|
||||
"""Test that query_analyzer_node logs its main actions."""
|
||||
# Mock Logger
|
||||
mock_logger = MagicMock()
|
||||
mock_get_logger.return_value = mock_logger
|
||||
|
||||
# Mock LLM
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
mock_structured_llm = MagicMock()
|
||||
mock_llm_instance.with_structured_output.return_value = mock_structured_llm
|
||||
|
||||
expected_analysis = QueryAnalysis(
|
||||
data_required=["2024 results", "Florida"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
mock_structured_llm.invoke.return_value = expected_analysis
|
||||
|
||||
query_analyzer_node(mock_state)
|
||||
|
||||
# Check that logger was called
|
||||
# We expect at least one log at the start and one at the end
|
||||
assert mock_logger.info.called
|
||||
|
||||
# Verify specific log messages if we decide on them
|
||||
# For now, just ensuring it's called is enough for Red phase
|
||||
103
backend/tests/test_query_analyzer_refinement.py
Normal file
103
backend/tests/test_query_analyzer_refinement.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from ea_chatbot.graph.nodes.query_analyzer import query_analyzer_node, QueryAnalysis
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
return {
|
||||
"messages": [],
|
||||
"question": "",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_refinement_coreference_from_history(mock_get_llm, base_state):
|
||||
"""
|
||||
Test that the analyzer can resolve Year/State from history.
|
||||
User asks "What about in NJ?" after a Florida 2024 query.
|
||||
Expected: next_action = 'plan', NOT 'clarify' due to missing year.
|
||||
"""
|
||||
state = base_state.copy()
|
||||
state["messages"] = [
|
||||
HumanMessage(content="Show me 2024 results for Florida"),
|
||||
AIMessage(content="Here are the 2024 results for Florida...")
|
||||
]
|
||||
state["question"] = "What about in New Jersey?"
|
||||
state["summary"] = "The user is looking for 2024 election results."
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_structured = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_structured
|
||||
|
||||
# We expect the LLM to eventually return 'plan' because it sees the context.
|
||||
# For now, if it returns 'clarify', this test should fail once we update the prompt to BE less strict.
|
||||
mock_structured.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results", "New Jersey"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=["state=NJ", "year=2024"],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
result = query_analyzer_node(state)
|
||||
assert result["next_action"] == "plan"
|
||||
assert "NJ" in str(result["analysis"]["conditions"])
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_refinement_tolerance_for_missing_format(mock_get_llm, base_state):
|
||||
"""
|
||||
Test that the analyzer doesn't flag missing output format or database name.
|
||||
User asks "Give me a graph of turnout".
|
||||
Expected: next_action = 'plan', even if 'format' or 'db' is not in query.
|
||||
"""
|
||||
state = base_state.copy()
|
||||
state["question"] = "Give me a graph of voter turnout in 2024 for Florida"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_structured = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_structured
|
||||
|
||||
mock_structured.invoke.return_value = QueryAnalysis(
|
||||
data_required=["voter turnout", "Florida"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=["year=2024"],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
result = query_analyzer_node(state)
|
||||
assert result["next_action"] == "plan"
|
||||
# Ensure no ambiguities were added by the analyzer itself (hallucinated requirement)
|
||||
assert len(result["analysis"]["ambiguities"]) == 0
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
def test_refinement_enforces_voter_identity_clarification(mock_get_llm, base_state):
|
||||
"""
|
||||
Test that 'track the same voter' still triggers clarification.
|
||||
"""
|
||||
state = base_state.copy()
|
||||
state["question"] = "Track the same voter participation in 2020 and 2024."
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_structured = MagicMock()
|
||||
mock_llm.with_structured_output.return_value = mock_structured
|
||||
|
||||
# We WANT it to clarify here because voter identity is not defined.
|
||||
mock_structured.invoke.return_value = QueryAnalysis(
|
||||
data_required=["voter participation"],
|
||||
unknowns=[],
|
||||
ambiguities=["Please define what fields constitute 'the same voter' (e.g. ID, or Name and DOB)."],
|
||||
conditions=[],
|
||||
next_action="clarify"
|
||||
)
|
||||
|
||||
result = query_analyzer_node(state)
|
||||
assert result["next_action"] == "clarify"
|
||||
assert "identity" in str(result["analysis"]["ambiguities"]).lower() or "same voter" in str(result["analysis"]["ambiguities"]).lower()
|
||||
34
backend/tests/test_researcher.py
Normal file
34
backend/tests/test_researcher.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
with patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_get_llm:
|
||||
mock_llm_instance = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
yield mock_llm_instance
|
||||
|
||||
def test_researcher_node_success(mock_llm):
|
||||
"""Test that researcher_node invokes LLM with web_search tool and returns messages."""
|
||||
state = {
|
||||
"question": "What is the capital of France?",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="The capital of France is Paris.")
|
||||
|
||||
result = researcher_node(state)
|
||||
|
||||
assert mock_llm.bind_tools.called
|
||||
# Check that it was called with web_search
|
||||
args, kwargs = mock_llm.bind_tools.call_args
|
||||
assert {"type": "web_search"} in args[0]
|
||||
|
||||
assert mock_llm_with_tools.invoke.called
|
||||
assert "messages" in result
|
||||
assert result["messages"][0].content == "The capital of France is Paris."
|
||||
62
backend/tests/test_researcher_search_tools.py
Normal file
62
backend/tests/test_researcher_search_tools.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.graph.nodes.researcher import researcher_node
|
||||
|
||||
@pytest.fixture
|
||||
def base_state():
|
||||
return {
|
||||
"question": "Who won the 2024 election?",
|
||||
"messages": [],
|
||||
"summary": ""
|
||||
}
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_binds_openai_search(mock_get_llm, base_state):
|
||||
"""Test that OpenAI LLM binds 'web_search' tool."""
|
||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Verify bind_tools called with correct OpenAI tool
|
||||
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
|
||||
assert result["messages"][0].content == "OpenAI Search Result"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_binds_google_search(mock_get_llm, base_state):
|
||||
"""Test that Google LLM binds 'google_search' tool."""
|
||||
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Verify bind_tools called with correct Google tool
|
||||
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
|
||||
assert result["messages"][0].content == "Google Search Result"
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
|
||||
"""Test that researcher falls back to basic LLM if bind_tools fails."""
|
||||
mock_llm = MagicMock(spec=ChatOpenAI)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Simulate bind_tools failing (e.g. model doesn't support it)
|
||||
mock_llm.bind_tools.side_effect = Exception("Not supported")
|
||||
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
|
||||
|
||||
result = researcher_node(base_state)
|
||||
|
||||
# Should still succeed using the base LLM
|
||||
assert result["messages"][0].content == "Basic Result"
|
||||
mock_llm.invoke.assert_called_once()
|
||||
41
backend/tests/test_state.py
Normal file
41
backend/tests/test_state.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
from typing import get_type_hints, List
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
import operator
|
||||
|
||||
def test_agent_state_structure():
|
||||
"""Verify that AgentState has the required fields and types."""
|
||||
hints = get_type_hints(AgentState)
|
||||
|
||||
assert "messages" in hints
|
||||
# Check if Annotated is used, we might need to inspect the __metadata__ if feasible,
|
||||
# but for TypedDict, checking the key existence is a good start.
|
||||
# The exact type check for Annotated[List[BaseMessage], operator.add] can be complex to assert strictly,
|
||||
# but we can check if it's there.
|
||||
|
||||
assert "question" in hints
|
||||
assert hints["question"] == str
|
||||
|
||||
# analysis should be Optional[Dict[str, Any]] or similar, but the spec says "Dictionary"
|
||||
# Let's check it exists.
|
||||
assert "analysis" in hints
|
||||
|
||||
assert "next_action" in hints
|
||||
assert hints["next_action"] == str
|
||||
|
||||
assert "summary" in hints
|
||||
# summary should be Optional[str] or str. Let's assume Optional[str] for flexibility.
|
||||
|
||||
assert "plots" in hints
|
||||
assert "dfs" in hints
|
||||
|
||||
def test_messages_reducer_behavior():
|
||||
"""Verify that the messages field allows adding lists (simulation of operator.add)."""
|
||||
# This is harder to test directly on the TypedDict definition without instantiating it in a graph context,
|
||||
# but we can verify that the type hint implies a list.
|
||||
hints = get_type_hints(AgentState)
|
||||
# We expect messages to be Annotated[List[BaseMessage], operator.add]
|
||||
# We can just assume the developer implements it correctly if the previous test passes,
|
||||
# or try to inspect the annotation.
|
||||
pass
|
||||
48
backend/tests/test_summarizer.py
Normal file
48
backend/tests/test_summarizer.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.nodes.summarizer import summarizer_node
|
||||
from ea_chatbot.graph.state import AgentState
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
with patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_get_llm:
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_get_llm.return_value = mock_llm_instance
|
||||
yield mock_llm_instance
|
||||
|
||||
def test_summarizer_node_success(mock_llm):
|
||||
"""Test that summarizer_node invokes LLM with correct inputs and returns messages."""
|
||||
state = {
|
||||
"question": "What is the total count?",
|
||||
"plan": "1. Run query\n2. Sum results",
|
||||
"code_output": "The total is 100",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm.invoke.return_value = AIMessage(content="The final answer is 100.")
|
||||
|
||||
result = summarizer_node(state)
|
||||
|
||||
# Verify LLM was called
|
||||
assert mock_llm.invoke.called
|
||||
|
||||
# Verify result structure
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert isinstance(result["messages"][0], AIMessage)
|
||||
assert result["messages"][0].content == "The final answer is 100."
|
||||
|
||||
def test_summarizer_node_empty_state(mock_llm):
|
||||
"""Test handling of empty or minimal state."""
|
||||
state = {
|
||||
"question": "Empty?",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
mock_llm.invoke.return_value = AIMessage(content="No data provided.")
|
||||
|
||||
result = summarizer_node(state)
|
||||
|
||||
assert "messages" in result
|
||||
assert result["messages"][0].content == "No data provided."
|
||||
93
backend/tests/test_workflow.py
Normal file
93
backend/tests/test_workflow.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
@patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.planner.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.coder.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.summarizer.get_llm_model")
|
||||
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
||||
@patch("ea_chatbot.utils.database_inspection.get_data_summary")
|
||||
@patch("ea_chatbot.graph.nodes.executor.Settings")
|
||||
@patch("ea_chatbot.graph.nodes.executor.DBClient")
|
||||
def test_workflow_full_flow(mock_db_client, mock_settings, mock_get_summary, mock_researcher_llm, mock_summarizer_llm, mock_coder_llm, mock_planner_llm, mock_qa_llm):
|
||||
"""Test the flow from query_analyzer through planner to coder."""
|
||||
|
||||
# Mock Settings for Executor
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.db_host = "localhost"
|
||||
mock_settings_instance.db_port = 5432
|
||||
mock_settings_instance.db_user = "user"
|
||||
mock_settings_instance.db_pswd = "pass"
|
||||
mock_settings_instance.db_name = "test_db"
|
||||
mock_settings_instance.db_table = "test_table"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
|
||||
# Mock DBClient
|
||||
mock_client_instance = MagicMock()
|
||||
mock_db_client.return_value = mock_client_instance
|
||||
|
||||
# 1. Mock Query Analyzer
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_qa_llm.return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_planner_llm.return_value = mock_planner_instance
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
goal="Task Goal",
|
||||
reflection="Reflection",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_coder_llm.return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
code="print('Hello')",
|
||||
explanation="Explanation"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_summarizer_llm.return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Summary")
|
||||
|
||||
# 5. Mock Researcher (just in case)
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_researcher_llm.return_value = mock_researcher_instance
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Show me the 2024 results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
# We use recursion_limit to avoid infinite loops in placeholders if any
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "plan" in result and result["plan"] is not None
|
||||
assert "code" in result and "print('Hello')" in result["code"]
|
||||
assert "analysis" in result
|
||||
139
backend/tests/test_workflow_e2e.py
Normal file
139
backend/tests/test_workflow_e2e.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import pytest
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
from langchain_core.messages import AIMessage
|
||||
from ea_chatbot.graph.workflow import app
|
||||
from ea_chatbot.graph.nodes.query_analyzer import QueryAnalysis
|
||||
from ea_chatbot.schemas import TaskPlanResponse, TaskPlanContext, CodeGenerationResponse
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llms():
|
||||
with patch("ea_chatbot.graph.nodes.query_analyzer.get_llm_model") as mock_qa_llm, \
|
||||
patch("ea_chatbot.graph.nodes.planner.get_llm_model") as mock_planner_llm, \
|
||||
patch("ea_chatbot.graph.nodes.coder.get_llm_model") as mock_coder_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarizer.get_llm_model") as mock_summarizer_llm, \
|
||||
patch("ea_chatbot.graph.nodes.researcher.get_llm_model") as mock_researcher_llm, \
|
||||
patch("ea_chatbot.graph.nodes.summarize_conversation.get_llm_model") as mock_summary_llm, \
|
||||
patch("ea_chatbot.utils.database_inspection.get_data_summary") as mock_get_summary:
|
||||
mock_get_summary.return_value = "Data summary"
|
||||
|
||||
# Mock summary LLM to return a simple response
|
||||
mock_summary_instance = MagicMock()
|
||||
mock_summary_llm.return_value = mock_summary_instance
|
||||
mock_summary_instance.invoke.return_value = AIMessage(content="Turn summary")
|
||||
|
||||
yield {
|
||||
"qa": mock_qa_llm,
|
||||
"planner": mock_planner_llm,
|
||||
"coder": mock_coder_llm,
|
||||
"summarizer": mock_summarizer_llm,
|
||||
"researcher": mock_researcher_llm,
|
||||
"summary": mock_summary_llm
|
||||
}
|
||||
|
||||
def test_workflow_data_analysis_flow(mock_llms):
|
||||
"""Test full flow: QueryAnalyzer -> Planner -> Coder -> Executor -> Summarizer."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to plan)
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=["2024 results"],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="plan"
|
||||
)
|
||||
|
||||
# 2. Mock Planner
|
||||
mock_planner_instance = MagicMock()
|
||||
mock_llms["planner"].return_value = mock_planner_instance
|
||||
mock_planner_instance.with_structured_output.return_value.invoke.return_value = TaskPlanResponse(
|
||||
goal="Get results",
|
||||
reflection="Reflect",
|
||||
context=TaskPlanContext(initial_context="Ctx", assumptions=[], constraints=[]),
|
||||
steps=["Step 1"]
|
||||
)
|
||||
|
||||
# 3. Mock Coder
|
||||
mock_coder_instance = MagicMock()
|
||||
mock_llms["coder"].return_value = mock_coder_instance
|
||||
mock_coder_instance.with_structured_output.return_value.invoke.return_value = CodeGenerationResponse(
|
||||
code="print('Execution Success')",
|
||||
explanation="Explain"
|
||||
)
|
||||
|
||||
# 4. Mock Summarizer
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Success")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Show me 2024 results",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 15})
|
||||
|
||||
assert result["next_action"] == "plan"
|
||||
assert "Execution Success" in result["code_output"]
|
||||
assert "Final Summary: Success" in result["messages"][-1].content
|
||||
|
||||
def test_workflow_research_flow(mock_llms):
|
||||
"""Test flow: QueryAnalyzer -> Researcher -> Summarizer."""
|
||||
|
||||
# 1. Mock Query Analyzer (routes to research)
|
||||
mock_qa_instance = MagicMock()
|
||||
mock_llms["qa"].return_value = mock_qa_instance
|
||||
mock_qa_instance.with_structured_output.return_value.invoke.return_value = QueryAnalysis(
|
||||
data_required=[],
|
||||
unknowns=[],
|
||||
ambiguities=[],
|
||||
conditions=[],
|
||||
next_action="research"
|
||||
)
|
||||
|
||||
# 2. Mock Researcher
|
||||
mock_researcher_instance = MagicMock()
|
||||
mock_llms["researcher"].return_value = mock_researcher_instance
|
||||
# Researcher node uses bind_tools if it's ChatOpenAI/ChatGoogleGenerativeAI
|
||||
# Since it's a MagicMock, it will fallback to using the base instance
|
||||
mock_researcher_instance.invoke.return_value = AIMessage(content="Research Results")
|
||||
|
||||
# Also mock bind_tools just in case we ever use spec
|
||||
mock_llm_with_tools = MagicMock()
|
||||
mock_researcher_instance.bind_tools.return_value = mock_llm_with_tools
|
||||
mock_llm_with_tools.invoke.return_value = AIMessage(content="Research Results")
|
||||
|
||||
# 3. Mock Summarizer (not used in this flow, but kept for completeness)
|
||||
mock_summarizer_instance = MagicMock()
|
||||
mock_llms["summarizer"].return_value = mock_summarizer_instance
|
||||
mock_summarizer_instance.invoke.return_value = AIMessage(content="Final Summary: Research Success")
|
||||
|
||||
# Initial state
|
||||
initial_state = {
|
||||
"messages": [],
|
||||
"question": "Who is the governor of Florida?",
|
||||
"analysis": None,
|
||||
"next_action": "",
|
||||
"plan": None,
|
||||
"code": None,
|
||||
"error": None,
|
||||
"plots": [],
|
||||
"dfs": {}
|
||||
}
|
||||
|
||||
# Run the graph
|
||||
result = app.invoke(initial_state, config={"recursion_limit": 10})
|
||||
|
||||
assert result["next_action"] == "research"
|
||||
assert "Research Results" in result["messages"][-1].content
|
||||
72
backend/tests/test_workflow_history.py
Normal file
72
backend/tests/test_workflow_history.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import pytest
|
||||
from ea_chatbot.history.manager import HistoryManager
|
||||
from ea_chatbot.history.models import User, Conversation, Message, Plot
|
||||
from ea_chatbot.config import Settings
|
||||
from sqlalchemy import delete
|
||||
|
||||
@pytest.fixture
|
||||
def history_manager():
|
||||
settings = Settings()
|
||||
manager = HistoryManager(settings.history_db_url)
|
||||
|
||||
def cleanup():
|
||||
with manager.get_session() as session:
|
||||
session.execute(delete(Plot))
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Conversation))
|
||||
session.execute(delete(User))
|
||||
|
||||
cleanup()
|
||||
yield manager
|
||||
cleanup()
|
||||
|
||||
def test_full_history_workflow(history_manager):
|
||||
# 1. Create and Authenticate User
|
||||
email = "e2e@example.com"
|
||||
password = "password123"
|
||||
history_manager.create_user(email, password, "E2E User")
|
||||
|
||||
user = history_manager.authenticate_user(email, password)
|
||||
assert user is not None
|
||||
assert user.display_name == "E2E User"
|
||||
|
||||
# 1.1 Verify get_user_by_id
|
||||
fetched_user = history_manager.get_user_by_id(user.id)
|
||||
assert fetched_user is not None
|
||||
assert fetched_user.username == email
|
||||
|
||||
# 2. Create Conversation
|
||||
conv = history_manager.create_conversation(user.id, "nj", "Test Analytics")
|
||||
assert conv.id is not None
|
||||
|
||||
# 3. Add User Message
|
||||
history_manager.add_message(conv.id, "user", "How many voters in NJ?")
|
||||
|
||||
# 4. Add Assistant Message with Plot
|
||||
plot_data = b"fake_png_data"
|
||||
history_manager.add_message(
|
||||
conv.id,
|
||||
"assistant",
|
||||
"There are X voters.",
|
||||
plots=[plot_data]
|
||||
)
|
||||
|
||||
# 5. Retrieve and Verify History
|
||||
messages = history_manager.get_messages(conv.id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == "user"
|
||||
assert messages[1].role == "assistant"
|
||||
assert len(messages[1].plots) == 1
|
||||
assert messages[1].plots[0].image_data == plot_data
|
||||
|
||||
# 6. Verify Conversation listing
|
||||
convs = history_manager.get_conversations(user.id, "nj")
|
||||
assert len(convs) == 1
|
||||
assert convs[0].name == "Test Analytics"
|
||||
|
||||
# 7. Update summary
|
||||
history_manager.update_conversation_summary(conv.id, "Voter count analysis")
|
||||
|
||||
# 8. Reload and verify summary
|
||||
updated_convs = history_manager.get_conversations(user.id, "nj")
|
||||
assert updated_convs[0].summary == "Voter count analysis"
|
||||
Reference in New Issue
Block a user