import pytest from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI from ea_chatbot.config import LLMConfig from ea_chatbot.utils.llm_factory import get_llm_model def test_get_openai_model(monkeypatch): """Test creating an OpenAI model.""" monkeypatch.setenv("OPENAI_API_KEY", "dummy") config = LLMConfig( provider="openai", model="gpt-4o", temperature=0.5, max_tokens=100 ) model = get_llm_model(config) assert isinstance(model, ChatOpenAI) assert model.model_name == "gpt-4o" assert model.temperature == 0.5 assert model.max_tokens == 100 def test_get_google_model(monkeypatch): """Test creating a Google model.""" monkeypatch.setenv("GOOGLE_API_KEY", "dummy") config = LLMConfig( provider="google", model="gemini-1.5-pro", temperature=0.7 ) model = get_llm_model(config) assert isinstance(model, ChatGoogleGenerativeAI) assert model.model == "gemini-1.5-pro" assert model.temperature == 0.7 def test_unsupported_provider(): """Test that an unsupported provider raises an error.""" config = LLMConfig(provider="unknown", model="test") with pytest.raises(ValueError, match="Unsupported LLM provider: unknown"): get_llm_model(config) def test_provider_specific_params(monkeypatch): """Test passing provider specific params.""" monkeypatch.setenv("OPENAI_API_KEY", "dummy") config = LLMConfig( provider="openai", model="o1-preview", provider_specific={"reasoning_effort": "high"} ) # Note: reasoning_effort support depends on the langchain-openai version, # but we check if kwargs are passed. model = get_llm_model(config) assert isinstance(model, ChatOpenAI) # Check if reasoning_effort was passed correctly assert getattr(model, "reasoning_effort", None) == "high"