Update Key Check (#8948)

In eval loop. It needn't be done unless you are creating the
corresponding evaluators
pull/9000/head
William FH 11 months ago committed by GitHub
parent 539672a7fd
commit 90579021f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -502,6 +502,18 @@ def _construct_run_evaluator(
return run_evaluator
def _get_keys(
config: RunEvalConfig,
run_inputs: Optional[List[str]],
run_outputs: Optional[List[str]],
example_outputs: Optional[List[str]],
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs)
return input_key, prediction_key, reference_key
def _load_run_evaluators(
config: RunEvalConfig,
run_type: str,
@ -521,9 +533,13 @@ def _load_run_evaluators(
"""
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
run_evaluators = []
input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs)
input_key, prediction_key, reference_key = None, None, None
if config.evaluators or any(
[isinstance(e, EvaluatorType) for e in config.evaluators]
):
input_key, prediction_key, reference_key = _get_keys(
config, run_inputs, run_outputs, example_outputs
)
for eval_config in config.evaluators:
run_evaluator = _construct_run_evaluator(
eval_config,
@ -1074,15 +1090,15 @@ def _run_on_examples(
A dictionary mapping example ids to the model outputs.
"""
results: Dict[str, Any] = {}
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, wrapped_model)
tracer = LangChainTracer(
project_name=project_name, client=client, use_threading=False
)
run_evaluators, examples = _setup_evaluation(
llm_or_chain_factory, examples, evaluation, data_type
wrapped_model, examples, evaluation, data_type
)
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper)
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
evalution_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [],
client=client,
@ -1091,7 +1107,7 @@ def _run_on_examples(
for i, example in enumerate(examples):
result = _run_llm_or_chain(
example,
llm_or_chain_factory,
wrapped_model,
num_repetitions,
tags=tags,
callbacks=callbacks,
@ -1114,8 +1130,8 @@ def _prepare_eval_run(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, llm_or_chain_factory)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, wrapped_model)
try:
project = client.create_project(project_name)
except ValueError as e:
@ -1130,7 +1146,7 @@ def _prepare_eval_run(
)
dataset = client.read_dataset(dataset_name=dataset_name)
examples = client.list_examples(dataset_id=str(dataset.id))
return llm_or_chain_factory, project_name, dataset, examples
return wrapped_model, project_name, dataset, examples
async def arun_on_dataset(
@ -1256,13 +1272,13 @@ async def arun_on_dataset(
evaluation=evaluation_config,
)
""" # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
results = await _arun_on_examples(
client,
examples,
llm_or_chain_factory,
wrapped_model,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,
@ -1423,14 +1439,14 @@ def run_on_dataset(
evaluation=evaluation_config,
)
""" # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
if concurrency_level in (0, 1):
results = _run_on_examples(
client,
examples,
llm_or_chain_factory,
wrapped_model,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
@ -1444,7 +1460,7 @@ def run_on_dataset(
coro = _arun_on_examples(
client,
examples,
llm_or_chain_factory,
wrapped_model,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,

Loading…
Cancel
Save