diff --git a/langchain/evaluation/criteria/eval_chain.py b/langchain/evaluation/criteria/eval_chain.py index 8270bbaa24..f2dd3c6b08 100644 --- a/langchain/evaluation/criteria/eval_chain.py +++ b/langchain/evaluation/criteria/eval_chain.py @@ -251,17 +251,24 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain): requires_reference=True, ) """ + expected_input_vars = {"input", "output", "criteria"} if prompt is None: if requires_reference: prompt = PROMPT_WITH_REFERENCES else: prompt = PROMPT + if requires_reference: + expected_input_vars.add("reference") + if expected_input_vars != set(prompt.input_variables): + raise ValueError( + f"Input variables should be {expected_input_vars}, " + f"but got {prompt.input_variables}" + ) + criteria_ = cls.resolve_criteria(criteria) criteria_str = " ".join(f"{k}: {v}" for k, v in criteria_.items()) prompt_ = prompt.partial(criteria=criteria_str) - return cls( - llm=llm, prompt=prompt_, requires_reference=requires_reference, **kwargs - ) + return cls(llm=llm, prompt=prompt_, **kwargs) def _get_eval_input( self, diff --git a/langchain/evaluation/run_evaluators/implementations.py b/langchain/evaluation/run_evaluators/implementations.py index d9a7ddb689..1c8ad643b9 100644 --- a/langchain/evaluation/run_evaluators/implementations.py +++ b/langchain/evaluation/run_evaluators/implementations.py @@ -14,7 +14,6 @@ from langchain.evaluation.criteria.eval_chain import ( CriteriaEvalChain, CriteriaResultOutputParser, ) -from langchain.evaluation.criteria.prompt import PROMPT as CRITERIA_PROMPT from langchain.evaluation.qa.eval_chain import QAEvalChain from langchain.evaluation.qa.eval_prompt import PROMPT as QA_DEFAULT_PROMPT from langchain.evaluation.qa.eval_prompt import SQL_PROMPT @@ -152,8 +151,9 @@ def get_criteria_evaluator( *, input_key: str = "input", prediction_key: str = "output", - prompt: BasePromptTemplate = CRITERIA_PROMPT, + prompt: Optional[BasePromptTemplate] = None, evaluation_name: Optional[str] = None, + requires_reference: bool = False, **kwargs: Any, ) -> RunEvaluatorChain: """Get an eval chain for grading a model's response against a map of criteria.""" @@ -174,7 +174,11 @@ def get_criteria_evaluator( ) tags = kwargs.pop("tags", []) eval_chain = CriteriaEvalChain.from_llm( - llm=llm, criteria=criteria_, prompt=prompt, **kwargs + llm=llm, + criteria=criteria_, + prompt=prompt, + requires_reference=requires_reference, + **kwargs, ) return RunEvaluatorChain( eval_chain=eval_chain,