From cf384dcb7ffe2889dcbac8e5bc8a587907ec0c06 Mon Sep 17 00:00:00 2001 From: ccw630 Date: Thu, 27 Apr 2023 13:07:20 +0800 Subject: [PATCH] Supports async in SequentialChain/SimpleSequentialChain (#3503) --- langchain/chains/sequential.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index 76d69946..b21dfac5 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -93,6 +93,13 @@ class SequentialChain(Chain): known_values.update(outputs) return {k: known_values[k] for k in self.output_variables} + async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + known_values = inputs.copy() + for i, chain in enumerate(self.chains): + outputs = await chain.acall(known_values, return_only_outputs=True) + known_values.update(outputs) + return {k: known_values[k] for k in self.output_variables} + class SimpleSequentialChain(Chain): """Simple chain where the outputs of one step feed directly into next.""" @@ -151,3 +158,20 @@ class SimpleSequentialChain(Chain): _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose ) return {self.output_key: _input} + + async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + _input = inputs[self.input_key] + color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) + for i, chain in enumerate(self.chains): + _input = await chain.arun(_input) + if self.strip_outputs: + _input = _input.strip() + if self.callback_manager.is_async: + await self.callback_manager.on_text( + _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose + ) + else: + self.callback_manager.on_text( + _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose + ) + return {self.output_key: _input}