20 lines
644 B
Python
20 lines
644 B
Python
import pytest
|
|
from langchain_openai import ChatOpenAI
|
|
from ea_chatbot.config import LLMConfig
|
|
from ea_chatbot.utils.llm_factory import get_llm_model
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
|
|
class MockHandler(BaseCallbackHandler):
|
|
pass
|
|
|
|
def test_get_llm_model_with_callbacks(monkeypatch):
|
|
"""Test that callbacks are passed to the model."""
|
|
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
|
|
config = LLMConfig(provider="openai", model="gpt-4o")
|
|
handler = MockHandler()
|
|
|
|
model = get_llm_model(config, callbacks=[handler])
|
|
|
|
assert isinstance(model, ChatOpenAI)
|
|
assert handler in model.callbacks
|