From ec66d5188c768423e5b90113e3807d7a01ce84b8 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Thu, 6 Jul 2023 06:37:04 -0700 Subject: [PATCH] Add Better Errors for Comparison Chain (#7033) + change to ABC - this lets us add things like the evaluation name for loading --- .../agents/trajectory_eval_chain.py | 9 +- langchain/evaluation/comparison/eval_chain.py | 54 +++-- langchain/evaluation/criteria/eval_chain.py | 44 ++++- langchain/evaluation/qa/eval_chain.py | 29 ++- langchain/evaluation/schema.py | 184 +++++++++++++++++- .../evaluation/agents/test_eval_chain.py | 35 +++- .../evaluation/comparison/test_eval_chain.py | 32 ++- .../evaluation/criteria/test_eval_chain.py | 24 ++- .../evaluation/qa/test_eval_chain.py | 2 +- 9 files changed, 358 insertions(+), 55 deletions(-) diff --git a/langchain/evaluation/agents/trajectory_eval_chain.py b/langchain/evaluation/agents/trajectory_eval_chain.py index 6d0f07e40f..546f5d3b8e 100644 --- a/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/langchain/evaluation/agents/trajectory_eval_chain.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from pydantic import Field +from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -186,10 +187,11 @@ The following is the expected answer. Use this to measure correctness: @classmethod def from_llm( cls, - llm: BaseChatModel, + llm: BaseLanguageModel, agent_tools: Optional[Sequence[BaseTool]] = None, output_parser: Optional[TrajectoryOutputParser] = None, return_reasoning: bool = False, + **kwargs: Any, ) -> "TrajectoryEvalChain": """Create a TrajectoryEvalChain object from a language model chain. @@ -205,6 +207,10 @@ The following is the expected answer. Use this to measure correctness: Returns: TrajectoryEvalChain: The TrajectoryEvalChain object. """ + if not isinstance(llm, BaseChatModel): + raise NotImplementedError( + "Only chat models supported by the current trajectory eval" + ) if agent_tools: prompt = EVAL_CHAT_PROMPT else: @@ -215,6 +221,7 @@ The following is the expected answer. Use this to measure correctness: return_reasoning=return_reasoning, eval_chain=eval_chain, output_parser=output_parser or TrajectoryOutputParser(), + **kwargs, ) @property diff --git a/langchain/evaluation/comparison/eval_chain.py b/langchain/evaluation/comparison/eval_chain.py index 3c3f4f666b..7022004e61 100644 --- a/langchain/evaluation/comparison/eval_chain.py +++ b/langchain/evaluation/comparison/eval_chain.py @@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import Callbacks from langchain.chains.llm import LLMChain from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE +from langchain.evaluation.schema import PairwiseStringEvaluator from langchain.prompts.prompt import PromptTemplate from langchain.schema import BaseOutputParser @@ -50,7 +51,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]): } -class PairwiseStringEvalChain(LLMChain): +class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMChain): """A chain for comparing the output of two models. Example: @@ -80,13 +81,31 @@ class PairwiseStringEvalChain(LLMChain): default_factory=PairwiseStringResultOutputParser ) + @property + def requires_reference(self) -> bool: + return "reference" in self.prompt.input_variables + + @property + def requires_input(self) -> bool: + return True + + @property + def _skip_reference_warning(self) -> str: + """Warning to show when reference is ignored.""" + return ( + f"Ignoring reference in {self.__class__.__name__}, as it is not expected." + "\nTo use a reference, initialize PairwiseStringEvalChain with" + " `requires_reference=True` or with a prompt with 'reference' as an" + " input variable." + ) + @classmethod def from_llm( cls, - *, llm: BaseLanguageModel, + *, prompt: Optional[PromptTemplate] = None, - require_reference: bool = False, + requires_reference: bool = False, **kwargs: Any, ) -> PairwiseStringEvalChain: """Initialize the PairwiseStringEvalChain from an LLM. @@ -94,7 +113,7 @@ class PairwiseStringEvalChain(LLMChain): Args: llm (BaseLanguageModel): The LLM to use. prompt (PromptTemplate, optional): The prompt to use. - require_reference (bool, optional): Whether to require a reference + requires_reference (bool, optional): Whether to require a reference string. Defaults to False. **kwargs (Any): Additional keyword arguments. @@ -103,13 +122,13 @@ class PairwiseStringEvalChain(LLMChain): """ expected_input_vars = {"prediction", "prediction_b", "input"} if prompt is None: - if require_reference: + if requires_reference: expected_input_vars.add("reference") prompt_ = PROMPT_WITH_REFERENCE else: prompt_ = PROMPT else: - if require_reference: + if requires_reference: expected_input_vars.add("reference") prompt_ = prompt @@ -121,23 +140,32 @@ class PairwiseStringEvalChain(LLMChain): return cls(llm=llm, prompt=prompt_, **kwargs) def _prepare_input( - self, prediction: str, prediction_b: str, input: str, reference: Optional[str] + self, + prediction: str, + prediction_b: str, + input: Optional[str], + reference: Optional[str], ) -> dict: input_ = { "prediction": prediction, "prediction_b": prediction_b, - "input": input, } - if reference is not None and "reference" in self.prompt.input_variables: + if self.requires_input: + if not input: + raise ValueError("Input is require for this comparison evaluator") + input_["input"] = input + if self.requires_reference: + if reference is None: + raise ValueError("Reference is required for this comparison evaluator") input_["reference"] = reference return input_ - def evaluate_string_pairs( + def _evaluate_string_pairs( self, *, prediction: str, prediction_b: str, - input: str, + input: Optional[str] = None, reference: Optional[str] = None, callbacks: Callbacks = None, **kwargs: Any, @@ -168,13 +196,13 @@ class PairwiseStringEvalChain(LLMChain): ) return result["text"] - async def aevaluate_string_pairs( + async def _aevaluate_string_pairs( self, *, prediction: str, prediction_b: str, - input: str, reference: Optional[str] = None, + input: Optional[str] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> dict: diff --git a/langchain/evaluation/criteria/eval_chain.py b/langchain/evaluation/criteria/eval_chain.py index c7d6a1f7c4..067fb6c543 100644 --- a/langchain/evaluation/criteria/eval_chain.py +++ b/langchain/evaluation/criteria/eval_chain.py @@ -2,12 +2,13 @@ from __future__ import annotations from typing import Any, Dict, List, Mapping, Optional, Sequence, Union -from pydantic import Field +from pydantic import Extra, Field from langchain.base_language import BaseLanguageModel from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES +from langchain.evaluation.schema import StringEvaluator from langchain.schema import BaseOutputParser, BasePromptTemplate _SUPPORTED_CRITERIA = { @@ -59,7 +60,7 @@ CRITERIA_TYPE = Union[ ] -class CriteriaEvalChain(LLMChain): +class CriteriaEvalChain(StringEvaluator, LLMChain): """LLM Chain for evaluating runs against criteria. Parameters @@ -96,11 +97,32 @@ class CriteriaEvalChain(LLMChain): >>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria) """ - requires_reference: bool = False - """Whether the evaluation template expects a reference text.""" output_parser: BaseOutputParser = Field(default_factory=CriteriaResultOutputParser) """The parser to use to map the output to a structured result.""" + class Config: + """Configuration for the QAEvalChain.""" + + extra = Extra.ignore + + @property + def requires_reference(self) -> bool: + return "reference" in self.prompt.input_variables + + @property + def requires_input(self) -> bool: + return True + + @property + def _skip_reference_warning(self) -> str: + """Warning to show when reference is ignored.""" + return ( + f"Ignoring reference in {self.__class__.__name__}, as it is not expected." + "\nTo use a reference, initialize CriteriaEvalChain with" + " `require_reference=True` or with a prompt with 'reference'" + " as an input variable." + ) + @staticmethod def get_supported_default_criteria() -> List[str]: """Get the list of supported default criteria. @@ -122,7 +144,7 @@ class CriteriaEvalChain(LLMChain): @classmethod def resolve_criteria( cls, - criteria: CRITERIA_TYPE, + criteria: Optional[CRITERIA_TYPE], ) -> Dict[str, str]: """Resolve the criteria to evaluate. @@ -148,6 +170,10 @@ class CriteriaEvalChain(LLMChain): {'relevance': 'Is the submission referring to a real quote from the text?', 'coherence': 'Is the submission coherent, well-structured, and organized?'} """ # noqa: E501 + if criteria is None: + return { + "helpfulness": _SUPPORTED_CRITERIA["helpfulness"], + } if isinstance(criteria, str): criteria_ = {criteria: _SUPPORTED_CRITERIA[criteria]} elif isinstance(criteria, ConstitutionalPrinciple): @@ -172,7 +198,7 @@ class CriteriaEvalChain(LLMChain): def from_llm( cls, llm: BaseLanguageModel, - criteria: CRITERIA_TYPE, + criteria: Optional[CRITERIA_TYPE] = None, *, prompt: Optional[BasePromptTemplate] = None, requires_reference: bool = False, @@ -184,7 +210,7 @@ class CriteriaEvalChain(LLMChain): ---------- llm : BaseLanguageModel The language model to use for evaluation. - criteria : CRITERIA_TYPE + criteria : CRITERIA_TYPE - default=None for "helpfulness" The criteria to evaluate the runs against. It can be: - a mapping of criterion names to descriptions - a sequence of criterion names @@ -252,7 +278,7 @@ class CriteriaEvalChain(LLMChain): input_["reference"] = reference return input_ - def evaluate_strings( + def _evaluate_strings( self, *, prediction: str, @@ -296,7 +322,7 @@ class CriteriaEvalChain(LLMChain): input_ = self._get_eval_input(prediction, reference, input) return self(input_, **kwargs)["text"] - async def aevaluate_strings( + async def _aevaluate_strings( self, *, prediction: str, diff --git a/langchain/evaluation/qa/eval_chain.py b/langchain/evaluation/qa/eval_chain.py index bba4c29826..3a13ec8e11 100644 --- a/langchain/evaluation/qa/eval_chain.py +++ b/langchain/evaluation/qa/eval_chain.py @@ -8,6 +8,7 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import Callbacks from langchain.chains.llm import LLMChain from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT +from langchain.evaluation.schema import StringEvaluator def _parse_string_eval_output(text: str) -> dict: @@ -38,9 +39,17 @@ def _parse_string_eval_output(text: str) -> dict: } -class QAEvalChain(LLMChain): +class QAEvalChain(LLMChain, StringEvaluator): """LLM Chain specifically for evaluating question answering.""" + @property + def requires_reference(self) -> bool: + return True + + @property + def requires_input(self) -> bool: + return True + @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any @@ -90,7 +99,7 @@ class QAEvalChain(LLMChain): return self.apply(inputs, callbacks=callbacks) - def evaluate_strings( + def _evaluate_strings( self, *, prediction: str, @@ -118,7 +127,7 @@ class QAEvalChain(LLMChain): )[0] return _parse_string_eval_output(result["text"]) - async def aevaluate_strings( + async def _aevaluate_strings( self, *, prediction: str, @@ -134,9 +143,17 @@ class QAEvalChain(LLMChain): return _parse_string_eval_output(result["text"]) -class ContextQAEvalChain(LLMChain): +class ContextQAEvalChain(LLMChain, StringEvaluator): """LLM Chain specifically for evaluating QA w/o GT based on context""" + @property + def requires_reference(self) -> bool: + return True + + @property + def requires_input(self) -> bool: + return True + @classmethod def _validate_input_vars(cls, prompt: PromptTemplate) -> None: expected_input_vars = {"query", "context", "result"} @@ -193,7 +210,7 @@ class ContextQAEvalChain(LLMChain): return self.apply(inputs, callbacks=callbacks) - def evaluate_strings( + def _evaluate_strings( self, *, prediction: str, @@ -208,7 +225,7 @@ class ContextQAEvalChain(LLMChain): )[0] return _parse_string_eval_output(result["text"]) - async def aevaluate_strings( + async def _aevaluate_strings( self, *, prediction: str, diff --git a/langchain/evaluation/schema.py b/langchain/evaluation/schema.py index 8c3362088e..4bcfc51307 100644 --- a/langchain/evaluation/schema.py +++ b/langchain/evaluation/schema.py @@ -1,14 +1,63 @@ """Interfaces to be implemented by general evaluators.""" -from abc import abstractmethod -from typing import Any, Optional, Protocol, runtime_checkable +from __future__ import annotations +import logging +from abc import ABC, abstractmethod +from typing import Any, Optional +from warnings import warn -@runtime_checkable -class StringEvaluator(Protocol): +logger = logging.getLogger(__name__) + + +class _EvalArgsMixin: + """Mixin for checking evaluation arguments.""" + + @property + def requires_reference(self) -> bool: + """Whether this evaluator requires a reference label.""" + return False + + @property + def requires_input(self) -> bool: + """Whether this evaluator requires an input string.""" + return False + + @property + def _skip_input_warning(self) -> str: + """Warning to show when input is ignored.""" + return f"Ignoring input in {self.__class__.__name__}, as it is not expected." + + @property + def _skip_reference_warning(self) -> str: + """Warning to show when reference is ignored.""" + return ( + f"Ignoring reference in {self.__class__.__name__}, as it is not expected." + ) + + def _check_evaluation_args( + self, + reference: Optional[str] = None, + input: Optional[str] = None, + ) -> None: + if self.requires_input and input is None: + raise ValueError(f"{self.__class__.__name__} requires an input string.") + elif input is not None and not self.requires_input: + warn(self._skip_input_warning) + else: + pass + if self.requires_reference and reference is None: + raise ValueError(f"{self.__class__.__name__} requires a reference string.") + elif reference is not None and not self.requires_reference: + warn(self._skip_reference_warning) + else: + pass + + +class StringEvaluator(_EvalArgsMixin, ABC): """Protocol for evaluating strings.""" @abstractmethod - def evaluate_strings( + def _evaluate_strings( self, *, prediction: str, @@ -28,7 +77,7 @@ class StringEvaluator(Protocol): dict: The evaluation results containing the score or value. """ - async def aevaluate_strings( + async def _aevaluate_strings( self, *, prediction: str, @@ -53,13 +102,61 @@ class StringEvaluator(Protocol): "async aevaluate_strings method." ) + def evaluate_strings( + self, + *, + prediction: str, + reference: Optional[str] = None, + input: Optional[str] = None, + **kwargs: Any, + ) -> dict: + """Evaluate Chain or LLM output, based on optional input and label. + + Args: + prediction (str): the LLM or chain prediction to evaluate. + reference (Optional[str], optional): the reference label + to evaluate against. + input (Optional[str], optional): the input to consider during evaluation + **kwargs: additional keyword arguments, including callbacks, tags, etc. + Returns: + dict: The evaluation results containing the score or value. + """ + self._check_evaluation_args(reference=reference, input=input) + return self._evaluate_strings( + prediction=prediction, reference=reference, input=input, **kwargs + ) + + async def aevaluate_strings( + self, + *, + prediction: str, + reference: Optional[str] = None, + input: Optional[str] = None, + **kwargs: Any, + ) -> dict: + """Asynchronously evaluate Chain or LLM output, based on optional + input and label. + + Args: + prediction (str): the LLM or chain prediction to evaluate. + reference (Optional[str], optional): the reference label + to evaluate against. + input (Optional[str], optional): the input to consider during evaluation + **kwargs: additional keyword arguments, including callbacks, tags, etc. + Returns: + dict: The evaluation results containing the score or value. + """ + self._check_evaluation_args(reference=reference, input=input) + return await self._aevaluate_strings( + prediction=prediction, reference=reference, input=input, **kwargs + ) -@runtime_checkable -class PairwiseStringEvaluator(Protocol): + +class PairwiseStringEvaluator(_EvalArgsMixin, ABC): """A protocol for comparing the output of two models.""" @abstractmethod - def evaluate_string_pairs( + def _evaluate_string_pairs( self, *, prediction: str, @@ -84,8 +181,9 @@ class PairwiseStringEvaluator(Protocol): other information. """ - async def aevaluate_string_pairs( + async def _aevaluate_string_pairs( self, + *, prediction: str, prediction_b: str, reference: Optional[str] = None, @@ -111,3 +209,69 @@ class PairwiseStringEvaluator(Protocol): f"{self.__class__.__name__} hasn't implemented an async " "aevaluate_string_pairs method." ) + + def evaluate_string_pairs( + self, + *, + prediction: str, + prediction_b: str, + reference: Optional[str] = None, + input: Optional[str] = None, + **kwargs: Any, + ) -> dict: + """Evaluate the output string pairs. + + Args: + prediction (str): The output string from the first model. + prediction_b (str): The output string from the second model. + reference (str, optional): The expected output / reference + string. Defaults to None. + input (str, optional): The input string. Defaults to None. + **kwargs (Any): Additional keyword arguments, such + as callbacks and optional reference strings. + + Returns: + dict: A dictionary containing the preference, scores, and/or + other information. + """ + self._check_evaluation_args(reference=reference, input=input) + return self._evaluate_string_pairs( + prediction=prediction, + prediction_b=prediction_b, + reference=reference, + input=input, + **kwargs, + ) + + async def aevaluate_string_pairs( + self, + *, + prediction: str, + prediction_b: str, + reference: Optional[str] = None, + input: Optional[str] = None, + **kwargs: Any, + ) -> dict: + """Evaluate the output string pairs. + + Args: + prediction (str): The output string from the first model. + prediction_b (str): The output string from the second model. + reference (str, optional): The expected output / reference + string. Defaults to None. + input (str, optional): The input string. Defaults to None. + **kwargs (Any): Additional keyword arguments, such + as callbacks and optional reference strings. + + Returns: + dict: A dictionary containing the preference, scores, and/or + other information. + """ + self._check_evaluation_args(reference=reference, input=input) + return await self._aevaluate_string_pairs( + prediction=prediction, + prediction_b=prediction_b, + reference=reference, + input=input, + **kwargs, + ) diff --git a/tests/unit_tests/evaluation/agents/test_eval_chain.py b/tests/unit_tests/evaluation/agents/test_eval_chain.py index 59fa3de017..c8c84ae574 100644 --- a/tests/unit_tests/evaluation/agents/test_eval_chain.py +++ b/tests/unit_tests/evaluation/agents/test_eval_chain.py @@ -1,13 +1,15 @@ """Test agent trajectory evaluation chain.""" -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple import pytest +from pydantic import Field +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain -from langchain.schema import AgentAction +from langchain.schema import AgentAction, BaseMessage from langchain.tools.base import tool -from tests.unit_tests.llms.fake_llm import FakeLLM +from tests.unit_tests.llms.fake_chat_model import FakeChatModel @pytest.fixture @@ -30,10 +32,31 @@ def foo(bar: str) -> str: return bar +class _FakeTrajectoryChatModel(FakeChatModel): + queries: Dict = Field(default_factory=dict) + sequential_responses: Optional[bool] = False + response_index: int = 0 + + def _call( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + if self.sequential_responses: + response = self.queries[list(self.queries.keys())[self.response_index]] + self.response_index = self.response_index + 1 + return response + else: + prompt = messages[0].content + return self.queries[prompt] + + def test_trajectory_eval_chain( intermediate_steps: List[Tuple[AgentAction, str]] ) -> None: - llm = FakeLLM( + llm = _FakeTrajectoryChatModel( queries={ "a": "Trajectory good\nScore: 5", "b": "Trajectory not good\nScore: 1", @@ -61,7 +84,7 @@ def test_trajectory_eval_chain( def test_trajectory_eval_chain_no_tools( intermediate_steps: List[Tuple[AgentAction, str]] ) -> None: - llm = FakeLLM( + llm = _FakeTrajectoryChatModel( queries={ "a": "Trajectory good\nScore: 5", "b": "Trajectory not good\nScore: 1", @@ -85,7 +108,7 @@ def test_trajectory_eval_chain_no_tools( def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> None: - llm = FakeLLM( + llm = _FakeTrajectoryChatModel( queries={ "a": "Trajectory good\nScore: 5", "b": "Trajectory not good\nScore: 1", diff --git a/tests/unit_tests/evaluation/comparison/test_eval_chain.py b/tests/unit_tests/evaluation/comparison/test_eval_chain.py index 4a96b43e18..4eb2508a9e 100644 --- a/tests/unit_tests/evaluation/comparison/test_eval_chain.py +++ b/tests/unit_tests/evaluation/comparison/test_eval_chain.py @@ -1,6 +1,8 @@ """Test the comparison chains.""" +import pytest + from langchain.evaluation.comparison.eval_chain import PairwiseStringEvalChain from tests.unit_tests.llms.fake_llm import FakeLLM @@ -30,10 +32,30 @@ def test_pairwise_string_comparison_chain() -> None: ) assert res["value"] == "A" assert res["score"] == 1 - res = chain.evaluate_string_pairs( - prediction="I like pie.", - prediction_b="I hate pie.", - input="What is your favorite food?", - ) + with pytest.warns(UserWarning, match=chain._skip_reference_warning): + res = chain.evaluate_string_pairs( + prediction="I like pie.", + prediction_b="I hate pie.", + input="What is your favorite food?", + reference="I enjoy pie.", + ) assert res["value"] == "B" assert res["score"] == 0 + + +def test_pairwise_string_comparison_chain_missing_ref() -> None: + llm = FakeLLM( + queries={ + "a": "The values are the same.\n[[C]]", + "b": "A is clearly better than b.\n[[A]]", + "c": "B is clearly better than a.\n[[B]]", + }, + sequential_responses=True, + ) + chain = PairwiseStringEvalChain.from_llm(llm=llm, requires_reference=True) + with pytest.raises(ValueError): + chain.evaluate_string_pairs( + prediction="I like pie.", + prediction_b="I love pie.", + input="What is your favorite food?", + ) diff --git a/tests/unit_tests/evaluation/criteria/test_eval_chain.py b/tests/unit_tests/evaluation/criteria/test_eval_chain.py index f978fa70e7..56a892cecc 100644 --- a/tests/unit_tests/evaluation/criteria/test_eval_chain.py +++ b/tests/unit_tests/evaluation/criteria/test_eval_chain.py @@ -1,6 +1,8 @@ """Test the criteria eval chain.""" +import pytest + from langchain.evaluation.criteria.eval_chain import ( _SUPPORTED_CRITERIA, CriteriaEvalChain, @@ -25,11 +27,25 @@ def test_criteria_eval_chain() -> None: ), criteria={"my criterion": "my criterion description"}, ) - result = chain.evaluate_strings( - prediction="my prediction", reference="my reference", input="my input" - ) + with pytest.warns(UserWarning, match=chain._skip_reference_warning): + result = chain.evaluate_strings( + prediction="my prediction", reference="my reference", input="my input" + ) assert result["reasoning"] == "The meaning of life" +def test_criteria_eval_chain_missing_reference() -> None: + chain = CriteriaEvalChain.from_llm( + llm=FakeLLM( + queries={"text": "The meaning of life\nY"}, + sequential_responses=True, + ), + requires_reference=True, + criteria={"my criterion": "my criterion description"}, + ) + with pytest.raises(ValueError): + chain.evaluate_strings(prediction="my prediction", input="my input") + + def test_implements_string_protocol() -> None: - assert isinstance(CriteriaEvalChain, StringEvaluator) + assert issubclass(CriteriaEvalChain, StringEvaluator) diff --git a/tests/unit_tests/evaluation/qa/test_eval_chain.py b/tests/unit_tests/evaluation/qa/test_eval_chain.py index 514fd28757..acdad692e3 100644 --- a/tests/unit_tests/evaluation/qa/test_eval_chain.py +++ b/tests/unit_tests/evaluation/qa/test_eval_chain.py @@ -52,7 +52,7 @@ def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None: def test_implements_string_evaluator_protocol( chain_cls: Type[LLMChain], ) -> None: - assert isinstance(chain_cls, StringEvaluator) + assert issubclass(chain_cls, StringEvaluator) @pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])