From 7ea2b08d1f94dc4f8b08fa1834252f5cb7be4065 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Wed, 2 Aug 2023 17:11:39 -0700 Subject: [PATCH] Use call directly for chain (#8655) for run_on_dataset since the `run()` method requires a single output --- .../langchain/smith/evaluation/runner_utils.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index ebb3f6c767..070def0600 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -604,11 +604,8 @@ async def _arun_chain( inputs_, callbacks=callbacks, tags=tags ) else: - if len(inputs) == 1: - inputs_ = next(iter(inputs.values())) - output = await chain.arun(inputs_, callbacks=callbacks, tags=tags) - else: - output = await chain.acall(inputs, callbacks=callbacks, tags=tags) + inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs + output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) return output @@ -926,11 +923,8 @@ def _run_chain( inputs_ = input_mapper(inputs) output: Union[dict, str] = chain(inputs_, callbacks=callbacks, tags=tags) else: - if len(inputs) == 1: - inputs_ = next(iter(inputs.values())) - output = chain.run(inputs_, callbacks=callbacks, tags=tags) - else: - output = chain(inputs, callbacks=callbacks, tags=tags) + inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs + output = chain(inputs_, callbacks=callbacks, tags=tags) return output