Refactor: Move backend files to backend/ directory and split .gitignore
This commit is contained in:
54
backend/tests/test_llm_factory.py
Normal file
54
backend/tests/test_llm_factory.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ea_chatbot.config import LLMConfig
|
||||
from ea_chatbot.utils.llm_factory import get_llm_model
|
||||
|
||||
def test_get_openai_model(monkeypatch):
|
||||
"""Test creating an OpenAI model."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
temperature=0.5,
|
||||
max_tokens=100
|
||||
)
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
assert model.model_name == "gpt-4o"
|
||||
assert model.temperature == 0.5
|
||||
assert model.max_tokens == 100
|
||||
|
||||
def test_get_google_model(monkeypatch):
|
||||
"""Test creating a Google model."""
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="google",
|
||||
model="gemini-1.5-pro",
|
||||
temperature=0.7
|
||||
)
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatGoogleGenerativeAI)
|
||||
assert model.model == "gemini-1.5-pro"
|
||||
assert model.temperature == 0.7
|
||||
|
||||
def test_unsupported_provider():
|
||||
"""Test that an unsupported provider raises an error."""
|
||||
config = LLMConfig(provider="unknown", model="test")
|
||||
with pytest.raises(ValueError, match="Unsupported LLM provider: unknown"):
|
||||
get_llm_model(config)
|
||||
|
||||
def test_provider_specific_params(monkeypatch):
|
||||
"""Test passing provider specific params."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
||||
config = LLMConfig(
|
||||
provider="openai",
|
||||
model="o1-preview",
|
||||
provider_specific={"reasoning_effort": "high"}
|
||||
)
|
||||
# Note: reasoning_effort support depends on the langchain-openai version,
|
||||
# but we check if kwargs are passed.
|
||||
model = get_llm_model(config)
|
||||
assert isinstance(model, ChatOpenAI)
|
||||
# Check if reasoning_effort was passed correctly
|
||||
assert getattr(model, "reasoning_effort", None) == "high"
|
||||
Reference in New Issue
Block a user