Use call directly for chain (#8655)

for run_on_dataset since the `run()` method requires a single output
This commit is contained in:
William FH 2023-08-02 17:11:39 -07:00 committed by GitHub
parent 368aa4ede7
commit 7ea2b08d1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -604,11 +604,8 @@ async def _arun_chain(
inputs_, callbacks=callbacks, tags=tags
)
else:
if len(inputs) == 1:
inputs_ = next(iter(inputs.values()))
output = await chain.arun(inputs_, callbacks=callbacks, tags=tags)
else:
output = await chain.acall(inputs, callbacks=callbacks, tags=tags)
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
return output
@ -926,11 +923,8 @@ def _run_chain(
inputs_ = input_mapper(inputs)
output: Union[dict, str] = chain(inputs_, callbacks=callbacks, tags=tags)
else:
if len(inputs) == 1:
inputs_ = next(iter(inputs.values()))
output = chain.run(inputs_, callbacks=callbacks, tags=tags)
else:
output = chain(inputs, callbacks=callbacks, tags=tags)
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs
output = chain(inputs_, callbacks=callbacks, tags=tags)
return output