langchain/tests/unit_tests/evaluation/qa/test_eval_chain.py
William FH a673a51efa
[Breaking] Update Evaluation Functionality (#7388)
- Migrate from deprecated langchainplus_sdk to `langsmith` package
- Update the `run_on_dataset()` API to use an eval config
- Update a number of evaluators, as well as the loading logic
- Update docstrings / reference docs
- Update tracer to share single HTTP session
2023-07-13 02:13:06 -07:00

70 lines
2.3 KiB
Python

"""Test LLM Bash functionality."""
import sys
from typing import Type
import pytest
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_chain import (
ContextQAEvalChain,
CotQAEvalChain,
QAEvalChain,
)
from langchain.evaluation.schema import StringEvaluator
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_eval_chain() -> None:
"""Test a simple eval chain."""
example = {"query": "What's my name", "answer": "John Doe"}
prediction = {"result": "John Doe"}
fake_qa_eval_chain = QAEvalChain.from_llm(FakeLLM())
outputs = fake_qa_eval_chain.evaluate([example, example], [prediction, prediction])
assert outputs[0] == outputs[1]
assert fake_qa_eval_chain.output_key in outputs[0]
assert outputs[0][fake_qa_eval_chain.output_key] == "foo"
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
@pytest.mark.parametrize("chain_cls", [ContextQAEvalChain, CotQAEvalChain])
def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None:
"""Test a simple eval chain."""
example = {
"query": "What's my name",
"context": "The name of this person is John Doe",
}
prediction = {"result": "John Doe"}
fake_qa_eval_chain = chain_cls.from_llm(FakeLLM())
outputs = fake_qa_eval_chain.evaluate([example, example], [prediction, prediction])
assert outputs[0] == outputs[1]
assert "text" in outputs[0]
assert outputs[0]["text"] == "foo"
@pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])
def test_implements_string_evaluator_protocol(
chain_cls: Type[LLMChain],
) -> None:
assert issubclass(chain_cls, StringEvaluator)
@pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])
def test_returns_expected_results(
chain_cls: Type[LLMChain],
) -> None:
fake_llm = FakeLLM(
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
)
chain = chain_cls.from_llm(fake_llm) # type: ignore
results = chain.evaluate_strings(
prediction="my prediction", reference="my reference", input="my input"
)
assert results["score"] == 1