Files
ea-chatbot-lg/backend/src/ea_chatbot/api/utils.py

165 lines
5.5 KiB
Python

import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional, Any, List
from jose import JWTError, jwt
from pydantic import BaseModel
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from ea_chatbot.config import Settings
from ea_chatbot.history.models import Message
settings = Settings()
def map_db_messages_to_langchain(db_messages: List[Message]) -> List[BaseMessage] :
"""
Converts a list of database Message models to LangChain BaseMessage objects.
Args:
db_messages: List of Message objects from the database.
Returns:
List of HumanMessage, AIMessage, or SystemMessage objects.
"""
lc_messages: List[BaseMessage] = []
for m in db_messages:
role = m.role.lower()
if role == "user":
lc_messages.append(HumanMessage(content=m.content))
elif role == "assistant":
lc_messages.append(AIMessage(content=m.content))
elif role == "system":
lc_messages.append(SystemMessage(content=m.content))
else:
# Default to HumanMessage for unknown roles
lc_messages.append(HumanMessage(content=m.content))
return lc_messages
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token.
Args:
data: The payload data to encode.
expires_delta: Optional expiration time delta.
Returns:
str: The encoded JWT token.
"""
to_encode = data.copy()
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
to_encode.update({
"exp": expire,
"iat": now,
"iss": "ea-chatbot-api",
"type": "access",
"jti": str(uuid.uuid4())
})
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT refresh token.
Args:
data: The payload data to encode.
expires_delta: Optional expiration time delta.
Returns:
str: The encoded JWT token.
"""
to_encode = data.copy()
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(days=settings.refresh_token_expire_days)
to_encode.update({
"exp": expire,
"iat": now,
"iss": "ea-chatbot-api",
"type": "refresh",
"jti": str(uuid.uuid4())
})
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt
def decode_access_token(token: str) -> Optional[dict]:
"""
Decode a JWT access token.
Args:
token: The token to decode.
Returns:
Optional[dict]: The decoded payload if valid, None otherwise.
"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
return payload
except JWTError:
return None
def convert_to_json_compatible(obj: Any) -> Any:
"""Recursively convert LangChain objects, Pydantic models, and others to JSON compatible formats."""
# Handle known non-serializable types first to avoid recursion
type_name = type(obj).__name__
if type_name == "Figure" or type_name == "DataFrame":
return f"<{type_name} object>"
if isinstance(obj, list):
return [convert_to_json_compatible(item) for item in obj]
elif isinstance(obj, dict):
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
elif isinstance(obj, BaseMessage):
# Handle content that might be a list of blocks (e.g. from Gemini or OpenAI tools)
content = obj.content
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, str):
text_parts.append(block)
elif isinstance(block, dict):
if block.get("type") == "text":
text_parts.append(block.get("text", ""))
# You could also handle other block types if needed
content = "".join(text_parts)
# Prefer .text property if available (common in some message types)
if hasattr(obj, "text") and isinstance(obj.text, str) and obj.text:
content = obj.text
return {"type": obj.type, "content": content, **convert_to_json_compatible(obj.additional_kwargs)}
elif isinstance(obj, BaseModel):
return convert_to_json_compatible(obj.model_dump())
elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel
try:
return convert_to_json_compatible(obj.model_dump())
except Exception:
return str(obj)
elif hasattr(obj, "dict"): # Fallback for Pydantic v1 or other objects
try:
return convert_to_json_compatible(obj.dict())
except Exception:
return str(obj)
elif hasattr(obj, "content"):
return str(obj.content)
elif isinstance(obj, (datetime, timezone)):
return obj.isoformat()
# Final fallback for any other types that might not be JSON serializable
import json
try:
json.dumps(obj)
return obj
except (TypeError, OverflowError):
return str(obj)