diff --git a/backend/src/ea_chatbot/app.py b/backend/dev_app.py similarity index 100% rename from backend/src/ea_chatbot/app.py rename to backend/dev_app.py diff --git a/backend/src/ea_chatbot/api/routers/agent.py b/backend/src/ea_chatbot/api/routers/agent.py index a39b868..788a87d 100644 --- a/backend/src/ea_chatbot/api/routers/agent.py +++ b/backend/src/ea_chatbot/api/routers/agent.py @@ -3,14 +3,12 @@ from typing import AsyncGenerator, List from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from ea_chatbot.api.dependencies import get_current_user, history_manager -from ea_chatbot.api.utils import convert_to_json_compatible +from ea_chatbot.api.utils import convert_to_json_compatible, map_db_messages_to_langchain from ea_chatbot.graph.workflow import app from ea_chatbot.graph.checkpoint import get_checkpointer from ea_chatbot.history.models import User as UserDB, Conversation -from ea_chatbot.history.utils import map_db_messages_to_langchain from ea_chatbot.api.schemas import ChatRequest from ea_chatbot.utils.plots import fig_to_bytes -import io import base64 from langchain_core.runnables.config import RunnableConfig from langchain_core.messages import BaseMessage diff --git a/backend/src/ea_chatbot/api/utils.py b/backend/src/ea_chatbot/api/utils.py index d96801c..46949a2 100644 --- a/backend/src/ea_chatbot/api/utils.py +++ b/backend/src/ea_chatbot/api/utils.py @@ -1,12 +1,39 @@ from datetime import datetime, timedelta, timezone -from typing import Optional, Any +from typing import Optional, Any, List from jose import JWTError, jwt from pydantic import BaseModel -from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage from ea_chatbot.config import Settings +from ea_chatbot.history.models import Message settings = Settings() +def map_db_messages_to_langchain(db_messages: List[Message]) -> List[BaseMessage] : + """ + Converts a list of database Message models to LangChain BaseMessage objects. + + Args: + db_messages: List of Message objects from the database. + + Returns: + List of HumanMessage, AIMessage, or SystemMessage objects. + """ + lc_messages: List[BaseMessage] = [] + + for m in db_messages: + role = m.role.lower() + if role == "user": + lc_messages.append(HumanMessage(content=m.content)) + elif role == "assistant": + lc_messages.append(AIMessage(content=m.content)) + elif role == "system": + lc_messages.append(SystemMessage(content=m.content)) + else: + # Default to HumanMessage for unknown roles + lc_messages.append(HumanMessage(content=m.content)) + + return lc_messages + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """ Create a JWT access token. diff --git a/backend/src/ea_chatbot/history/utils.py b/backend/src/ea_chatbot/history/utils.py deleted file mode 100644 index 3d3400b..0000000 --- a/backend/src/ea_chatbot/history/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Utility functions for history management.""" -from typing import List -from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage -from ea_chatbot.history.models import Message - -def map_db_messages_to_langchain(db_messages: List[Message]) -> List[BaseMessage] : - """ - Converts a list of database Message models to LangChain BaseMessage objects. - - Args: - db_messages: List of Message objects from the database. - - Returns: - List of HumanMessage, AIMessage, or SystemMessage objects. - """ - lc_messages: List[BaseMessage] = [] - - for m in db_messages: - role = m.role.lower() - if role == "user": - lc_messages.append(HumanMessage(content=m.content)) - elif role == "assistant": - lc_messages.append(AIMessage(content=m.content)) - elif role == "system": - lc_messages.append(SystemMessage(content=m.content)) - else: - # Default to HumanMessage for unknown roles - lc_messages.append(HumanMessage(content=m.content)) - - return lc_messages diff --git a/backend/src/ea_chatbot/utils/plots.py b/backend/src/ea_chatbot/utils/plots.py index 5677a56..584f238 100644 --- a/backend/src/ea_chatbot/utils/plots.py +++ b/backend/src/ea_chatbot/utils/plots.py @@ -3,15 +3,25 @@ import matplotlib.pyplot as plt def fig_to_bytes(fig: plt.Figure, format: str = "png") -> bytes: """ - Convert a Matplotlib figure to bytes. + Convert a Matplotlib figure to a bytes object in the specified format. + This utility encapsulates the boilerplate logic for saving a figure to an + in-memory buffer and retrieving its binary content. + Args: - fig: The Matplotlib figure to convert. - format: The image format to use (default: "png"). + fig: The Matplotlib Figure object to be converted. + format: The image format string (e.g., "png", "jpg", "svg"). + Defaults to "png". Returns: - bytes: The image data. + The binary image data as a bytes object. + + Raises: + ValueError: If an unsupported image format is provided. """ buf = io.BytesIO() - fig.savefig(buf, format=format) + try: + fig.savefig(buf, format=format) + except Exception as e: + raise ValueError(f"Failed to convert figure to {format}: {str(e)}") from e return buf.getvalue() diff --git a/backend/tests/test_history_mapping.py b/backend/tests/test_history_mapping.py index 35bf45e..9c2c2de 100644 --- a/backend/tests/test_history_mapping.py +++ b/backend/tests/test_history_mapping.py @@ -1,6 +1,6 @@ import pytest from ea_chatbot.history.models import Message -from ea_chatbot.history.utils import map_db_messages_to_langchain +from ea_chatbot.api.utils import map_db_messages_to_langchain from langchain_core.messages import HumanMessage, AIMessage def test_map_db_messages_to_langchain():