Files
ea-chatbot-lg/src/ea_chatbot/utils/llm_factory.py

37 lines
1.5 KiB
Python

from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.callbacks import BaseCallbackHandler
from ea_chatbot.config import LLMConfig
def get_llm_model(config: LLMConfig, callbacks: Optional[List[BaseCallbackHandler]] = None) -> BaseChatModel:
"""
Factory function to get a LangChain chat model based on configuration.
Args:
config: LLMConfig object containing model settings.
callbacks: Optional list of LangChain callback handlers.
Returns:
Initialized BaseChatModel instance.
Raises:
ValueError: If the provider is not supported.
"""
params = {
"temperature": config.temperature,
"max_tokens": config.max_tokens,
**config.provider_specific
}
# Filter out None values to allow defaults to take over if not specified
params = {k: v for k, v in params.items() if v is not None}
if config.provider.lower() == "openai":
return ChatOpenAI(model=config.model, callbacks=callbacks, **params)
elif config.provider.lower() == "google" or config.provider.lower() == "google_genai":
return ChatGoogleGenerativeAI(model=config.model, callbacks=callbacks, **params)
else:
raise ValueError(f"Unsupported LLM provider: {config.provider}")