63 lines
2.5 KiB
Python
63 lines
2.5 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from langchain_core.messages import AIMessage
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from ea_chatbot.graph.nodes.researcher import researcher_node
|
|
|
|
@pytest.fixture
|
|
def base_state():
|
|
return {
|
|
"question": "Who won the 2024 election?",
|
|
"messages": [],
|
|
"summary": ""
|
|
}
|
|
|
|
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
def test_researcher_binds_openai_search(mock_get_llm, base_state):
|
|
"""Test that OpenAI LLM binds 'web_search' tool."""
|
|
mock_llm = MagicMock(spec=ChatOpenAI)
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
mock_llm_with_tools = MagicMock()
|
|
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
|
mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result")
|
|
|
|
result = researcher_node(base_state)
|
|
|
|
# Verify bind_tools called with correct OpenAI tool
|
|
mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}])
|
|
assert result["messages"][0].content == "OpenAI Search Result"
|
|
|
|
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
def test_researcher_binds_google_search(mock_get_llm, base_state):
|
|
"""Test that Google LLM binds 'google_search' tool."""
|
|
mock_llm = MagicMock(spec=ChatGoogleGenerativeAI)
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
mock_llm_with_tools = MagicMock()
|
|
mock_llm.bind_tools.return_value = mock_llm_with_tools
|
|
mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result")
|
|
|
|
result = researcher_node(base_state)
|
|
|
|
# Verify bind_tools called with correct Google tool
|
|
mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}])
|
|
assert result["messages"][0].content == "Google Search Result"
|
|
|
|
@patch("ea_chatbot.graph.nodes.researcher.get_llm_model")
|
|
def test_researcher_fallback_on_bind_error(mock_get_llm, base_state):
|
|
"""Test that researcher falls back to basic LLM if bind_tools fails."""
|
|
mock_llm = MagicMock(spec=ChatOpenAI)
|
|
mock_get_llm.return_value = mock_llm
|
|
|
|
# Simulate bind_tools failing (e.g. model doesn't support it)
|
|
mock_llm.bind_tools.side_effect = Exception("Not supported")
|
|
mock_llm.invoke.return_value = AIMessage(content="Basic Result")
|
|
|
|
result = researcher_node(base_state)
|
|
|
|
# Should still succeed using the base LLM
|
|
assert result["messages"][0].content == "Basic Result"
|
|
mock_llm.invoke.assert_called_once()
|