diff --git a/tests/api/test_agent.py b/tests/api/test_agent.py new file mode 100644 index 0000000..8eb0ea7 --- /dev/null +++ b/tests/api/test_agent.py @@ -0,0 +1,56 @@ +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch, AsyncMock +from ea_chatbot.api.main import app +from ea_chatbot.api.dependencies import get_current_user +from ea_chatbot.history.models import User +from ea_chatbot.api.utils import create_access_token + +client = TestClient(app) + +@pytest.fixture +def mock_user(): + return User(id="user-123", username="test@example.com", display_name="Test User") + +@pytest.fixture +def auth_header(mock_user): + app.dependency_overrides[get_current_user] = lambda: mock_user + token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id}) + yield {"Authorization": f"Bearer {token}"} + app.dependency_overrides.clear() + +def test_stream_agent_unauthorized(): + """Test that streaming requires authentication.""" + response = client.post("/chat/stream", json={"message": "hello"}) + assert response.status_code == 401 + +@pytest.mark.asyncio +async def test_stream_agent_success(auth_header): + """Test successful agent streaming with SSE.""" + # We need to mock the LangGraph app.astream_events + mock_events = [ + {"event": "on_chat_model_start", "name": "gpt-5", "data": {"input": "..."}}, + {"event": "on_chat_model_stream", "name": "gpt-5", "data": {"chunk": "Hello"}}, + {"event": "on_chain_end", "name": "agent", "data": {"output": "..."}} + ] + + async def mock_astream_events(*args, **kwargs): + for event in mock_events: + yield event + + with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \ + patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp: + + mock_cp.return_value.__aenter__.return_value = AsyncMock() + + # Using TestClient with a stream context + with client.stream("POST", "/chat/stream", + json={"message": "hello", "thread_id": "t1"}, + headers=auth_header) as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream" + + lines = list(response.iter_lines()) + # Each event should start with 'data: ' and be valid JSON + data_lines = [line for line in lines if line.startswith("data: ")] + assert len(data_lines) >= len(mock_events)