116 lines
4.4 KiB
Python
116 lines
4.4 KiB
Python
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from unittest.mock import MagicMock, patch
|
|
from ea_chatbot.api.main import app
|
|
from ea_chatbot.api.dependencies import get_current_user
|
|
from ea_chatbot.history.models import Conversation, Message, Plot, User
|
|
from ea_chatbot.api.utils import create_access_token
|
|
from datetime import datetime, timezone
|
|
|
|
client = TestClient(app)
|
|
|
|
@pytest.fixture
|
|
def mock_user():
|
|
user = User(id="user-123", username="test@example.com", display_name="Test User")
|
|
return user
|
|
|
|
@pytest.fixture
|
|
def auth_header(mock_user):
|
|
# Override get_current_user to return our mock user
|
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
|
token = create_access_token(data={"sub": mock_user.id})
|
|
yield {"Authorization": f"Bearer {token}"}
|
|
app.dependency_overrides.clear()
|
|
|
|
def test_get_conversations_success(auth_header, mock_user):
|
|
"""Test retrieving list of conversations."""
|
|
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
|
mock_hm.get_conversations.return_value = [
|
|
Conversation(
|
|
id="c1",
|
|
name="Conv 1",
|
|
user_id=mock_user.id,
|
|
data_state="nj",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
]
|
|
|
|
response = client.get("/conversations", headers=auth_header)
|
|
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 1
|
|
assert response.json()[0]["name"] == "Conv 1"
|
|
|
|
def test_create_conversation_success(auth_header, mock_user):
|
|
"""Test creating a new conversation."""
|
|
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
|
mock_hm.create_conversation.return_value = Conversation(
|
|
id="c2",
|
|
name="New Conv",
|
|
user_id=mock_user.id,
|
|
data_state="nj",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
response = client.post(
|
|
"/conversations",
|
|
json={"name": "New Conv"},
|
|
headers=auth_header
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
assert response.json()["name"] == "New Conv"
|
|
assert response.json()["id"] == "c2"
|
|
|
|
def test_get_messages_success(auth_header):
|
|
"""Test retrieving messages for a conversation."""
|
|
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
|
mock_hm.get_messages.return_value = [
|
|
Message(
|
|
id="m1",
|
|
role="user",
|
|
content="Hello",
|
|
conversation_id="c1",
|
|
created_at=datetime.now(timezone.utc)
|
|
)
|
|
]
|
|
|
|
response = client.get("/conversations/c1/messages", headers=auth_header)
|
|
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 1
|
|
assert response.json()[0]["content"] == "Hello"
|
|
|
|
def test_delete_conversation_success(auth_header):
|
|
"""Test deleting a conversation."""
|
|
with patch("ea_chatbot.api.routers.history.history_manager") as mock_hm:
|
|
mock_hm.delete_conversation.return_value = True
|
|
|
|
response = client.delete("/conversations/c1", headers=auth_header)
|
|
assert response.status_code == 204
|
|
|
|
def test_get_plot_success(auth_header, mock_user):
|
|
"""Test retrieving a plot artifact."""
|
|
with patch("ea_chatbot.api.routers.artifacts.history_manager") as mock_hm:
|
|
# Mocking finding a plot by ID
|
|
mock_session = MagicMock()
|
|
mock_hm.get_session.return_value.__enter__.return_value = mock_session
|
|
|
|
# Mocking the models and their relationships
|
|
mock_conv = Conversation(id="c1", user_id=mock_user.id, user=mock_user)
|
|
mock_msg = Message(id="m1", conversation_id="c1", conversation=mock_conv)
|
|
mock_plot = Plot(id="p1", image_data=b"fake-image-data", message_id="m1", message=mock_msg)
|
|
|
|
def mock_get(model, id):
|
|
if model == Plot: return mock_plot
|
|
if model == Message: return mock_msg
|
|
if model == Conversation: return mock_conv
|
|
return None
|
|
|
|
mock_session.get.side_effect = mock_get
|
|
|
|
response = client.get("/artifacts/plots/p1", headers=auth_header)
|
|
|
|
assert response.status_code == 200
|
|
assert response.content == b"fake-image-data"
|
|
assert response.headers["content-type"] == "image/png" |