Add Better Errors for Comparison Chain (#7033)

+ change to ABC - this lets us add things like the evaluation name for
loading
This commit is contained in:
William FH 2023-07-06 06:37:04 -07:00 committed by GitHub
parent e61cfb6e99
commit ec66d5188c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 358 additions and 55 deletions

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@ -186,10 +187,11 @@ The following is the expected answer. Use this to measure correctness:
@classmethod
def from_llm(
cls,
llm: BaseChatModel,
llm: BaseLanguageModel,
agent_tools: Optional[Sequence[BaseTool]] = None,
output_parser: Optional[TrajectoryOutputParser] = None,
return_reasoning: bool = False,
**kwargs: Any,
) -> "TrajectoryEvalChain":
"""Create a TrajectoryEvalChain object from a language model chain.
@ -205,6 +207,10 @@ The following is the expected answer. Use this to measure correctness:
Returns:
TrajectoryEvalChain: The TrajectoryEvalChain object.
"""
if not isinstance(llm, BaseChatModel):
raise NotImplementedError(
"Only chat models supported by the current trajectory eval"
)
if agent_tools:
prompt = EVAL_CHAT_PROMPT
else:
@ -215,6 +221,7 @@ The following is the expected answer. Use this to measure correctness:
return_reasoning=return_reasoning,
eval_chain=eval_chain,
output_parser=output_parser or TrajectoryOutputParser(),
**kwargs,
)
@property

View File

