65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
import logging
|
|
import pytest
|
|
import io
|
|
from unittest.mock import MagicMock
|
|
from ea_chatbot.utils.logging import LangChainLoggingHandler
|
|
|
|
@pytest.fixture
|
|
def log_capture():
|
|
"""Fixture to capture logs from a logger."""
|
|
log_stream = io.StringIO()
|
|
logger = logging.getLogger("test_langchain")
|
|
logger.setLevel(logging.INFO)
|
|
# Remove existing handlers
|
|
for handler in logger.handlers[:]:
|
|
logger.removeHandler(handler)
|
|
|
|
handler = logging.StreamHandler(log_stream)
|
|
logger.addHandler(handler)
|
|
return logger, log_stream
|
|
|
|
def test_langchain_logging_handler_on_llm_start(log_capture):
|
|
"""Test that on_llm_start logs the correct message."""
|
|
logger, log_stream = log_capture
|
|
handler = LangChainLoggingHandler(logger=logger)
|
|
handler.on_llm_start(serialized={"name": "test_model"}, prompts=["test prompt"])
|
|
|
|
output = log_stream.getvalue()
|
|
assert "LLM Started:" in output
|
|
assert "test_model" in output
|
|
|
|
def test_langchain_logging_handler_on_llm_end(log_capture):
|
|
"""Test that on_llm_end logs token usage."""
|
|
logger, log_stream = log_capture
|
|
handler = LangChainLoggingHandler(logger=logger)
|
|
response = MagicMock()
|
|
response.llm_output = {
|
|
"token_usage": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 20,
|
|
"total_tokens": 30
|
|
},
|
|
"model_name": "test_model"
|
|
}
|
|
|
|
handler.on_llm_end(response=response)
|
|
|
|
output = log_stream.getvalue()
|
|
assert "LLM Ended:" in output
|
|
assert "test_model" in output
|
|
assert "Tokens: 30" in output
|
|
assert "10 prompt" in output
|
|
assert "20 completion" in output
|
|
|
|
def test_langchain_logging_handler_on_llm_error(log_capture):
|
|
"""Test that on_llm_error logs the error."""
|
|
logger, log_stream = log_capture
|
|
handler = LangChainLoggingHandler(logger=logger)
|
|
error = Exception("test error")
|
|
|
|
handler.on_llm_error(error=error)
|
|
|
|
output = log_stream.getvalue()
|
|
assert "LLM Error:" in output
|
|
assert "test error" in output
|