diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 5b3d5775c4..88cd3f8986 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -502,6 +502,18 @@ def _construct_run_evaluator( return run_evaluator +def _get_keys( + config: RunEvalConfig, + run_inputs: Optional[List[str]], + run_outputs: Optional[List[str]], + example_outputs: Optional[List[str]], +) -> Tuple[Optional[str], Optional[str], Optional[str]]: + input_key = _determine_input_key(config, run_inputs) + prediction_key = _determine_prediction_key(config, run_outputs) + reference_key = _determine_reference_key(config, example_outputs) + return input_key, prediction_key, reference_key + + def _load_run_evaluators( config: RunEvalConfig, run_type: str, @@ -521,9 +533,13 @@ def _load_run_evaluators( """ eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0) run_evaluators = [] - input_key = _determine_input_key(config, run_inputs) - prediction_key = _determine_prediction_key(config, run_outputs) - reference_key = _determine_reference_key(config, example_outputs) + input_key, prediction_key, reference_key = None, None, None + if config.evaluators or any( + [isinstance(e, EvaluatorType) for e in config.evaluators] + ): + input_key, prediction_key, reference_key = _get_keys( + config, run_inputs, run_outputs, example_outputs + ) for eval_config in config.evaluators: run_evaluator = _construct_run_evaluator( eval_config, @@ -1074,15 +1090,15 @@ def _run_on_examples( A dictionary mapping example ids to the model outputs. """ results: Dict[str, Any] = {} - llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory) - project_name = _get_project_name(project_name, llm_or_chain_factory) + wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory) + project_name = _get_project_name(project_name, wrapped_model) tracer = LangChainTracer( project_name=project_name, client=client, use_threading=False ) run_evaluators, examples = _setup_evaluation( - llm_or_chain_factory, examples, evaluation, data_type + wrapped_model, examples, evaluation, data_type ) - examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper) + examples = _validate_example_inputs(examples, wrapped_model, input_mapper) evalution_handler = EvaluatorCallbackHandler( evaluators=run_evaluators or [], client=client, @@ -1091,7 +1107,7 @@ def _run_on_examples( for i, example in enumerate(examples): result = _run_llm_or_chain( example, - llm_or_chain_factory, + wrapped_model, num_repetitions, tags=tags, callbacks=callbacks, @@ -1114,8 +1130,8 @@ def _prepare_eval_run( llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, project_name: Optional[str], ) -> Tuple[MCF, str, Dataset, Iterator[Example]]: - llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name) - project_name = _get_project_name(project_name, llm_or_chain_factory) + wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name) + project_name = _get_project_name(project_name, wrapped_model) try: project = client.create_project(project_name) except ValueError as e: @@ -1130,7 +1146,7 @@ def _prepare_eval_run( ) dataset = client.read_dataset(dataset_name=dataset_name) examples = client.list_examples(dataset_id=str(dataset.id)) - return llm_or_chain_factory, project_name, dataset, examples + return wrapped_model, project_name, dataset, examples async def arun_on_dataset( @@ -1256,13 +1272,13 @@ async def arun_on_dataset( evaluation=evaluation_config, ) """ # noqa: E501 - llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run( + wrapped_model, project_name, dataset, examples = _prepare_eval_run( client, dataset_name, llm_or_chain_factory, project_name ) results = await _arun_on_examples( client, examples, - llm_or_chain_factory, + wrapped_model, concurrency_level=concurrency_level, num_repetitions=num_repetitions, project_name=project_name, @@ -1423,14 +1439,14 @@ def run_on_dataset( evaluation=evaluation_config, ) """ # noqa: E501 - llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run( + wrapped_model, project_name, dataset, examples = _prepare_eval_run( client, dataset_name, llm_or_chain_factory, project_name ) if concurrency_level in (0, 1): results = _run_on_examples( client, examples, - llm_or_chain_factory, + wrapped_model, num_repetitions=num_repetitions, project_name=project_name, verbose=verbose, @@ -1444,7 +1460,7 @@ def run_on_dataset( coro = _arun_on_examples( client, examples, - llm_or_chain_factory, + wrapped_model, concurrency_level=concurrency_level, num_repetitions=num_repetitions, project_name=project_name,