Override Keys Option (#13537)

Should be able to override the global key if you want to evaluate
different outputs in a single run
pull/13541/head
William FH 10 months ago committed by GitHub
parent e584b28c54
commit 5a28dc3210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -54,6 +54,26 @@ class EvalConfig(BaseModel):
return kwargs
class SingleKeyEvalConfig(EvalConfig):
reference_key: Optional[str] = None
"""The key in the dataset run to use as the reference string.
If not provided, we will attempt to infer automatically."""
prediction_key: Optional[str] = None
"""The key from the traced run's outputs dictionary to use to
represent the prediction. If not provided, it will be inferred
automatically."""
input_key: Optional[str] = None
"""The key from the traced run's inputs dictionary to use to represent the
input. If not provided, it will be inferred automatically."""
def get_kwargs(self) -> Dict[str, Any]:
kwargs = super().get_kwargs()
# Filer out the keys that are not needed for the evaluator.
for key in ["reference_key", "prediction_key", "input_key"]:
kwargs.pop(key, None)
return kwargs
class RunEvalConfig(BaseModel):
"""Configuration for a run evaluation.
@ -113,7 +133,7 @@ class RunEvalConfig(BaseModel):
class Config:
arbitrary_types_allowed = True
class Criteria(EvalConfig):
class Criteria(SingleKeyEvalConfig):
"""Configuration for a reference-free criteria evaluator.
Parameters
@ -134,7 +154,7 @@ class RunEvalConfig(BaseModel):
) -> None:
super().__init__(criteria=criteria, **kwargs)
class LabeledCriteria(EvalConfig):
class LabeledCriteria(SingleKeyEvalConfig):
"""Configuration for a labeled (with references) criteria evaluator.
Parameters
@ -154,7 +174,7 @@ class RunEvalConfig(BaseModel):
) -> None:
super().__init__(criteria=criteria, **kwargs)
class EmbeddingDistance(EvalConfig):
class EmbeddingDistance(SingleKeyEvalConfig):
"""Configuration for an embedding distance evaluator.
Parameters
@ -174,7 +194,7 @@ class RunEvalConfig(BaseModel):
class Config:
arbitrary_types_allowed = True
class StringDistance(EvalConfig):
class StringDistance(SingleKeyEvalConfig):
"""Configuration for a string distance evaluator.
Parameters
@ -196,7 +216,7 @@ class RunEvalConfig(BaseModel):
"""Whether to normalize the distance to between 0 and 1.
Applies only to the Levenshtein and Damerau-Levenshtein distances."""
class QA(EvalConfig):
class QA(SingleKeyEvalConfig):
"""Configuration for a QA evaluator.
Parameters
@ -211,7 +231,7 @@ class RunEvalConfig(BaseModel):
llm: Optional[BaseLanguageModel] = None
prompt: Optional[BasePromptTemplate] = None
class ContextQA(EvalConfig):
class ContextQA(SingleKeyEvalConfig):
"""Configuration for a context-based QA evaluator.
Parameters
@ -227,7 +247,7 @@ class RunEvalConfig(BaseModel):
llm: Optional[BaseLanguageModel] = None
prompt: Optional[BasePromptTemplate] = None
class CoTQA(EvalConfig):
class CoTQA(SingleKeyEvalConfig):
"""Configuration for a context-based QA evaluator.
Parameters
@ -243,7 +263,7 @@ class RunEvalConfig(BaseModel):
llm: Optional[BaseLanguageModel] = None
prompt: Optional[BasePromptTemplate] = None
class JsonValidity(EvalConfig):
class JsonValidity(SingleKeyEvalConfig):
"""Configuration for a json validity evaluator.
Parameters
@ -261,7 +281,7 @@ class RunEvalConfig(BaseModel):
evaluator_type: EvaluatorType = EvaluatorType.JSON_EQUALITY
class ExactMatch(EvalConfig):
class ExactMatch(SingleKeyEvalConfig):
"""Configuration for an exact match string evaluator.
Parameters
@ -279,7 +299,7 @@ class RunEvalConfig(BaseModel):
ignore_punctuation: bool = False
ignore_numbers: bool = False
class RegexMatch(EvalConfig):
class RegexMatch(SingleKeyEvalConfig):
"""Configuration for a regex match string evaluator.
Parameters
@ -291,7 +311,7 @@ class RunEvalConfig(BaseModel):
evaluator_type: EvaluatorType = EvaluatorType.REGEX_MATCH
flags: int = 0
class ScoreString(EvalConfig):
class ScoreString(SingleKeyEvalConfig):
"""Configuration for a score string evaluator.
This is like the criteria evaluator but it is configured by
default to return a score on the scale from 1-10.

@ -487,6 +487,11 @@ def _construct_run_evaluator(
kwargs = {"llm": eval_llm, **eval_config.get_kwargs()}
evaluator_ = load_evaluator(eval_config.evaluator_type, **kwargs)
eval_type_tag = eval_config.evaluator_type.value
# Override keys if specified in the config
if isinstance(eval_config, smith_eval_config.SingleKeyEvalConfig):
input_key = eval_config.input_key or input_key
prediction_key = eval_config.prediction_key or prediction_key
reference_key = eval_config.reference_key or reference_key
if isinstance(evaluator_, StringEvaluator):
if evaluator_.requires_reference and reference_key is None:

Loading…
Cancel
Save