From e736d60516d757a6b892fe183a42077b4c497979 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Thu, 6 Jul 2023 13:58:58 -0700 Subject: [PATCH] Load Evaluator (#6942) Create a `load_evaluators()` function so you don't have to import all the individual evaluator classes --- langchain/evaluation/__init__.py | 45 +++++--- .../agents/trajectory_eval_chain.py | 12 +- langchain/evaluation/comparison/eval_chain.py | 11 +- langchain/evaluation/criteria/eval_chain.py | 4 +- langchain/evaluation/loading.py | 104 +++++++++++++++++- langchain/evaluation/qa/eval_chain.py | 13 ++- langchain/evaluation/schema.py | 33 ++++++ tests/unit_tests/evaluation/test_loading.py | 16 +++ 8 files changed, 209 insertions(+), 29 deletions(-) create mode 100644 tests/unit_tests/evaluation/test_loading.py diff --git a/langchain/evaluation/__init__.py b/langchain/evaluation/__init__.py index b6dfc5027c..eb1c4b64da 100644 --- a/langchain/evaluation/__init__.py +++ b/langchain/evaluation/__init__.py @@ -1,33 +1,45 @@ -"""Functionality relating to evaluation. +"""Evaluation chains for grading LLM and Chain outputs. -This module contains off-the-shelf evaluation chains for -grading the output of LangChain primitives such as LLMs and Chains. +This module contains off-the-shelf evaluation chains for grading the output of +LangChain primitives such as language models and chains. + +To load an evaluator, you can use the :func:`load_evaluators ` function with the +names of the evaluators to load. + +To load one of the LangChain HuggingFace datasets, you can use the :func:`load_dataset ` function with the +name of the dataset to load. Some common use cases for evaluation include: -- Grading accuracy of a response against ground truth answers: QAEvalChain -- Comparing the output of two models: PairwiseStringEvalChain -- Judging the efficacy of an agent's tool usage: TrajectoryEvalChain -- Checking whether an output complies with a set of criteria: CriteriaEvalChain +- Grading the accuracy of a response against ground truth answers: :class:`QAEvalChain ` +- Comparing the output of two models: :class:`PairwiseStringEvalChain ` +- Judging the efficacy of an agent's tool usage: :class:`TrajectoryEvalChain ` +- Checking whether an output complies with a set of criteria: :class:`CriteriaEvalChain ` -This module also contains low level APIs for making more evaluators for your -custom evaluation task. These include: -- StringEvaluator: Evaluates an output string against a reference and/or - with input context. -- PairwiseStringEvaluator: Evaluates two strings against each other. -""" +This module also contains low-level APIs for creating custom evaluators for +specific evaluation tasks. These include: -from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain +- :class:`StringEvaluator `: Evaluate a prediction string against a reference label and/or input context. +- :class:`PairwiseStringEvaluator `: Evaluate two prediction strings against each other. + Useful for scoring preferences, measuring similarity between two chain or llm agents, or comparing outputs on similar inputs. +- :class:`AgentTrajectoryEvaluator `: Evaluate the full sequence of actions + taken by an agent. + +""" # noqa: E501 +from langchain.evaluation.agents import TrajectoryEvalChain from langchain.evaluation.comparison import PairwiseStringEvalChain -from langchain.evaluation.criteria.eval_chain import CriteriaEvalChain +from langchain.evaluation.criteria import CriteriaEvalChain +from langchain.evaluation.loading import load_dataset, load_evaluator, load_evaluators from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain from langchain.evaluation.schema import ( AgentTrajectoryEvaluator, + EvaluatorType, PairwiseStringEvaluator, StringEvaluator, ) __all__ = [ + "EvaluatorType", "PairwiseStringEvalChain", "QAEvalChain", "CotQAEvalChain", @@ -36,5 +48,8 @@ __all__ = [ "PairwiseStringEvaluator", "TrajectoryEvalChain", "CriteriaEvalChain", + "load_evaluators", + "load_evaluator", + "load_dataset", "AgentTrajectoryEvaluator", ] diff --git a/langchain/evaluation/agents/trajectory_eval_chain.py b/langchain/evaluation/agents/trajectory_eval_chain.py index dfd9a44c1e..b0249da226 100644 --- a/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/langchain/evaluation/agents/trajectory_eval_chain.py @@ -7,7 +7,7 @@ chain (LLMChain) to generate the reasoning and scores. from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union -from pydantic import Field +from pydantic import Extra, Field from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import ( @@ -15,14 +15,13 @@ from langchain.callbacks.manager import ( CallbackManagerForChainRun, Callbacks, ) -from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chat_models.base import BaseChatModel from langchain.evaluation.agents.trajectory_eval_prompt import ( EVAL_CHAT_PROMPT, TOOL_FREE_EVAL_CHAT_PROMPT, ) -from langchain.evaluation.schema import AgentTrajectoryEvaluator +from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain from langchain.schema import AgentAction, BaseOutputParser, OutputParserException from langchain.tools.base import BaseTool @@ -71,7 +70,7 @@ class TrajectoryOutputParser(BaseOutputParser): return TrajectoryEval(score=int(score_str), reasoning=reasoning) -class TrajectoryEvalChain(AgentTrajectoryEvaluator, Chain): +class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain): """A chain for evaluating ReAct style agents. This chain is used to evaluate ReAct style agents by reasoning about @@ -125,6 +124,11 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, Chain): return_reasoning: bool = False """Whether to return the reasoning along with the score.""" + class Config: + """Configuration for the QAEvalChain.""" + + extra = Extra.ignore + @property def _tools_description(self) -> str: """Get the description of the agent tools. diff --git a/langchain/evaluation/comparison/eval_chain.py b/langchain/evaluation/comparison/eval_chain.py index 7022004e61..4fc1c978d9 100644 --- a/langchain/evaluation/comparison/eval_chain.py +++ b/langchain/evaluation/comparison/eval_chain.py @@ -3,13 +3,13 @@ from __future__ import annotations from typing import Any, Optional -from pydantic import Field +from pydantic import Extra, Field from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import Callbacks from langchain.chains.llm import LLMChain from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE -from langchain.evaluation.schema import PairwiseStringEvaluator +from langchain.evaluation.schema import LLMEvalChain, PairwiseStringEvaluator from langchain.prompts.prompt import PromptTemplate from langchain.schema import BaseOutputParser @@ -51,7 +51,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]): } -class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMChain): +class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain): """A chain for comparing the output of two models. Example: @@ -81,6 +81,11 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMChain): default_factory=PairwiseStringResultOutputParser ) + class Config: + """Configuration for the QAEvalChain.""" + + extra = Extra.ignore + @property def requires_reference(self) -> bool: return "reference" in self.prompt.input_variables diff --git a/langchain/evaluation/criteria/eval_chain.py b/langchain/evaluation/criteria/eval_chain.py index 067fb6c543..8270bbaa24 100644 --- a/langchain/evaluation/criteria/eval_chain.py +++ b/langchain/evaluation/criteria/eval_chain.py @@ -8,7 +8,7 @@ from langchain.base_language import BaseLanguageModel from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.llm import LLMChain from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES -from langchain.evaluation.schema import StringEvaluator +from langchain.evaluation.schema import LLMEvalChain, StringEvaluator from langchain.schema import BaseOutputParser, BasePromptTemplate _SUPPORTED_CRITERIA = { @@ -60,7 +60,7 @@ CRITERIA_TYPE = Union[ ] -class CriteriaEvalChain(StringEvaluator, LLMChain): +class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain): """LLM Chain for evaluating runs against criteria. Parameters diff --git a/langchain/evaluation/loading.py b/langchain/evaluation/loading.py index 613e261303..27b8b79abd 100644 --- a/langchain/evaluation/loading.py +++ b/langchain/evaluation/loading.py @@ -1,8 +1,108 @@ -from typing import Dict, List +"""Loading datasets and evaluators.""" +from typing import Any, Dict, List, Optional, Sequence, Type + +from langchain.base_language import BaseLanguageModel +from langchain.chains.base import Chain +from langchain.chat_models.openai import ChatOpenAI +from langchain.evaluation.agents.trajectory_eval_chain import TrajectoryEvalChain +from langchain.evaluation.comparison import PairwiseStringEvalChain +from langchain.evaluation.criteria.eval_chain import CriteriaEvalChain +from langchain.evaluation.qa import ContextQAEvalChain, CotQAEvalChain, QAEvalChain +from langchain.evaluation.schema import EvaluatorType, LLMEvalChain def load_dataset(uri: str) -> List[Dict]: - from datasets import load_dataset + """Load a dataset from the LangChainDatasets HuggingFace org.""" + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "load_dataset requires the `datasets` package." + " Please install with `pip install datasets`" + ) dataset = load_dataset(f"LangChainDatasets/{uri}") return [d for d in dataset["train"]] + + +_EVALUATOR_MAP: Dict[EvaluatorType, Type[LLMEvalChain]] = { + EvaluatorType.QA: QAEvalChain, + EvaluatorType.COT_QA: CotQAEvalChain, + EvaluatorType.CONTEXT_QA: ContextQAEvalChain, + EvaluatorType.PAIRWISE_STRING: PairwiseStringEvalChain, + EvaluatorType.AGENT_TRAJECTORY: TrajectoryEvalChain, + EvaluatorType.CRITERIA: CriteriaEvalChain, +} + + +def load_evaluator( + evaluator: EvaluatorType, + *, + llm: Optional[BaseLanguageModel] = None, + **kwargs: Any, +) -> Chain: + """Load the requested evaluation chain specified by a string. + + Parameters + ---------- + evaluator : EvaluatorType + The type of evaluator to load. + llm : BaseLanguageModel, optional + The language model to use for evaluation, by default None + **kwargs : Any + Additional keyword arguments to pass to the evaluator. + + Returns + ------- + Chain + The loaded evaluation chain. + + Examples + -------- + >>> llm = ChatOpenAI(model="gpt-4", temperature=0) + >>> evaluator = load_evaluator(EvaluatorType.QA, llm=llm) + """ + llm = llm or ChatOpenAI(model="gpt-4", temperature=0) + return _EVALUATOR_MAP[evaluator].from_llm(llm=llm, **kwargs) + + +def load_evaluators( + evaluators: Sequence[EvaluatorType], + *, + llm: Optional[BaseLanguageModel] = None, + config: Optional[dict] = None, + **kwargs: Any, +) -> List[Chain]: + """Load evaluators specified by a list of evaluator types. + + Parameters + ---------- + evaluators : Sequence[EvaluatorType] + The list of evaluator types to load. + llm : BaseLanguageModel, optional + The language model to use for evaluation, if none is provided, a default + ChatOpenAI gpt-4 model will be used. + config : dict, optional + A dictionary mapping evaluator types to additional keyword arguments, + by default None + **kwargs : Any + Additional keyword arguments to pass to all evaluators. + + Returns + ------- + List[Chain] + The loaded evaluators. + + Examples + -------- + .. code-block:: python + from langchain.evaluation import load_evaluators, EvaluatorType + evaluators = [EvaluatorType.QA, EvaluatorType.CRITERIA] + loaded_evaluators = load_evaluators(evaluators, criteria="helpfulness") + """ + llm = llm or ChatOpenAI(model="gpt-4", temperature=0) + loaded = [] + for evaluator in evaluators: + _kwargs = config.get(evaluator, {}) if config else {} + loaded.append(load_evaluator(evaluator, llm=llm, **{**kwargs, **_kwargs})) + return loaded diff --git a/langchain/evaluation/qa/eval_chain.py b/langchain/evaluation/qa/eval_chain.py index 3a13ec8e11..725539616c 100644 --- a/langchain/evaluation/qa/eval_chain.py +++ b/langchain/evaluation/qa/eval_chain.py @@ -3,12 +3,14 @@ from __future__ import annotations from typing import Any, List, Optional, Sequence +from pydantic import Extra + from langchain import PromptTemplate from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import Callbacks from langchain.chains.llm import LLMChain from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT -from langchain.evaluation.schema import StringEvaluator +from langchain.evaluation.schema import LLMEvalChain, StringEvaluator def _parse_string_eval_output(text: str) -> dict: @@ -39,9 +41,14 @@ def _parse_string_eval_output(text: str) -> dict: } -class QAEvalChain(LLMChain, StringEvaluator): +class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): """LLM Chain specifically for evaluating question answering.""" + class Config: + """Configuration for the QAEvalChain.""" + + extra = Extra.ignore + @property def requires_reference(self) -> bool: return True @@ -143,7 +150,7 @@ class QAEvalChain(LLMChain, StringEvaluator): return _parse_string_eval_output(result["text"]) -class ContextQAEvalChain(LLMChain, StringEvaluator): +class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): """LLM Chain specifically for evaluating QA w/o GT based on context""" @property diff --git a/langchain/evaluation/schema.py b/langchain/evaluation/schema.py index bd6351a5c3..57fea9e6cc 100644 --- a/langchain/evaluation/schema.py +++ b/langchain/evaluation/schema.py @@ -3,14 +3,47 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod +from enum import Enum from typing import Any, Optional, Sequence, Tuple from warnings import warn +from langchain.base_language import BaseLanguageModel +from langchain.chains.base import Chain from langchain.schema.agent import AgentAction logger = logging.getLogger(__name__) +class EvaluatorType(str, Enum): + """The types of the evaluators.""" + + QA = "qa" + """Question answering evaluator, which grades answers to questions + directly using an LLM.""" + COT_QA = "cot_qa" + """Chain of thought question answering evaluator, which grades + answers to questions using + chain of thought 'reasoning'.""" + CONTEXT_QA = "context_qa" + """Question answering evaluator that incorporates 'context' in the response.""" + PAIRWISE_STRING = "pairwise_string" + """The pairwise string evaluator, which compares the output of two models.""" + AGENT_TRAJECTORY = "trajectory" + """The agent trajectory evaluator, which grades the agent's intermediate steps.""" + CRITERIA = "criteria" + """The criteria evaluator, which evaluates a model based on a + custom set of criteria.""" + + +class LLMEvalChain(Chain): + """A base class for evaluators that use an LLM.""" + + @classmethod + @abstractmethod + def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> LLMEvalChain: + """Create a new evaluator from an LLM.""" + + class _EvalArgsMixin: """Mixin for checking evaluation arguments.""" diff --git a/tests/unit_tests/evaluation/test_loading.py b/tests/unit_tests/evaluation/test_loading.py new file mode 100644 index 0000000000..27c538d8b3 --- /dev/null +++ b/tests/unit_tests/evaluation/test_loading.py @@ -0,0 +1,16 @@ +"""Test the loading function for evalutors.""" + +import pytest + +from langchain.evaluation.loading import EvaluatorType, load_evaluators +from tests.unit_tests.llms.fake_chat_model import FakeChatModel + + +@pytest.mark.parametrize("evaluator_type", EvaluatorType) +def test_load_evaluators(evaluator_type: EvaluatorType) -> None: + """Test loading evaluators.""" + fake_llm = FakeChatModel() + load_evaluators([evaluator_type], llm=fake_llm) + + # Test as string + load_evaluators([evaluator_type.value], llm=fake_llm) # type: ignore