Make pairwise comparison chain more like LLM as a judge (#11013)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:**: Adds LLM as a judge as an eval chain
  - **Tag maintainer:** @hwchase17 

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com>
pull/11087/head
CG80499 12 months ago committed by GitHub
parent 175ef0a55d
commit 64385c4eae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,12 +1,20 @@
"""Base classes for comparing the output of two models.""" """Base classes for comparing the output of two models."""
from __future__ import annotations from __future__ import annotations
import logging
import re
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.evaluation.comparison.prompt import PROMPT, PROMPT_WITH_REFERENCE from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.openai import ChatOpenAI
from langchain.evaluation.comparison.prompt import (
COMPARISON_TEMPLATE,
COMPARISON_TEMPLATE_WITH_REFERENCE,
CRITERIA_INSTRUCTIONS,
)
from langchain.evaluation.criteria.eval_chain import ( from langchain.evaluation.criteria.eval_chain import (
CRITERIA_TYPE, CRITERIA_TYPE,
Criteria, Criteria,
@ -17,6 +25,10 @@ from langchain.pydantic_v1 import Extra, Field
from langchain.schema import RUN_KEY, BaseOutputParser from langchain.schema import RUN_KEY, BaseOutputParser
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
logger = logging.getLogger(__name__)
_FIND_DOUBLE_BRACKETS = re.compile(r"\[\[(.*?)\]\]")
_SUPPORTED_CRITERIA = { _SUPPORTED_CRITERIA = {
Criteria.CONCISENESS: "Is the submission concise and to the point?", Criteria.CONCISENESS: "Is the submission concise and to the point?",
Criteria.RELEVANCE: "Is the submission referring to a real quote from the text?", Criteria.RELEVANCE: "Is the submission referring to a real quote from the text?",
@ -112,27 +124,26 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]):
ValueError: If the verdict is invalid. ValueError: If the verdict is invalid.
""" """
parsed = text.strip().rsplit("\n", maxsplit=1) match = _FIND_DOUBLE_BRACKETS.search(text)
if len(parsed) == 1:
reasoning = "" if match:
verdict = parsed[0] verdict = match.group(1)
else:
reasoning, verdict = parsed if not match or verdict not in {"A", "B", "C"}:
verdict = verdict.strip("[").strip("]")
if verdict not in {"A", "B", "C"}:
raise ValueError( raise ValueError(
f"Invalid verdict: {verdict}. " f"Invalid output: {text}. "
"Verdict must be one of 'A', 'B', or 'C'." "Output must contain a double bracketed string\
with the verdict 'A', 'B', or 'C'."
) )
# C means the models are tied. Return 'None' meaning no preference # C means the models are tied. Return 'None' meaning no preference
verdict_ = None if verdict == "C" else verdict verdict_ = None if verdict == "C" else verdict
score = { score = {
"A": 1, "A": 1,
"B": 0, "B": 0,
None: 0.5, "C": 0.5,
}.get(verdict_) }[verdict]
return { return {
"reasoning": reasoning, "reasoning": text,
"value": verdict_, "value": verdict_,
"score": score, "score": score,
} }
@ -225,7 +236,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
"""Initialize the PairwiseStringEvalChain from an LLM. """Initialize the PairwiseStringEvalChain from an LLM.
Args: Args:
llm (BaseLanguageModel): The LLM to use. llm (BaseChatModel): The LLM to use (GPT-4 recommended).
prompt (PromptTemplate, optional): The prompt to use. prompt (PromptTemplate, optional): The prompt to use.
**kwargs (Any): Additional keyword arguments. **kwargs (Any): Additional keyword arguments.
@ -236,8 +247,17 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
ValueError: If the input variables are not as expected. ValueError: If the input variables are not as expected.
""" """
if not (
isinstance(llm, (ChatOpenAI, AzureChatOpenAI))
and llm.model_name.startswith("gpt-4")
):
logger.warning(
"This chain was only tested with GPT-4. \
Performance may be significantly worse with other models."
)
expected_input_vars = {"prediction", "prediction_b", "input", "criteria"} expected_input_vars = {"prediction", "prediction_b", "input", "criteria"}
prompt_ = prompt or PROMPT prompt_ = prompt or COMPARISON_TEMPLATE.partial(reference="")
if expected_input_vars != set(prompt_.input_variables): if expected_input_vars != set(prompt_.input_variables):
raise ValueError( raise ValueError(
f"Input variables should be {expected_input_vars}, " f"Input variables should be {expected_input_vars}, "
@ -245,6 +265,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
) )
criteria_ = resolve_pairwise_criteria(criteria) criteria_ = resolve_pairwise_criteria(criteria)
criteria_str = "\n".join(f"{k}: {v}" if v else k for k, v in criteria_.items()) criteria_str = "\n".join(f"{k}: {v}" if v else k for k, v in criteria_.items())
criteria_str = CRITERIA_INSTRUCTIONS + criteria_str if criteria_str else ""
return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs) return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)
def _prepare_input( def _prepare_input(
@ -418,7 +439,7 @@ class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain):
"reference", "reference",
"criteria", "criteria",
} }
prompt_ = prompt or PROMPT_WITH_REFERENCE prompt_ = prompt or COMPARISON_TEMPLATE_WITH_REFERENCE
if expected_input_vars != set(prompt_.input_variables): if expected_input_vars != set(prompt_.input_variables):
raise ValueError( raise ValueError(
f"Input variables should be {expected_input_vars}, " f"Input variables should be {expected_input_vars}, "
@ -426,4 +447,5 @@ class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain):
) )
criteria_ = resolve_pairwise_criteria(criteria) criteria_ = resolve_pairwise_criteria(criteria)
criteria_str = "\n".join(f"{k}: {v}" for k, v in criteria_.items()) criteria_str = "\n".join(f"{k}: {v}" for k, v in criteria_.items())
criteria_str = CRITERIA_INSTRUCTIONS + criteria_str if criteria_str else ""
return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs) return cls(llm=llm, prompt=prompt_.partial(criteria=criteria_str), **kwargs)

