import pytest from fastapi.testclient import TestClient from unittest.mock import MagicMock, patch, AsyncMock from ea_chatbot.api.main import app from ea_chatbot.api.dependencies import get_current_user from ea_chatbot.history.models import User from ea_chatbot.api.utils import create_access_token client = TestClient(app) @pytest.fixture def mock_user(): return User(id="user-123", username="test@example.com", display_name="Test User") @pytest.fixture def auth_header(mock_user): app.dependency_overrides[get_current_user] = lambda: mock_user token = create_access_token(data={"sub": mock_user.username, "user_id": mock_user.id}) yield {"Authorization": f"Bearer {token}"} app.dependency_overrides.clear() def test_stream_agent_unauthorized(): """Test that streaming requires authentication.""" response = client.post("/chat/stream", json={"message": "hello"}) assert response.status_code == 401 @pytest.mark.asyncio async def test_stream_agent_success(auth_header, mock_user): """Test successful agent streaming with SSE.""" # We need to mock the LangGraph app.astream_events mock_events = [ {"event": "on_chat_model_start", "name": "gpt-5", "data": {"input": "..."}}, {"event": "on_chat_model_stream", "name": "gpt-5", "data": {"chunk": "Hello"}}, {"event": "on_chain_end", "name": "agent", "data": {"output": "..."}} ] async def mock_astream_events(*args, **kwargs): for event in mock_events: yield event with patch("ea_chatbot.api.routers.agent.app.astream_events", side_effect=mock_astream_events), \ patch("ea_chatbot.api.routers.agent.get_checkpointer") as mock_cp, \ patch("ea_chatbot.api.routers.agent.history_manager") as mock_hm: mock_cp.return_value.__aenter__.return_value = AsyncMock() # Mock session and DB objects mock_session = MagicMock() mock_hm.get_session.return_value.__enter__.return_value = mock_session from ea_chatbot.history.models import Conversation mock_conv = Conversation(id="t1", user_id=mock_user.id) mock_session.get.return_value = mock_conv # Using TestClient with a stream context with client.stream("POST", "/chat/stream", json={"message": "hello", "thread_id": "t1"}, headers=auth_header) as response: assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] lines = list(response.iter_lines()) # Each event should start with 'data: ' and be valid JSON data_lines = [line for line in lines if line.startswith("data: ")] assert len(data_lines) >= len(mock_events)