@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain
from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE
from langchain.evaluation.schema import PairwiseStringEvaluator
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
@ -50,7 +51,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]):
}
class PairwiseStringEvalChain(LLMChain):
class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMChain):
"""A chain for comparing the output of two models.
Example:
@ -80,13 +81,31 @@ class PairwiseStringEvalChain(LLMChain):
default_factory=PairwiseStringResultOutputParser
)
@property
def requires_reference(self) -> bool:
return "reference" in self.prompt.input_variables
@property
def requires_input(self) -> bool:
return True
@property
def _skip_reference_warning(self) -> str:
"""Warning to show when reference is ignored."""
return (
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
"\nTo use a reference, initialize PairwiseStringEvalChain with"
" `requires_reference=True` or with a prompt with 'reference' as an"
" input variable."
)
@classmethod
def from_llm(
cls,
*,
llm: BaseLanguageModel,
*,
prompt: Optional[PromptTemplate] = None,
require_reference: bool = False,
requires_reference: bool = False,
**kwargs: Any,
) -> PairwiseStringEvalChain:
"""Initialize the PairwiseStringEvalChain from an LLM.
@ -94,7 +113,7 @@ class PairwiseStringEvalChain(LLMChain):
Args:
llm (BaseLanguageModel): The LLM to use.
prompt (PromptTemplate, optional): The prompt to use.
require_reference (bool, optional): Whether to require a reference
requires_reference (bool, optional): Whether to require a reference
string. Defaults to False.
**kwargs (Any): Additional keyword arguments.
@ -103,13 +122,13 @@ class PairwiseStringEvalChain(LLMChain):
"""
expected_input_vars = {"prediction", "prediction_b", "input"}
if prompt is None:
if require_reference:
if requires_reference:
expected_input_vars.add("reference")
prompt_ = PROMPT_WITH_REFERENCE
else:
prompt_ = PROMPT
else:
if require_reference:
if requires_reference:
expected_input_vars.add("reference")
prompt_ = prompt
@ -121,23 +140,32 @@ class PairwiseStringEvalChain(LLMChain):
return cls(llm=llm, prompt=prompt_, **kwargs)
def _prepare_input(
self, prediction: str, prediction_b: str, input: str, reference: Optional[str]
self,
prediction: str,
prediction_b: str,
input: Optional[str],
reference: Optional[str],
) -> dict:
input_ = {
"prediction": prediction,
"prediction_b": prediction_b,
"input": input,
}
if reference is not None and "reference" in self.prompt.input_variables:
if self.requires_input:
if not input:
raise ValueError("Input is require for this comparison evaluator")
input_["input"] = input
if self.requires_reference:
if reference is None:
raise ValueError("Reference is required for this comparison evaluator")
input_["reference"] = reference
return input_
def evaluate_string_pairs(
def _evaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
input: str,
input: Optional[str] = None,
reference: Optional[str] = None,
callbacks: Callbacks = None,
**kwargs: Any,
@ -168,13 +196,13 @@ class PairwiseStringEvalChain(LLMChain):
)
return result["text"]
async def aevaluate_string_pairs(
async def _aevaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
input: str,
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> dict:

View File

@ -2,12 +2,13 @@ from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from pydantic import Field
from pydantic import Extra, Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
from langchain.evaluation.schema import StringEvaluator
from langchain.schema import BaseOutputParser, BasePromptTemplate
_SUPPORTED_CRITERIA = {
@ -59,7 +60,7 @@ CRITERIA_TYPE = Union[
]
class CriteriaEvalChain(LLMChain):
class CriteriaEvalChain(StringEvaluator, LLMChain):
"""LLM Chain for evaluating runs against criteria.
Parameters
@ -96,11 +97,32 @@ class CriteriaEvalChain(LLMChain):
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
"""
requires_reference: bool = False
"""Whether the evaluation template expects a reference text."""
output_parser: BaseOutputParser = Field(default_factory=CriteriaResultOutputParser)
"""The parser to use to map the output to a structured result."""
class Config:
"""Configuration for the QAEvalChain."""
extra = Extra.ignore
@property
def requires_reference(self) -> bool:
return "reference" in self.prompt.input_variables
@property
def requires_input(self) -> bool:
return True
@property
def _skip_reference_warning(self) -> str:
"""Warning to show when reference is ignored."""
return (
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
"\nTo use a reference, initialize CriteriaEvalChain with"
" `require_reference=True` or with a prompt with 'reference'"
" as an input variable."
)
@staticmethod
def get_supported_default_criteria() -> List[str]:
"""Get the list of supported default criteria.
@ -122,7 +144,7 @@ class CriteriaEvalChain(LLMChain):
@classmethod
def resolve_criteria(
cls,
criteria: CRITERIA_TYPE,
criteria: Optional[CRITERIA_TYPE],
) -> Dict[str, str]:
"""Resolve the criteria to evaluate.
@ -148,6 +170,10 @@ class CriteriaEvalChain(LLMChain):
{'relevance': 'Is the submission referring to a real quote from the text?',
'coherence': 'Is the submission coherent, well-structured, and organized?'}
""" # noqa: E501
if criteria is None:
return {
"helpfulness": _SUPPORTED_CRITERIA["helpfulness"],
}
if isinstance(criteria, str):
criteria_ = {criteria: _SUPPORTED_CRITERIA[criteria]}
elif isinstance(criteria, ConstitutionalPrinciple):
@ -172,7 +198,7 @@ class CriteriaEvalChain(LLMChain):
def from_llm(
cls,
llm: BaseLanguageModel,
criteria: CRITERIA_TYPE,
criteria: Optional[CRITERIA_TYPE] = None,
*,
prompt: Optional[BasePromptTemplate] = None,
requires_reference: bool = False,
@ -184,7 +210,7 @@ class CriteriaEvalChain(LLMChain):
----------
llm : BaseLanguageModel
The language model to use for evaluation.
criteria : CRITERIA_TYPE
criteria : CRITERIA_TYPE - default=None for "helpfulness"
The criteria to evaluate the runs against. It can be:
- a mapping of criterion names to descriptions
- a sequence of criterion names
@ -252,7 +278,7 @@ class CriteriaEvalChain(LLMChain):
input_["reference"] = reference
return input_
def evaluate_strings(
def _evaluate_strings(
self,
*,
prediction: str,
@ -296,7 +322,7 @@ class CriteriaEvalChain(LLMChain):
input_ = self._get_eval_input(prediction, reference, input)
return self(input_, **kwargs)["text"]
async def aevaluate_strings(
async def _aevaluate_strings(
self,
*,
prediction: str,

View File

@ -8,6 +8,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
from langchain.evaluation.schema import StringEvaluator
def _parse_string_eval_output(text: str) -> dict:
@ -38,9 +39,17 @@ def _parse_string_eval_output(text: str) -> dict:
}
class QAEvalChain(LLMChain):
class QAEvalChain(LLMChain, StringEvaluator):
"""LLM Chain specifically for evaluating question answering."""
@property
def requires_reference(self) -> bool:
return True
@property
def requires_input(self) -> bool:
return True
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any
@ -90,7 +99,7 @@ class QAEvalChain(LLMChain):
return self.apply(inputs, callbacks=callbacks)
def evaluate_strings(
def _evaluate_strings(
self,
*,
prediction: str,
@ -118,7 +127,7 @@ class QAEvalChain(LLMChain):
)[0]
return _parse_string_eval_output(result["text"])
async def aevaluate_strings(
async def _aevaluate_strings(
self,
*,
prediction: str,
@ -134,9 +143,17 @@ class QAEvalChain(LLMChain):
return _parse_string_eval_output(result["text"])
class ContextQAEvalChain(LLMChain):
class ContextQAEvalChain(LLMChain, StringEvaluator):
"""LLM Chain specifically for evaluating QA w/o GT based on context"""
@property
def requires_reference(self) -> bool:
return True
@property
def requires_input(self) -> bool:
return True
@classmethod
def _validate_input_vars(cls, prompt: PromptTemplate) -> None:
expected_input_vars = {"query", "context", "result"}
@ -193,7 +210,7 @@ class ContextQAEvalChain(LLMChain):
return self.apply(inputs, callbacks=callbacks)
def evaluate_strings(
def _evaluate_strings(
self,
*,
prediction: str,
@ -208,7 +225,7 @@ class ContextQAEvalChain(LLMChain):
)[0]
return _parse_string_eval_output(result["text"])
async def aevaluate_strings(
async def _aevaluate_strings(
self,
*,
prediction: str,

View File

@ -1,14 +1,63 @@
"""Interfaces to be implemented by general evaluators."""
from abc import abstractmethod
from typing import Any, Optional, Protocol, runtime_checkable
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any, Optional
from warnings import warn
logger = logging.getLogger(__name__)
@runtime_checkable
class StringEvaluator(Protocol):
class _EvalArgsMixin:
"""Mixin for checking evaluation arguments."""
@property
def requires_reference(self) -> bool:
"""Whether this evaluator requires a reference label."""
return False
@property
def requires_input(self) -> bool:
"""Whether this evaluator requires an input string."""
return False
@property
def _skip_input_warning(self) -> str:
"""Warning to show when input is ignored."""
return f"Ignoring input in {self.__class__.__name__}, as it is not expected."
@property
def _skip_reference_warning(self) -> str:
"""Warning to show when reference is ignored."""
return (
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
)
def _check_evaluation_args(
self,
reference: Optional[str] = None,
input: Optional[str] = None,
) -> None:
if self.requires_input and input is None:
raise ValueError(f"{self.__class__.__name__} requires an input string.")
elif input is not None and not self.requires_input:
warn(self._skip_input_warning)
else:
pass
if self.requires_reference and reference is None:
raise ValueError(f"{self.__class__.__name__} requires a reference string.")
elif reference is not None and not self.requires_reference:
warn(self._skip_reference_warning)
else:
pass
class StringEvaluator(_EvalArgsMixin, ABC):
"""Protocol for evaluating strings."""
@abstractmethod
def evaluate_strings(
def _evaluate_strings(
self,
*,
prediction: str,
@ -28,7 +77,7 @@ class StringEvaluator(Protocol):
dict: The evaluation results containing the score or value.
"""
async def aevaluate_strings(
async def _aevaluate_strings(
self,
*,
prediction: str,
@ -53,13 +102,61 @@ class StringEvaluator(Protocol):
"async aevaluate_strings method."
)
def evaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
**kwargs: Any,
) -> dict:
"""Evaluate Chain or LLM output, based on optional input and label.
@runtime_checkable
class PairwiseStringEvaluator(Protocol):
Args:
prediction (str): the LLM or chain prediction to evaluate.
reference (Optional[str], optional): the reference label
to evaluate against.
input (Optional[str], optional): the input to consider during evaluation
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
"""
self._check_evaluation_args(reference=reference, input=input)
return self._evaluate_strings(
prediction=prediction, reference=reference, input=input, **kwargs
)
async def aevaluate_strings(
self,
*,
prediction: str,
reference: Optional[str] = None,
input: Optional[str] = None,
**kwargs: Any,
) -> dict:
"""Asynchronously evaluate Chain or LLM output, based on optional
input and label.
Args:
prediction (str): the LLM or chain prediction to evaluate.
reference (Optional[str], optional): the reference label
to evaluate against.
input (Optional[str], optional): the input to consider during evaluation
**kwargs: additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
"""
self._check_evaluation_args(reference=reference, input=input)
return await self._aevaluate_strings(
prediction=prediction, reference=reference, input=input, **kwargs
)
class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
"""A protocol for comparing the output of two models."""
@abstractmethod
def evaluate_string_pairs(
def _evaluate_string_pairs(
self,
*,
prediction: str,
@ -84,8 +181,9 @@ class PairwiseStringEvaluator(Protocol):
other information.
"""
async def aevaluate_string_pairs(
async def _aevaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
@ -111,3 +209,69 @@ class PairwiseStringEvaluator(Protocol):
f"{self.__class__.__name__} hasn't implemented an async "
"aevaluate_string_pairs method."
)
def evaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
input: Optional[str] = None,
**kwargs: Any,
) -> dict:
"""Evaluate the output string pairs.
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
reference (str, optional): The expected output / reference
string. Defaults to None.
input (str, optional): The input string. Defaults to None.
**kwargs (Any): Additional keyword arguments, such
as callbacks and optional reference strings.
Returns:
dict: A dictionary containing the preference, scores, and/or
other information.
"""
self._check_evaluation_args(reference=reference, input=input)
return self._evaluate_string_pairs(
prediction=prediction,
prediction_b=prediction_b,
reference=reference,
input=input,
**kwargs,
)
async def aevaluate_string_pairs(
self,
*,
prediction: str,
prediction_b: str,
reference: Optional[str] = None,
input: Optional[str] = None,
**kwargs: Any,
) -> dict:
"""Evaluate the output string pairs.
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
reference (str, optional): The expected output / reference
string. Defaults to None.
input (str, optional): The input string. Defaults to None.
**kwargs (Any): Additional keyword arguments, such
as callbacks and optional reference strings.
Returns:
dict: A dictionary containing the preference, scores, and/or
other information.
"""
self._check_evaluation_args(reference=reference, input=input)
return await self._aevaluate_string_pairs(
prediction=prediction,
prediction_b=prediction_b,
reference=reference,
input=input,
**kwargs,
)

View File

@ -1,13 +1,15 @@
"""Test agent trajectory evaluation chain."""
from typing import List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import pytest
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
from langchain.schema import AgentAction
from langchain.schema import AgentAction, BaseMessage
from langchain.tools.base import tool
from tests.unit_tests.llms.fake_llm import FakeLLM
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
@pytest.fixture
@ -30,10 +32,31 @@ def foo(bar: str) -> str:
return bar
class _FakeTrajectoryChatModel(FakeChatModel):
queries: Dict = Field(default_factory=dict)
sequential_responses: Optional[bool] = False
response_index: int = 0
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.sequential_responses:
response = self.queries[list(self.queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response
else:
prompt = messages[0].content
return self.queries[prompt]
def test_trajectory_eval_chain(
intermediate_steps: List[Tuple[AgentAction, str]]
) -> None:
llm = FakeLLM(
llm = _FakeTrajectoryChatModel(
queries={
"a": "Trajectory good\nScore: 5",
"b": "Trajectory not good\nScore: 1",
@ -61,7 +84,7 @@ def test_trajectory_eval_chain(
def test_trajectory_eval_chain_no_tools(
intermediate_steps: List[Tuple[AgentAction, str]]
) -> None:
llm = FakeLLM(
llm = _FakeTrajectoryChatModel(
queries={
"a": "Trajectory good\nScore: 5",
"b": "Trajectory not good\nScore: 1",
@ -85,7 +108,7 @@ def test_trajectory_eval_chain_no_tools(
def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> None:
llm = FakeLLM(
llm = _FakeTrajectoryChatModel(
queries={
"a": "Trajectory good\nScore: 5",
"b": "Trajectory not good\nScore: 1",

View File

@ -1,6 +1,8 @@
"""Test the comparison chains."""
import pytest
from langchain.evaluation.comparison.eval_chain import PairwiseStringEvalChain
from tests.unit_tests.llms.fake_llm import FakeLLM
@ -30,10 +32,30 @@ def test_pairwise_string_comparison_chain() -> None:
)
assert res["value"] == "A"
assert res["score"] == 1
with pytest.warns(UserWarning, match=chain._skip_reference_warning):
res = chain.evaluate_string_pairs(
prediction="I like pie.",
prediction_b="I hate pie.",
input="What is your favorite food?",
reference="I enjoy pie.",
)
assert res["value"] == "B"
assert res["score"] == 0
def test_pairwise_string_comparison_chain_missing_ref() -> None:
llm = FakeLLM(
queries={
"a": "The values are the same.\n[[C]]",
"b": "A is clearly better than b.\n[[A]]",
"c": "B is clearly better than a.\n[[B]]",
},
sequential_responses=True,
)
chain = PairwiseStringEvalChain.from_llm(llm=llm, requires_reference=True)
with pytest.raises(ValueError):
chain.evaluate_string_pairs(
prediction="I like pie.",
prediction_b="I love pie.",
input="What is your favorite food?",
)

View File

@ -1,6 +1,8 @@
"""Test the criteria eval chain."""
import pytest
from langchain.evaluation.criteria.eval_chain import (
_SUPPORTED_CRITERIA,
CriteriaEvalChain,
@ -25,11 +27,25 @@ def test_criteria_eval_chain() -> None:
),
criteria={"my criterion": "my criterion description"},
)
with pytest.warns(UserWarning, match=chain._skip_reference_warning):
result = chain.evaluate_strings(
prediction="my prediction", reference="my reference", input="my input"
)
assert result["reasoning"] == "The meaning of life"
def test_criteria_eval_chain_missing_reference() -> None:
chain = CriteriaEvalChain.from_llm(
llm=FakeLLM(
queries={"text": "The meaning of life\nY"},
sequential_responses=True,
),
requires_reference=True,
criteria={"my criterion": "my criterion description"},
)
with pytest.raises(ValueError):
chain.evaluate_strings(prediction="my prediction", input="my input")
def test_implements_string_protocol() -> None:
assert isinstance(CriteriaEvalChain, StringEvaluator)
assert issubclass(CriteriaEvalChain, StringEvaluator)

View File

@ -52,7 +52,7 @@ def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None:
def test_implements_string_evaluator_protocol(
chain_cls: Type[LLMChain],
) -> None:
assert isinstance(chain_cls, StringEvaluator)
assert issubclass(chain_cls, StringEvaluator)
@pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])