Explicitly list requires_reference in function (#7357)

pull/7372/head
William FH 1 year ago committed by GitHub
parent 49b2b0e3c0
commit 38ca5c84cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -251,17 +251,24 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
requires_reference=True,
)
"""
expected_input_vars = {"input", "output", "criteria"}
if prompt is None:
if requires_reference:
prompt = PROMPT_WITH_REFERENCES
else:
prompt = PROMPT
if requires_reference:
expected_input_vars.add("reference")
if expected_input_vars != set(prompt.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt.input_variables}"
)
criteria_ = cls.resolve_criteria(criteria)
criteria_str = " ".join(f"{k}: {v}" for k, v in criteria_.items())
prompt_ = prompt.partial(criteria=criteria_str)
return cls(
llm=llm, prompt=prompt_, requires_reference=requires_reference, **kwargs
)
return cls(llm=llm, prompt=prompt_, **kwargs)
def _get_eval_input(
self,

@ -14,7 +14,6 @@ from langchain.evaluation.criteria.eval_chain import (
CriteriaEvalChain,
CriteriaResultOutputParser,
)
from langchain.evaluation.criteria.prompt import PROMPT as CRITERIA_PROMPT
from langchain.evaluation.qa.eval_chain import QAEvalChain
from langchain.evaluation.qa.eval_prompt import PROMPT as QA_DEFAULT_PROMPT
from langchain.evaluation.qa.eval_prompt import SQL_PROMPT
@ -152,8 +151,9 @@ def get_criteria_evaluator(
*,
input_key: str = "input",
prediction_key: str = "output",
prompt: BasePromptTemplate = CRITERIA_PROMPT,
prompt: Optional[BasePromptTemplate] = None,
evaluation_name: Optional[str] = None,
requires_reference: bool = False,
**kwargs: Any,
) -> RunEvaluatorChain:
"""Get an eval chain for grading a model's response against a map of criteria."""
@ -174,7 +174,11 @@ def get_criteria_evaluator(
)
tags = kwargs.pop("tags", [])
eval_chain = CriteriaEvalChain.from_llm(
llm=llm, criteria=criteria_, prompt=prompt, **kwargs
llm=llm,
criteria=criteria_,
prompt=prompt,
requires_reference=requires_reference,
**kwargs,
)
return RunEvaluatorChain(
eval_chain=eval_chain,

Loading…
Cancel
Save