@ -5,64 +5,55 @@ and answers the question. The prompt is based on the paper from
Zheng, et. al. https://arxiv.org/abs/2306.05685 Zheng, et. al. https://arxiv.org/abs/2306.05685
""" """
# flake8: noqa # flake8: noqa
from langchain.prompts import PromptTemplate from langchain.prompts.chat import ChatPromptTemplate
template = """Act as a fair judge and rate the two responses to the question below.\ SYSTEM_MESSAGE = 'Please act as an impartial judge and evaluate the quality \
Choose the response that best followed the instructions and answered the question.\ of the responses provided by two AI assistants to the user question displayed below. \
Your assessment should weigh the following criteria: You should choose the assistant that follows the user\'s instructions \
{criteria}\ and answers \the user\'s question better. \
Start by comparing both responses and give a brief rationale.\ Your evaluation should consider factors such as the \
Avoid bias from the order of presentation or response length. helpfulness, relevance, accuracy, depth, creativity, \
After giving your rationale, make your final decision using this format:\ and level of detail of their responses. \
"[[A]]" if assistant A is better, "[[B]]" if assistant B is better,\ Begin your evaluation by comparing the two responses and provide a short explanation. \
and "[[C]]" for a tie. Finally, repeat the decision again on its own on a new line. Avoid any position biases and ensure that the order in which \
the responses were presented does not influence your decision. \
[QUESTION] Do not allow the length of the responses to influence your evaluation. \
{input} Do not favor certain names of the assistants. Be as objective as possible. \
[/QUESTION] After providing your explanation, output your final verdict by strictly following \
this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better, \
[RESPONSE A] and "[[C]]" for a tie.'
{prediction}
[/RESPONSE A] CRITERIA_INSTRUCTIONS = (
"For this evaluation, you should primarily consider the following criteria:\n"
[RESPONSE B]
{prediction_b}
[/RESPONSE B]"""
PROMPT = PromptTemplate(
input_variables=["input", "prediction", "prediction_b", "criteria"],
template=template,
) )
template = """Act as a fair judge and rate the two responses to the question below.\ COMPARISON_TEMPLATE = ChatPromptTemplate.from_messages(
Choose the response that best followed the instructions and answered the question.\ [
Your assessment should weigh the following criteria: ("system", SYSTEM_MESSAGE),
{criteria}\ (
Start by comparing both responses and give a brief rationale.\ "human",
Avoid bias from the order of presentation or response length.\ "{criteria}[User Question]\n{input}\n\n\
Weigh accuracy based on the following ground truth reference\ [The Start of Assistant A's Answer]\n{prediction}\n\
answer to the question: [The End of Assistant A's Answer]\
\n\n[The Start of Assistant B's Answer]\n{prediction_b}\n\
[REFERENCE] [The End of Assistant B's Answer]",
{reference} ),
[/REFERENCE] ]
)
After giving your rationale, make your final decision using this format:\
"[[A]]" if assistant A is better, "[[B]]" if assistant B is better,\
and "[[C]]" for a tie. Finally, repeat the decision again on its own on a new line.
[QUESTION]
{input}
[/QUESTION]
[RESPONSE A]
{prediction}
[/RESPONSE A]
[RESPONSE B]
{prediction_b}
[/RESPONSE B]"""
PROMPT_WITH_REFERENCE = PromptTemplate( COMPARISON_TEMPLATE_WITH_REFERENCE = ChatPromptTemplate.from_messages(
input_variables=["input", "prediction", "prediction_b", "reference", "criteria"], [
template=template, ("system", SYSTEM_MESSAGE),
(
"human",
"{criteria}\n\nTo help you evaluate the responses, \
here is a reference answer to the user's question:\n\
{reference}\
[User Question]\n{input}\n\n\
[The Start of Assistant A's Answer]\n{prediction}\n\
[The End of Assistant A's Answer]\
\n\n[The Start of Assistant B's Answer]\n{prediction_b}\n\
[The End of Assistant B's Answer]",
),
]
) )

