langchain/tests/unit_tests/chains/test_llm_summarization_checker.py
Ankush Gola d3ec00b566
Callbacks Refactor [base] (#3256)
Co-authored-by: Nuno Campos <nuno@boringbits.io>
Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com>
Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
2023-04-30 11:14:09 -07:00

47 lines
1.3 KiB
Python

# flake8: noqa E501
"""Test LLMSummarization functionality."""
import pytest
from langchain.chains.llm_summarization_checker.base import (
ARE_ALL_TRUE_PROMPT,
CHECK_ASSERTIONS_PROMPT,
CREATE_ASSERTIONS_PROMPT,
REVISED_SUMMARY_PROMPT,
LLMSummarizationCheckerChain,
)
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture
def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain:
"""Fake LLMCheckerChain for testing."""
queries = {
CREATE_ASSERTIONS_PROMPT.format(
summary="a",
): "b",
CHECK_ASSERTIONS_PROMPT.format(
assertions="b",
): "- b - True",
REVISED_SUMMARY_PROMPT.format(
checked_assertions="- b - True", summary="a"
): "b",
ARE_ALL_TRUE_PROMPT.format(
checked_assertions="- b - True",
): "True",
}
fake_llm = FakeLLM(queries=queries)
return LLMSummarizationCheckerChain.from_llm(
fake_llm, input_key="q", output_key="a"
)
def test_simple_text(
fake_llm_summarization_checker_chain: LLMSummarizationCheckerChain,
) -> None:
"""Test simple question that should not need python."""
question = "a"
output = fake_llm_summarization_checker_chain.run(question)
assert output == "b"