mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add Better Errors for Comparison Chain (#7033)
+ change to ABC - this lets us add things like the evaluation name for loading
This commit is contained in:
parent
e61cfb6e99
commit
ec66d5188c
@ -9,6 +9,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
@ -186,10 +187,11 @@ The following is the expected answer. Use this to measure correctness:
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseChatModel,
|
||||
llm: BaseLanguageModel,
|
||||
agent_tools: Optional[Sequence[BaseTool]] = None,
|
||||
output_parser: Optional[TrajectoryOutputParser] = None,
|
||||
return_reasoning: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "TrajectoryEvalChain":
|
||||
"""Create a TrajectoryEvalChain object from a language model chain.
|
||||
|
||||
@ -205,6 +207,10 @@ The following is the expected answer. Use this to measure correctness:
|
||||
Returns:
|
||||
TrajectoryEvalChain: The TrajectoryEvalChain object.
|
||||
"""
|
||||
if not isinstance(llm, BaseChatModel):
|
||||
raise NotImplementedError(
|
||||
"Only chat models supported by the current trajectory eval"
|
||||
)
|
||||
if agent_tools:
|
||||
prompt = EVAL_CHAT_PROMPT
|
||||
else:
|
||||
@ -215,6 +221,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
return_reasoning=return_reasoning,
|
||||
eval_chain=eval_chain,
|
||||
output_parser=output_parser or TrajectoryOutputParser(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE
|
||||
from langchain.evaluation.schema import PairwiseStringEvaluator
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
@ -50,7 +51,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]):
|
||||
}
|
||||
|
||||
|
||||
class PairwiseStringEvalChain(LLMChain):
|
||||
class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMChain):
|
||||
"""A chain for comparing the output of two models.
|
||||
|
||||
Example:
|
||||
@ -80,13 +81,31 @@ class PairwiseStringEvalChain(LLMChain):
|
||||
default_factory=PairwiseStringResultOutputParser
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
return "reference" in self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def requires_input(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def _skip_reference_warning(self) -> str:
|
||||
"""Warning to show when reference is ignored."""
|
||||
return (
|
||||
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
|
||||
"\nTo use a reference, initialize PairwiseStringEvalChain with"
|
||||
" `requires_reference=True` or with a prompt with 'reference' as an"
|
||||
" input variable."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
*,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
require_reference: bool = False,
|
||||
requires_reference: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> PairwiseStringEvalChain:
|
||||
"""Initialize the PairwiseStringEvalChain from an LLM.
|
||||
@ -94,7 +113,7 @@ class PairwiseStringEvalChain(LLMChain):
|
||||
Args:
|
||||
llm (BaseLanguageModel): The LLM to use.
|
||||
prompt (PromptTemplate, optional): The prompt to use.
|
||||
require_reference (bool, optional): Whether to require a reference
|
||||
requires_reference (bool, optional): Whether to require a reference
|
||||
string. Defaults to False.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
@ -103,13 +122,13 @@ class PairwiseStringEvalChain(LLMChain):
|
||||
"""
|
||||
expected_input_vars = {"prediction", "prediction_b", "input"}
|
||||
if prompt is None:
|
||||
if require_reference:
|
||||
if requires_reference:
|
||||
expected_input_vars.add("reference")
|
||||
prompt_ = PROMPT_WITH_REFERENCE
|
||||
else:
|
||||
prompt_ = PROMPT
|
||||
else:
|
||||
if require_reference:
|
||||
if requires_reference:
|
||||
expected_input_vars.add("reference")
|
||||
prompt_ = prompt
|
||||
|
||||
@ -121,23 +140,32 @@ class PairwiseStringEvalChain(LLMChain):
|
||||
return cls(llm=llm, prompt=prompt_, **kwargs)
|
||||
|
||||
def _prepare_input(
|
||||
self, prediction: str, prediction_b: str, input: str, reference: Optional[str]
|
||||
self,
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
input: Optional[str],
|
||||
reference: Optional[str],
|
||||
) -> dict:
|
||||
input_ = {
|
||||
"prediction": prediction,
|
||||
"prediction_b": prediction_b,
|
||||
"input": input,
|
||||
}
|
||||
if reference is not None and "reference" in self.prompt.input_variables:
|
||||
if self.requires_input:
|
||||
if not input:
|
||||
raise ValueError("Input is require for this comparison evaluator")
|
||||
input_["input"] = input
|
||||
if self.requires_reference:
|
||||
if reference is None:
|
||||
raise ValueError("Reference is required for this comparison evaluator")
|
||||
input_["reference"] = reference
|
||||
return input_
|
||||
|
||||
def evaluate_string_pairs(
|
||||
def _evaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
input: str,
|
||||
input: Optional[str] = None,
|
||||
reference: Optional[str] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
@ -168,13 +196,13 @@ class PairwiseStringEvalChain(LLMChain):
|
||||
)
|
||||
return result["text"]
|
||||
|
||||
async def aevaluate_string_pairs(
|
||||
async def _aevaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
input: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
|
@ -2,12 +2,13 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Extra, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
from langchain.schema import BaseOutputParser, BasePromptTemplate
|
||||
|
||||
_SUPPORTED_CRITERIA = {
|
||||
@ -59,7 +60,7 @@ CRITERIA_TYPE = Union[
|
||||
]
|
||||
|
||||
|
||||
class CriteriaEvalChain(LLMChain):
|
||||
class CriteriaEvalChain(StringEvaluator, LLMChain):
|
||||
"""LLM Chain for evaluating runs against criteria.
|
||||
|
||||
Parameters
|
||||
@ -96,11 +97,32 @@ class CriteriaEvalChain(LLMChain):
|
||||
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
|
||||
"""
|
||||
|
||||
requires_reference: bool = False
|
||||
"""Whether the evaluation template expects a reference text."""
|
||||
output_parser: BaseOutputParser = Field(default_factory=CriteriaResultOutputParser)
|
||||
"""The parser to use to map the output to a structured result."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for the QAEvalChain."""
|
||||
|
||||
extra = Extra.ignore
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
return "reference" in self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def requires_input(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def _skip_reference_warning(self) -> str:
|
||||
"""Warning to show when reference is ignored."""
|
||||
return (
|
||||
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
|
||||
"\nTo use a reference, initialize CriteriaEvalChain with"
|
||||
" `require_reference=True` or with a prompt with 'reference'"
|
||||
" as an input variable."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_default_criteria() -> List[str]:
|
||||
"""Get the list of supported default criteria.
|
||||
@ -122,7 +144,7 @@ class CriteriaEvalChain(LLMChain):
|
||||
@classmethod
|
||||
def resolve_criteria(
|
||||
cls,
|
||||
criteria: CRITERIA_TYPE,
|
||||
criteria: Optional[CRITERIA_TYPE],
|
||||
) -> Dict[str, str]:
|
||||
"""Resolve the criteria to evaluate.
|
||||
|
||||
@ -148,6 +170,10 @@ class CriteriaEvalChain(LLMChain):
|
||||
{'relevance': 'Is the submission referring to a real quote from the text?',
|
||||
'coherence': 'Is the submission coherent, well-structured, and organized?'}
|
||||
""" # noqa: E501
|
||||
if criteria is None:
|
||||
return {
|
||||
"helpfulness": _SUPPORTED_CRITERIA["helpfulness"],
|
||||
}
|
||||
if isinstance(criteria, str):
|
||||
criteria_ = {criteria: _SUPPORTED_CRITERIA[criteria]}
|
||||
elif isinstance(criteria, ConstitutionalPrinciple):
|
||||
@ -172,7 +198,7 @@ class CriteriaEvalChain(LLMChain):
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
criteria: CRITERIA_TYPE,
|
||||
criteria: Optional[CRITERIA_TYPE] = None,
|
||||
*,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
requires_reference: bool = False,
|
||||
@ -184,7 +210,7 @@ class CriteriaEvalChain(LLMChain):
|
||||
----------
|
||||
llm : BaseLanguageModel
|
||||
The language model to use for evaluation.
|
||||
criteria : CRITERIA_TYPE
|
||||
criteria : CRITERIA_TYPE - default=None for "helpfulness"
|
||||
The criteria to evaluate the runs against. It can be:
|
||||
- a mapping of criterion names to descriptions
|
||||
- a sequence of criterion names
|
||||
@ -252,7 +278,7 @@ class CriteriaEvalChain(LLMChain):
|
||||
input_["reference"] = reference
|
||||
return input_
|
||||
|
||||
def evaluate_strings(
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
@ -296,7 +322,7 @@ class CriteriaEvalChain(LLMChain):
|
||||
input_ = self._get_eval_input(prediction, reference, input)
|
||||
return self(input_, **kwargs)["text"]
|
||||
|
||||
async def aevaluate_strings(
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
|
@ -8,6 +8,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
|
||||
|
||||
def _parse_string_eval_output(text: str) -> dict:
|
||||
@ -38,9 +39,17 @@ def _parse_string_eval_output(text: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
class QAEvalChain(LLMChain):
|
||||
class QAEvalChain(LLMChain, StringEvaluator):
|
||||
"""LLM Chain specifically for evaluating question answering."""
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def requires_input(self) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any
|
||||
@ -90,7 +99,7 @@ class QAEvalChain(LLMChain):
|
||||
|
||||
return self.apply(inputs, callbacks=callbacks)
|
||||
|
||||
def evaluate_strings(
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
@ -118,7 +127,7 @@ class QAEvalChain(LLMChain):
|
||||
)[0]
|
||||
return _parse_string_eval_output(result["text"])
|
||||
|
||||
async def aevaluate_strings(
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
@ -134,9 +143,17 @@ class QAEvalChain(LLMChain):
|
||||
return _parse_string_eval_output(result["text"])
|
||||
|
||||
|
||||
class ContextQAEvalChain(LLMChain):
|
||||
class ContextQAEvalChain(LLMChain, StringEvaluator):
|
||||
"""LLM Chain specifically for evaluating QA w/o GT based on context"""
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def requires_input(self) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _validate_input_vars(cls, prompt: PromptTemplate) -> None:
|
||||
expected_input_vars = {"query", "context", "result"}
|
||||
@ -193,7 +210,7 @@ class ContextQAEvalChain(LLMChain):
|
||||
|
||||
return self.apply(inputs, callbacks=callbacks)
|
||||
|
||||
def evaluate_strings(
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
@ -208,7 +225,7 @@ class ContextQAEvalChain(LLMChain):
|
||||
)[0]
|
||||
return _parse_string_eval_output(result["text"])
|
||||
|
||||
async def aevaluate_strings(
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
|
@ -1,14 +1,63 @@
|
||||
"""Interfaces to be implemented by general evaluators."""
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional, Protocol, runtime_checkable
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from warnings import warn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class StringEvaluator(Protocol):
|
||||
class _EvalArgsMixin:
|
||||
"""Mixin for checking evaluation arguments."""
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
"""Whether this evaluator requires a reference label."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def requires_input(self) -> bool:
|
||||
"""Whether this evaluator requires an input string."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def _skip_input_warning(self) -> str:
|
||||
"""Warning to show when input is ignored."""
|
||||
return f"Ignoring input in {self.__class__.__name__}, as it is not expected."
|
||||
|
||||
@property
|
||||
def _skip_reference_warning(self) -> str:
|
||||
"""Warning to show when reference is ignored."""
|
||||
return (
|
||||
f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
|
||||
)
|
||||
|
||||
def _check_evaluation_args(
|
||||
self,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
) -> None:
|
||||
if self.requires_input and input is None:
|
||||
raise ValueError(f"{self.__class__.__name__} requires an input string.")
|
||||
elif input is not None and not self.requires_input:
|
||||
warn(self._skip_input_warning)
|
||||
else:
|
||||
pass
|
||||
if self.requires_reference and reference is None:
|
||||
raise ValueError(f"{self.__class__.__name__} requires a reference string.")
|
||||
elif reference is not None and not self.requires_reference:
|
||||
warn(self._skip_reference_warning)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""Protocol for evaluating strings."""
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_strings(
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
@ -28,7 +77,7 @@ class StringEvaluator(Protocol):
|
||||
dict: The evaluation results containing the score or value.
|
||||
"""
|
||||
|
||||
async def aevaluate_strings(
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
@ -53,13 +102,61 @@ class StringEvaluator(Protocol):
|
||||
"async aevaluate_strings method."
|
||||
)
|
||||
|
||||
def evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Evaluate Chain or LLM output, based on optional input and label.
|
||||
|
||||
@runtime_checkable
|
||||
class PairwiseStringEvaluator(Protocol):
|
||||
Args:
|
||||
prediction (str): the LLM or chain prediction to evaluate.
|
||||
reference (Optional[str], optional): the reference label
|
||||
to evaluate against.
|
||||
input (Optional[str], optional): the input to consider during evaluation
|
||||
**kwargs: additional keyword arguments, including callbacks, tags, etc.
|
||||
Returns:
|
||||
dict: The evaluation results containing the score or value.
|
||||
"""
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
return self._evaluate_strings(
|
||||
prediction=prediction, reference=reference, input=input, **kwargs
|
||||
)
|
||||
|
||||
async def aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Asynchronously evaluate Chain or LLM output, based on optional
|
||||
input and label.
|
||||
|
||||
Args:
|
||||
prediction (str): the LLM or chain prediction to evaluate.
|
||||
reference (Optional[str], optional): the reference label
|
||||
to evaluate against.
|
||||
input (Optional[str], optional): the input to consider during evaluation
|
||||
**kwargs: additional keyword arguments, including callbacks, tags, etc.
|
||||
Returns:
|
||||
dict: The evaluation results containing the score or value.
|
||||
"""
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
return await self._aevaluate_strings(
|
||||
prediction=prediction, reference=reference, input=input, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""A protocol for comparing the output of two models."""
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_string_pairs(
|
||||
def _evaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
@ -84,8 +181,9 @@ class PairwiseStringEvaluator(Protocol):
|
||||
other information.
|
||||
"""
|
||||
|
||||
async def aevaluate_string_pairs(
|
||||
async def _aevaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
reference: Optional[str] = None,
|
||||
@ -111,3 +209,69 @@ class PairwiseStringEvaluator(Protocol):
|
||||
f"{self.__class__.__name__} hasn't implemented an async "
|
||||
"aevaluate_string_pairs method."
|
||||
)
|
||||
|
||||
def evaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Evaluate the output string pairs.
|
||||
|
||||
Args:
|
||||
prediction (str): The output string from the first model.
|
||||
prediction_b (str): The output string from the second model.
|
||||
reference (str, optional): The expected output / reference
|
||||
string. Defaults to None.
|
||||
input (str, optional): The input string. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments, such
|
||||
as callbacks and optional reference strings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or
|
||||
other information.
|
||||
"""
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
return self._evaluate_string_pairs(
|
||||
prediction=prediction,
|
||||
prediction_b=prediction_b,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def aevaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
prediction: str,
|
||||
prediction_b: str,
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Evaluate the output string pairs.
|
||||
|
||||
Args:
|
||||
prediction (str): The output string from the first model.
|
||||
prediction_b (str): The output string from the second model.
|
||||
reference (str, optional): The expected output / reference
|
||||
string. Defaults to None.
|
||||
input (str, optional): The input string. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments, such
|
||||
as callbacks and optional reference strings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or
|
||||
other information.
|
||||
"""
|
||||
self._check_evaluation_args(reference=reference, input=input)
|
||||
return await self._aevaluate_string_pairs(
|
||||
prediction=prediction,
|
||||
prediction_b=prediction_b,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -1,13 +1,15 @@
|
||||
"""Test agent trajectory evaluation chain."""
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.schema import AgentAction, BaseMessage
|
||||
from langchain.tools.base import tool
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -30,10 +32,31 @@ def foo(bar: str) -> str:
|
||||
return bar
|
||||
|
||||
|
||||
class _FakeTrajectoryChatModel(FakeChatModel):
|
||||
queries: Dict = Field(default_factory=dict)
|
||||
sequential_responses: Optional[bool] = False
|
||||
response_index: int = 0
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self.sequential_responses:
|
||||
response = self.queries[list(self.queries.keys())[self.response_index]]
|
||||
self.response_index = self.response_index + 1
|
||||
return response
|
||||
else:
|
||||
prompt = messages[0].content
|
||||
return self.queries[prompt]
|
||||
|
||||
|
||||
def test_trajectory_eval_chain(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> None:
|
||||
llm = FakeLLM(
|
||||
llm = _FakeTrajectoryChatModel(
|
||||
queries={
|
||||
"a": "Trajectory good\nScore: 5",
|
||||
"b": "Trajectory not good\nScore: 1",
|
||||
@ -61,7 +84,7 @@ def test_trajectory_eval_chain(
|
||||
def test_trajectory_eval_chain_no_tools(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> None:
|
||||
llm = FakeLLM(
|
||||
llm = _FakeTrajectoryChatModel(
|
||||
queries={
|
||||
"a": "Trajectory good\nScore: 5",
|
||||
"b": "Trajectory not good\nScore: 1",
|
||||
@ -85,7 +108,7 @@ def test_trajectory_eval_chain_no_tools(
|
||||
|
||||
|
||||
def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> None:
|
||||
llm = FakeLLM(
|
||||
llm = _FakeTrajectoryChatModel(
|
||||
queries={
|
||||
"a": "Trajectory good\nScore: 5",
|
||||
"b": "Trajectory not good\nScore: 1",
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Test the comparison chains."""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.evaluation.comparison.eval_chain import PairwiseStringEvalChain
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
@ -30,10 +32,30 @@ def test_pairwise_string_comparison_chain() -> None:
|
||||
)
|
||||
assert res["value"] == "A"
|
||||
assert res["score"] == 1
|
||||
with pytest.warns(UserWarning, match=chain._skip_reference_warning):
|
||||
res = chain.evaluate_string_pairs(
|
||||
prediction="I like pie.",
|
||||
prediction_b="I hate pie.",
|
||||
input="What is your favorite food?",
|
||||
reference="I enjoy pie.",
|
||||
)
|
||||
assert res["value"] == "B"
|
||||
assert res["score"] == 0
|
||||
|
||||
|
||||
def test_pairwise_string_comparison_chain_missing_ref() -> None:
|
||||
llm = FakeLLM(
|
||||
queries={
|
||||
"a": "The values are the same.\n[[C]]",
|
||||
"b": "A is clearly better than b.\n[[A]]",
|
||||
"c": "B is clearly better than a.\n[[B]]",
|
||||
},
|
||||
sequential_responses=True,
|
||||
)
|
||||
chain = PairwiseStringEvalChain.from_llm(llm=llm, requires_reference=True)
|
||||
with pytest.raises(ValueError):
|
||||
chain.evaluate_string_pairs(
|
||||
prediction="I like pie.",
|
||||
prediction_b="I love pie.",
|
||||
input="What is your favorite food?",
|
||||
)
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Test the criteria eval chain."""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.evaluation.criteria.eval_chain import (
|
||||
_SUPPORTED_CRITERIA,
|
||||
CriteriaEvalChain,
|
||||
@ -25,11 +27,25 @@ def test_criteria_eval_chain() -> None:
|
||||
),
|
||||
criteria={"my criterion": "my criterion description"},
|
||||
)
|
||||
with pytest.warns(UserWarning, match=chain._skip_reference_warning):
|
||||
result = chain.evaluate_strings(
|
||||
prediction="my prediction", reference="my reference", input="my input"
|
||||
)
|
||||
assert result["reasoning"] == "The meaning of life"
|
||||
|
||||
|
||||
def test_criteria_eval_chain_missing_reference() -> None:
|
||||
chain = CriteriaEvalChain.from_llm(
|
||||
llm=FakeLLM(
|
||||
queries={"text": "The meaning of life\nY"},
|
||||
sequential_responses=True,
|
||||
),
|
||||
requires_reference=True,
|
||||
criteria={"my criterion": "my criterion description"},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
chain.evaluate_strings(prediction="my prediction", input="my input")
|
||||
|
||||
|
||||
def test_implements_string_protocol() -> None:
|
||||
assert isinstance(CriteriaEvalChain, StringEvaluator)
|
||||
assert issubclass(CriteriaEvalChain, StringEvaluator)
|
||||
|
@ -52,7 +52,7 @@ def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None:
|
||||
def test_implements_string_evaluator_protocol(
|
||||
chain_cls: Type[LLMChain],
|
||||
) -> None:
|
||||
assert isinstance(chain_cls, StringEvaluator)
|
||||
assert issubclass(chain_cls, StringEvaluator)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("chain_cls", [QAEvalChain, ContextQAEvalChain, CotQAEvalChain])
|
||||
|
Loading…
Reference in New Issue
Block a user