165 lines
5.5 KiB
Python
165 lines
5.5 KiB
Python
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional, Any, List
|
|
from jose import JWTError, jwt
|
|
from pydantic import BaseModel
|
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
|
from ea_chatbot.config import Settings
|
|
from ea_chatbot.history.models import Message
|
|
|
|
settings = Settings()
|
|
|
|
def map_db_messages_to_langchain(db_messages: List[Message]) -> List[BaseMessage] :
|
|
"""
|
|
Converts a list of database Message models to LangChain BaseMessage objects.
|
|
|
|
Args:
|
|
db_messages: List of Message objects from the database.
|
|
|
|
Returns:
|
|
List of HumanMessage, AIMessage, or SystemMessage objects.
|
|
"""
|
|
lc_messages: List[BaseMessage] = []
|
|
|
|
for m in db_messages:
|
|
role = m.role.lower()
|
|
if role == "user":
|
|
lc_messages.append(HumanMessage(content=m.content))
|
|
elif role == "assistant":
|
|
lc_messages.append(AIMessage(content=m.content))
|
|
elif role == "system":
|
|
lc_messages.append(SystemMessage(content=m.content))
|
|
else:
|
|
# Default to HumanMessage for unknown roles
|
|
lc_messages.append(HumanMessage(content=m.content))
|
|
|
|
return lc_messages
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""
|
|
Create a JWT access token.
|
|
|
|
Args:
|
|
data: The payload data to encode.
|
|
expires_delta: Optional expiration time delta.
|
|
|
|
Returns:
|
|
str: The encoded JWT token.
|
|
"""
|
|
to_encode = data.copy()
|
|
now = datetime.now(timezone.utc)
|
|
|
|
if expires_delta:
|
|
expire = now + expires_delta
|
|
else:
|
|
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"iat": now,
|
|
"iss": "ea-chatbot-api",
|
|
"type": "access",
|
|
"jti": str(uuid.uuid4())
|
|
})
|
|
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
|
return encoded_jwt
|
|
|
|
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""
|
|
Create a JWT refresh token.
|
|
|
|
Args:
|
|
data: The payload data to encode.
|
|
expires_delta: Optional expiration time delta.
|
|
|
|
Returns:
|
|
str: The encoded JWT token.
|
|
"""
|
|
to_encode = data.copy()
|
|
now = datetime.now(timezone.utc)
|
|
|
|
if expires_delta:
|
|
expire = now + expires_delta
|
|
else:
|
|
expire = now + timedelta(days=settings.refresh_token_expire_days)
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"iat": now,
|
|
"iss": "ea-chatbot-api",
|
|
"type": "refresh",
|
|
"jti": str(uuid.uuid4())
|
|
})
|
|
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
|
return encoded_jwt
|
|
|
|
def decode_access_token(token: str) -> Optional[dict]:
|
|
"""
|
|
Decode a JWT access token.
|
|
|
|
Args:
|
|
token: The token to decode.
|
|
|
|
Returns:
|
|
Optional[dict]: The decoded payload if valid, None otherwise.
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
def convert_to_json_compatible(obj: Any) -> Any:
|
|
"""Recursively convert LangChain objects, Pydantic models, and others to JSON compatible formats."""
|
|
# Handle known non-serializable types first to avoid recursion
|
|
type_name = type(obj).__name__
|
|
if type_name == "Figure" or type_name == "DataFrame":
|
|
return f"<{type_name} object>"
|
|
|
|
if isinstance(obj, list):
|
|
return [convert_to_json_compatible(item) for item in obj]
|
|
elif isinstance(obj, dict):
|
|
return {k: convert_to_json_compatible(v) for k, v in obj.items()}
|
|
elif isinstance(obj, BaseMessage):
|
|
# Handle content that might be a list of blocks (e.g. from Gemini or OpenAI tools)
|
|
content = obj.content
|
|
if isinstance(content, list):
|
|
text_parts = []
|
|
for block in content:
|
|
if isinstance(block, str):
|
|
text_parts.append(block)
|
|
elif isinstance(block, dict):
|
|
if block.get("type") == "text":
|
|
text_parts.append(block.get("text", ""))
|
|
# You could also handle other block types if needed
|
|
content = "".join(text_parts)
|
|
|
|
# Prefer .text property if available (common in some message types)
|
|
if hasattr(obj, "text") and isinstance(obj.text, str) and obj.text:
|
|
content = obj.text
|
|
|
|
return {"type": obj.type, "content": content, **convert_to_json_compatible(obj.additional_kwargs)}
|
|
elif isinstance(obj, BaseModel):
|
|
return convert_to_json_compatible(obj.model_dump())
|
|
elif hasattr(obj, "model_dump"): # For Pydantic v2 if not caught by BaseModel
|
|
try:
|
|
return convert_to_json_compatible(obj.model_dump())
|
|
except Exception:
|
|
return str(obj)
|
|
elif hasattr(obj, "dict"): # Fallback for Pydantic v1 or other objects
|
|
try:
|
|
return convert_to_json_compatible(obj.dict())
|
|
except Exception:
|
|
return str(obj)
|
|
elif hasattr(obj, "content"):
|
|
return str(obj.content)
|
|
elif isinstance(obj, (datetime, timezone)):
|
|
return obj.isoformat()
|
|
|
|
# Final fallback for any other types that might not be JSON serializable
|
|
import json
|
|
try:
|
|
json.dumps(obj)
|
|
return obj
|
|
except (TypeError, OverflowError):
|
|
return str(obj) |