Files
ea-chatbot-lg/backend/tests/test_researcher_search_tools.py

63 lines
2.5 KiB
Python

import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from ea_chatbot.graph.nodes.researcher import researcher_node
@pytest.fixture
def base_state():
return {
"question": "Who won the 2024 election?",
"messages": [],
"summary": ""
}
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_binds_openai_search(mock_get_llm, base_state):
"""Test that OpenAI LLM binds 'web_search' tool."""
mock_llm = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
result = researcher_node(base_state)
# Verify bind_tools called with correct OpenAI tool
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
assert result["messages"][0].content == "OpenAI Search Result"
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_binds_google_search(mock_get_llm, base_state):
"""Test that Google LLM binds 'google_search' tool."""
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
mock_get_llm.return_value = mock_llm
mock_llm_with_tools = MagicMock()
mock_llm.bind_tools.return_value = mock_llm_with_tools
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
result = researcher_node(base_state)
# Verify bind_tools called with correct Google tool
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
assert result["messages"][0].content == "Google Search Result"
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
"""Test that researcher falls back to basic LLM if bind_tools fails."""
mock_llm = MagicMock(spec=ChatOpenAI)
mock_get_llm.return_value = mock_llm
# Simulate bind_tools failing (e.g. model doesn't support it)
mock_llm.bind_tools.side_effect = Exception("Not supported")
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
result = researcher_node(base_state)
# Should still succeed using the base LLM
assert result["messages"][0].content == "Basic Result"
mock_llm.invoke.assert_called_once()