import logging import pytest import io from unittest.mock import MagicMock from ea_chatbot.utils.logging import LangChainLoggingHandler @pytest.fixture def log_capture(): """Fixture to capture logs from a logger.""" log_stream = io.StringIO() logger = logging.getLogger("test_langchain") logger.setLevel(logging.INFO) # Remove existing handlers for handler in logger.handlers[:]: logger.removeHandler(handler) handler = logging.StreamHandler(log_stream) logger.addHandler(handler) return logger, log_stream def test_langchain_logging_handler_on_llm_start(log_capture): """Test that on_llm_start logs the correct message.""" logger, log_stream = log_capture handler = LangChainLoggingHandler(logger=logger) handler.on_llm_start(serialized={"name": "test_model"}, prompts=["test prompt"]) output = log_stream.getvalue() assert "LLM Started:" in output assert "test_model" in output def test_langchain_logging_handler_on_llm_end(log_capture): """Test that on_llm_end logs token usage.""" logger, log_stream = log_capture handler = LangChainLoggingHandler(logger=logger) response = MagicMock() response.llm_output = { "token_usage": { "prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30 }, "model_name": "test_model" } handler.on_llm_end(response=response) output = log_stream.getvalue() assert "LLM Ended:" in output assert "test_model" in output assert "Tokens: 30" in output assert "10 prompt" in output assert "20 completion" in output def test_langchain_logging_handler_on_llm_error(log_capture): """Test that on_llm_error logs the error.""" logger, log_stream = log_capture handler = LangChainLoggingHandler(logger=logger) error = Exception("test error") handler.on_llm_error(error=error) output = log_stream.getvalue() assert "LLM Error:" in output assert "test error" in output