Explicitly list requires_reference in function (#7357)

pull/7372/head
William FH 1 year ago committed by GitHub
parent 49b2b0e3c0
commit 38ca5c84cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -251,17 +251,24 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
requires_reference=True, requires_reference=True,
) )
""" """
expected_input_vars = {"input", "output", "criteria"}
if prompt is None: if prompt is None:
if requires_reference: if requires_reference:
prompt = PROMPT_WITH_REFERENCES prompt = PROMPT_WITH_REFERENCES
else: else:
prompt = PROMPT prompt = PROMPT
if requires_reference:
expected_input_vars.add("reference")
if expected_input_vars != set(prompt.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt.input_variables}"
)
criteria_ = cls.resolve_criteria(criteria) criteria_ = cls.resolve_criteria(criteria)
criteria_str = " ".join(f"{k}: {v}" for k, v in criteria_.items()) criteria_str = " ".join(f"{k}: {v}" for k, v in criteria_.items())
prompt_ = prompt.partial(criteria=criteria_str) prompt_ = prompt.partial(criteria=criteria_str)
return cls( return cls(llm=llm, prompt=prompt_, **kwargs)
llm=llm, prompt=prompt_, requires_reference=requires_reference, **kwargs
)
def _get_eval_input( def _get_eval_input(
self, self,

@ -14,7 +14,6 @@ from langchain.evaluation.criteria.eval_chain import (
CriteriaEvalChain, CriteriaEvalChain,
CriteriaResultOutputParser, CriteriaResultOutputParser,
) )
from langchain.evaluation.criteria.prompt import PROMPT as CRITERIA_PROMPT
from langchain.evaluation.qa.eval_chain import QAEvalChain from langchain.evaluation.qa.eval_chain import QAEvalChain
from langchain.evaluation.qa.eval_prompt import PROMPT as QA_DEFAULT_PROMPT from langchain.evaluation.qa.eval_prompt import PROMPT as QA_DEFAULT_PROMPT
from langchain.evaluation.qa.eval_prompt import SQL_PROMPT from langchain.evaluation.qa.eval_prompt import SQL_PROMPT
@ -152,8 +151,9 @@ def get_criteria_evaluator(
*, *,
input_key: str = "input", input_key: str = "input",
prediction_key: str = "output", prediction_key: str = "output",
prompt: BasePromptTemplate = CRITERIA_PROMPT, prompt: Optional[BasePromptTemplate] = None,
evaluation_name: Optional[str] = None, evaluation_name: Optional[str] = None,
requires_reference: bool = False,
**kwargs: Any, **kwargs: Any,
) -> RunEvaluatorChain: ) -> RunEvaluatorChain:
"""Get an eval chain for grading a model's response against a map of criteria.""" """Get an eval chain for grading a model's response against a map of criteria."""
@ -174,7 +174,11 @@ def get_criteria_evaluator(
) )
tags = kwargs.pop("tags", []) tags = kwargs.pop("tags", [])
eval_chain = CriteriaEvalChain.from_llm( eval_chain = CriteriaEvalChain.from_llm(
llm=llm, criteria=criteria_, prompt=prompt, **kwargs llm=llm,
criteria=criteria_,
prompt=prompt,
requires_reference=requires_reference,
**kwargs,
) )
return RunEvaluatorChain( return RunEvaluatorChain(
eval_chain=eval_chain, eval_chain=eval_chain,

Loading…
Cancel
Save