import pytest from unittest.mock import MagicMock, patch from langchain_core.messages import AIMessage from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI from ea_chatbot.graph.nodes.researcher import researcher_node @pytest.fixture def base_state(): return { "question": "Who won the 2024 election?", "messages": [], "summary": "" } @patch("ea_chatbot.graph.nodes.researcher.get_llm_model") def test_researcher_binds_openai_search(mock_get_llm, base_state): """Test that OpenAI LLM binds 'web_search' tool.""" mock_llm = MagicMock(spec=ChatOpenAI) mock_get_llm.return_value = mock_llm mock_llm_with_tools = MagicMock() mock_llm.bind_tools.return_value = mock_llm_with_tools mock_llm_with_tools.invoke.return_value = AIMessage(content="OpenAI Search Result") result = researcher_node(base_state) # Verify bind_tools called with correct OpenAI tool mock_llm.bind_tools.assert_called_once_with([{"type": "web_search"}]) assert result["messages"][0].content == "OpenAI Search Result" @patch("ea_chatbot.graph.nodes.researcher.get_llm_model") def test_researcher_binds_google_search(mock_get_llm, base_state): """Test that Google LLM binds 'google_search' tool.""" mock_llm = MagicMock(spec=ChatGoogleGenerativeAI) mock_get_llm.return_value = mock_llm mock_llm_with_tools = MagicMock() mock_llm.bind_tools.return_value = mock_llm_with_tools mock_llm_with_tools.invoke.return_value = AIMessage(content="Google Search Result") result = researcher_node(base_state) # Verify bind_tools called with correct Google tool mock_llm.bind_tools.assert_called_once_with([{"google_search": {}}]) assert result["messages"][0].content == "Google Search Result" @patch("ea_chatbot.graph.nodes.researcher.get_llm_model") def test_researcher_fallback_on_bind_error(mock_get_llm, base_state): """Test that researcher falls back to basic LLM if bind_tools fails.""" mock_llm = MagicMock(spec=ChatOpenAI) mock_get_llm.return_value = mock_llm # Simulate bind_tools failing (e.g. model doesn't support it) mock_llm.bind_tools.side_effect = Exception("Not supported") mock_llm.invoke.return_value = AIMessage(content="Basic Result") result = researcher_node(base_state) # Should still succeed using the base LLM assert result["messages"][0].content == "Basic Result" mock_llm.invoke.assert_called_once()