diff --git a/libs/langchain/langchain/smith/evaluation/config.py b/libs/langchain/langchain/smith/evaluation/config.py index 8fa57e778f..f723452864 100644 --- a/libs/langchain/langchain/smith/evaluation/config.py +++ b/libs/langchain/langchain/smith/evaluation/config.py @@ -54,6 +54,26 @@ class EvalConfig(BaseModel): return kwargs +class SingleKeyEvalConfig(EvalConfig): + reference_key: Optional[str] = None + """The key in the dataset run to use as the reference string. + If not provided, we will attempt to infer automatically.""" + prediction_key: Optional[str] = None + """The key from the traced run's outputs dictionary to use to + represent the prediction. If not provided, it will be inferred + automatically.""" + input_key: Optional[str] = None + """The key from the traced run's inputs dictionary to use to represent the + input. If not provided, it will be inferred automatically.""" + + def get_kwargs(self) -> Dict[str, Any]: + kwargs = super().get_kwargs() + # Filer out the keys that are not needed for the evaluator. + for key in ["reference_key", "prediction_key", "input_key"]: + kwargs.pop(key, None) + return kwargs + + class RunEvalConfig(BaseModel): """Configuration for a run evaluation. @@ -113,7 +133,7 @@ class RunEvalConfig(BaseModel): class Config: arbitrary_types_allowed = True - class Criteria(EvalConfig): + class Criteria(SingleKeyEvalConfig): """Configuration for a reference-free criteria evaluator. Parameters @@ -134,7 +154,7 @@ class RunEvalConfig(BaseModel): ) -> None: super().__init__(criteria=criteria, **kwargs) - class LabeledCriteria(EvalConfig): + class LabeledCriteria(SingleKeyEvalConfig): """Configuration for a labeled (with references) criteria evaluator. Parameters @@ -154,7 +174,7 @@ class RunEvalConfig(BaseModel): ) -> None: super().__init__(criteria=criteria, **kwargs) - class EmbeddingDistance(EvalConfig): + class EmbeddingDistance(SingleKeyEvalConfig): """Configuration for an embedding distance evaluator. Parameters @@ -174,7 +194,7 @@ class RunEvalConfig(BaseModel): class Config: arbitrary_types_allowed = True - class StringDistance(EvalConfig): + class StringDistance(SingleKeyEvalConfig): """Configuration for a string distance evaluator. Parameters @@ -196,7 +216,7 @@ class RunEvalConfig(BaseModel): """Whether to normalize the distance to between 0 and 1. Applies only to the Levenshtein and Damerau-Levenshtein distances.""" - class QA(EvalConfig): + class QA(SingleKeyEvalConfig): """Configuration for a QA evaluator. Parameters @@ -211,7 +231,7 @@ class RunEvalConfig(BaseModel): llm: Optional[BaseLanguageModel] = None prompt: Optional[BasePromptTemplate] = None - class ContextQA(EvalConfig): + class ContextQA(SingleKeyEvalConfig): """Configuration for a context-based QA evaluator. Parameters @@ -227,7 +247,7 @@ class RunEvalConfig(BaseModel): llm: Optional[BaseLanguageModel] = None prompt: Optional[BasePromptTemplate] = None - class CoTQA(EvalConfig): + class CoTQA(SingleKeyEvalConfig): """Configuration for a context-based QA evaluator. Parameters @@ -243,7 +263,7 @@ class RunEvalConfig(BaseModel): llm: Optional[BaseLanguageModel] = None prompt: Optional[BasePromptTemplate] = None - class JsonValidity(EvalConfig): + class JsonValidity(SingleKeyEvalConfig): """Configuration for a json validity evaluator. Parameters @@ -261,7 +281,7 @@ class RunEvalConfig(BaseModel): evaluator_type: EvaluatorType = EvaluatorType.JSON_EQUALITY - class ExactMatch(EvalConfig): + class ExactMatch(SingleKeyEvalConfig): """Configuration for an exact match string evaluator. Parameters @@ -279,7 +299,7 @@ class RunEvalConfig(BaseModel): ignore_punctuation: bool = False ignore_numbers: bool = False - class RegexMatch(EvalConfig): + class RegexMatch(SingleKeyEvalConfig): """Configuration for a regex match string evaluator. Parameters @@ -291,7 +311,7 @@ class RunEvalConfig(BaseModel): evaluator_type: EvaluatorType = EvaluatorType.REGEX_MATCH flags: int = 0 - class ScoreString(EvalConfig): + class ScoreString(SingleKeyEvalConfig): """Configuration for a score string evaluator. This is like the criteria evaluator but it is configured by default to return a score on the scale from 1-10. diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 1e17652f39..291d3a9fc1 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -487,6 +487,11 @@ def _construct_run_evaluator( kwargs = {"llm": eval_llm, **eval_config.get_kwargs()} evaluator_ = load_evaluator(eval_config.evaluator_type, **kwargs) eval_type_tag = eval_config.evaluator_type.value + # Override keys if specified in the config + if isinstance(eval_config, smith_eval_config.SingleKeyEvalConfig): + input_key = eval_config.input_key or input_key + prediction_key = eval_config.prediction_key or prediction_key + reference_key = eval_config.reference_key or reference_key if isinstance(evaluator_, StringEvaluator): if evaluator_.requires_reference and reference_key is None: