Update Key Check (#8948)

In eval loop. It needn't be done unless you are creating the
corresponding evaluators
pull/9000/head
William FH 1 year ago committed by GitHub
parent 539672a7fd
commit 90579021f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -502,6 +502,18 @@ def _construct_run_evaluator(
return run_evaluator return run_evaluator
def _get_keys(
config: RunEvalConfig,
run_inputs: Optional[List[str]],
run_outputs: Optional[List[str]],
example_outputs: Optional[List[str]],
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs)
return input_key, prediction_key, reference_key
def _load_run_evaluators( def _load_run_evaluators(
config: RunEvalConfig, config: RunEvalConfig,
run_type: str, run_type: str,
@ -521,9 +533,13 @@ def _load_run_evaluators(
""" """
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0) eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
run_evaluators = [] run_evaluators = []
input_key = _determine_input_key(config, run_inputs) input_key, prediction_key, reference_key = None, None, None
prediction_key = _determine_prediction_key(config, run_outputs) if config.evaluators or any(
reference_key = _determine_reference_key(config, example_outputs) [isinstance(e, EvaluatorType) for e in config.evaluators]
):
input_key, prediction_key, reference_key = _get_keys(
config, run_inputs, run_outputs, example_outputs
)
for eval_config in config.evaluators: for eval_config in config.evaluators:
run_evaluator = _construct_run_evaluator( run_evaluator = _construct_run_evaluator(
eval_config, eval_config,
@ -1074,15 +1090,15 @@ def _run_on_examples(
A dictionary mapping example ids to the model outputs. A dictionary mapping example ids to the model outputs.
""" """
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory) wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory) project_name = _get_project_name(project_name, wrapped_model)
tracer = LangChainTracer( tracer = LangChainTracer(
project_name=project_name, client=client, use_threading=False project_name=project_name, client=client, use_threading=False
) )
run_evaluators, examples = _setup_evaluation( run_evaluators, examples = _setup_evaluation(
llm_or_chain_factory, examples, evaluation, data_type wrapped_model, examples, evaluation, data_type
) )
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper) examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
evalution_handler = EvaluatorCallbackHandler( evalution_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [], evaluators=run_evaluators or [],
client=client, client=client,
@ -1091,7 +1107,7 @@ def _run_on_examples(
for i, example in enumerate(examples): for i, example in enumerate(examples):
result = _run_llm_or_chain( result = _run_llm_or_chain(
example, example,
llm_or_chain_factory, wrapped_model,
num_repetitions, num_repetitions,
tags=tags, tags=tags,
callbacks=callbacks, callbacks=callbacks,
@ -1114,8 +1130,8 @@ def _prepare_eval_run(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str], project_name: Optional[str],
) -> Tuple[MCF, str, Dataset, Iterator[Example]]: ) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name) wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, llm_or_chain_factory) project_name = _get_project_name(project_name, wrapped_model)
try: try:
project = client.create_project(project_name) project = client.create_project(project_name)
except ValueError as e: except ValueError as e:
@ -1130,7 +1146,7 @@ def _prepare_eval_run(
) )
dataset = client.read_dataset(dataset_name=dataset_name) dataset = client.read_dataset(dataset_name=dataset_name)
examples = client.list_examples(dataset_id=str(dataset.id)) examples = client.list_examples(dataset_id=str(dataset.id))
return llm_or_chain_factory, project_name, dataset, examples return wrapped_model, project_name, dataset, examples
async def arun_on_dataset( async def arun_on_dataset(
@ -1256,13 +1272,13 @@ async def arun_on_dataset(
evaluation=evaluation_config, evaluation=evaluation_config,
) )
""" # noqa: E501 """ # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run( wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name client, dataset_name, llm_or_chain_factory, project_name
) )
results = await _arun_on_examples( results = await _arun_on_examples(
client, client,
examples, examples,
llm_or_chain_factory, wrapped_model,
concurrency_level=concurrency_level, concurrency_level=concurrency_level,
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,
@ -1423,14 +1439,14 @@ def run_on_dataset(
evaluation=evaluation_config, evaluation=evaluation_config,
) )
""" # noqa: E501 """ # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run( wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name client, dataset_name, llm_or_chain_factory, project_name
) )
if concurrency_level in (0, 1): if concurrency_level in (0, 1):
results = _run_on_examples( results = _run_on_examples(
client, client,
examples, examples,
llm_or_chain_factory, wrapped_model,
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,
verbose=verbose, verbose=verbose,
@ -1444,7 +1460,7 @@ def run_on_dataset(
coro = _arun_on_examples( coro = _arun_on_examples(
client, client,
examples, examples,
llm_or_chain_factory, wrapped_model,
concurrency_level=concurrency_level, concurrency_level=concurrency_level,
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,

Loading…
Cancel
Save