37 lines
1.5 KiB
Python
37 lines
1.5 KiB
Python
from typing import Optional, cast, TYPE_CHECKING, Literal, Dict, List, Tuple, Any
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from ea_chatbot.config import LLMConfig
|
|
|
|
def get_llm_model(config: LLMConfig, callbacks: Optional[List[BaseCallbackHandler]] = None) -> BaseChatModel:
|
|
"""
|
|
Factory function to get a LangChain chat model based on configuration.
|
|
|
|
Args:
|
|
config: LLMConfig object containing model settings.
|
|
callbacks: Optional list of LangChain callback handlers.
|
|
|
|
Returns:
|
|
Initialized BaseChatModel instance.
|
|
|
|
Raises:
|
|
ValueError: If the provider is not supported.
|
|
"""
|
|
params = {
|
|
"temperature": config.temperature,
|
|
"max_tokens": config.max_tokens,
|
|
**config.provider_specific
|
|
}
|
|
|
|
# Filter out None values to allow defaults to take over if not specified
|
|
params = {k: v for k, v in params.items() if v is not None}
|
|
|
|
if config.provider.lower() == "openai":
|
|
return ChatOpenAI(model=config.model, callbacks=callbacks, **params)
|
|
elif config.provider.lower() == "google" or config.provider.lower() == "google_genai":
|
|
return ChatGoogleGenerativeAI(model=config.model, callbacks=callbacks, **params)
|
|
else:
|
|
raise ValueError(f"Unsupported LLM provider: {config.provider}")
|