@ -34,7 +34,7 @@ def test_PairwiseStringResultOutputParser_parse() -> None:
[[A]]""" [[A]]"""
got = output_parser.parse(text) got = output_parser.parse(text)
want = { want = {
"reasoning": "I like pie better than cake.", "reasoning": text,
"value": "A", "value": "A",
"score": 1, "score": 1,
} }
@ -46,7 +46,7 @@ def test_PairwiseStringResultOutputParser_parse() -> None:
[[B]]""" [[B]]"""
got = output_parser.parse(text) got = output_parser.parse(text)
want = { want = {
"reasoning": "I like cake better than pie.", "reasoning": text,
"value": "B", "value": "B",
"score": 0, "score": 0,
} }
@ -58,7 +58,7 @@ def test_PairwiseStringResultOutputParser_parse() -> None:
[[C]]""" [[C]]"""
got = output_parser.parse(text) got = output_parser.parse(text)
want = { want = {
"reasoning": "I like cake and pie.", "reasoning": text,
"value": None, "value": None,
"score": 0.5, "score": 0.5,
} }
@ -84,7 +84,7 @@ def test_pairwise_string_comparison_chain() -> None:
) )
assert res["value"] is None assert res["value"] is None
assert res["score"] == 0.5 assert res["score"] == 0.5
assert res["reasoning"] == "The values are the same." assert res["reasoning"] == "The values are the same.\n[[C]]"
res = chain.evaluate_string_pairs( res = chain.evaluate_string_pairs(
prediction="I like pie.", prediction="I like pie.",
prediction_b="I like pie.", prediction_b="I like pie.",

Loading…
Cancel
Save