From 8201cae770df0941ab7263c7f8f1843a953c6966 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 18 Sep 2023 15:36:57 +0100 Subject: [PATCH] Bug fixes for runnables (#10738) - tools invoked in async methods would not work due to missing await - RunnableSequence.stream() was creating an extra root run by mistake, and it can simplified due to existence of default implementation for .transform() --- .../langchain/schema/runnable/base.py | 203 +++--------------- libs/langchain/langchain/tools/base.py | 4 +- 2 files changed, 29 insertions(+), 178 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 57a993bdaf..49c3abf462 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1269,98 +1269,23 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): run_manager: CallbackManagerForChainRun, config: RunnableConfig, ) -> Iterator[Output]: - # setup callbacks - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - # start the root run - run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - steps = [self.first] + self.middle + [self.last] - streaming_start_index = 0 - - for i in range(len(steps) - 1, 0, -1): - if type(steps[i]).transform != Runnable.transform: - streaming_start_index = i - 1 - else: - break - final_pipeline = None - gathered_input = None - if streaming_start_index == 0: - final_pipeline = steps[streaming_start_index].transform( - input, - patch_config(config, callbacks=run_manager.get_child("seq:step:1")), + # transform the input stream of each step with the next + # steps that don't natively support transforming an input stream will + # buffer input in memory until all available, and then start emitting output + final_pipeline = cast(Iterator[Output], input) + for step in steps: + final_pipeline = step.transform( + final_pipeline, + patch_config( + config, + callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"), + ), ) - else: - try: - for input_chunk in input: - if gathered_input is None: - gathered_input = input_chunk - else: - gathered_input += input_chunk - # invoke the first steps - for step in steps[0:streaming_start_index]: - gathered_input = step.invoke( - gathered_input, - # mark each step as a child run - patch_config( - config, - callbacks=run_manager.get_child( - f"seq:step:{steps.index(step)+1}" - ), - ), - ) - # stream the first of the last steps with the final non-streaming input - final_pipeline = steps[streaming_start_index].stream( - gathered_input, - patch_config( - config, - callbacks=run_manager.get_child( - f"seq:step:{streaming_start_index+1}" - ), - ), - ) - except (KeyboardInterrupt, Exception) as e: - run_manager.on_chain_error(e) - raise - - # stream the last steps - final: Union[Output, None] = None - final_supported = True - try: - # stream the rest of the last steps with streaming input - for step in steps[streaming_start_index + 1 :]: - final_pipeline = step.transform( - final_pipeline, - patch_config( - config, - callbacks=run_manager.get_child( - f"seq:step:{steps.index(step)+1}" - ), - ), - ) - for output in final_pipeline: - yield output - # Accumulate output if possible, otherwise disable accumulation - if final_supported: - if final is None: - final = output - else: - try: - final += output # type: ignore[operator] - except TypeError: - final = None - final_supported = False - pass - # finish the root run - except (KeyboardInterrupt, Exception) as e: - run_manager.on_chain_error(e) - raise - else: - run_manager.on_chain_end(final) + for output in final_pipeline: + yield output async def _atransform( self, @@ -1368,97 +1293,23 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, ) -> AsyncIterator[Output]: - # setup callbacks - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) - # start the root run - run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - steps = [self.first] + self.middle + [self.last] - streaming_start_index = len(steps) - 1 - - for i in range(len(steps) - 1, 0, -1): - if type(steps[i]).atransform != Runnable.atransform: - streaming_start_index = i - 1 - else: - break - - final_pipeline = None - gathered_input = None - if streaming_start_index == 0: - final_pipeline = steps[0].atransform( - input, - patch_config(config, callbacks=run_manager.get_child("seq:step:1")), - ) - else: - try: - async for input_chunk in input: - if gathered_input is None: - gathered_input = input_chunk - else: - gathered_input += input_chunk - # invoke the first steps - for step in steps[0:streaming_start_index]: - gathered_input = await step.ainvoke( - gathered_input, - # mark each step as a child run - patch_config( - config, - callbacks=run_manager.get_child( - f"seq:step:{steps.index(step)+1}" - ), - ), - ) - # stream the first of the last steps with the final non-streaming input - final_pipeline = steps[streaming_start_index].astream( - gathered_input, - patch_config( - config, - callbacks=run_manager.get_child( - f"seq:step:{streaming_start_index+1}" - ), - ), - ) - except (KeyboardInterrupt, Exception) as e: - await run_manager.on_chain_error(e) - raise # stream the last steps - final: Union[Output, None] = None - final_supported = True - try: - # stream the rest of the last steps with streaming input - for step in steps[streaming_start_index + 1 :]: - final_pipeline = step.atransform( - final_pipeline, - patch_config( - config, - callbacks=run_manager.get_child( - f"seq:step:{steps.index(step)+1}" - ), - ), - ) - async for output in final_pipeline: - yield output - # Accumulate output if possible, otherwise disable accumulation - if final_supported: - if final is None: - final = output - else: - try: - final += output # type: ignore[operator] - except TypeError: - final = None - final_supported = False - pass - # finish the root run - except (KeyboardInterrupt, Exception) as e: - await run_manager.on_chain_error(e) - raise - else: - await run_manager.on_chain_end(final) + # transform the input stream of each step with the next + # steps that don't natively support transforming an input stream will + # buffer input in memory until all available, and then start emitting output + final_pipeline = cast(AsyncIterator[Output], input) + for step in steps: + final_pipeline = step.atransform( + final_pipeline, + patch_config( + config, + callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"), + ), + ) + async for output in final_pipeline: + yield output def transform( self, diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 69597cd903..b6fac8bf9d 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -210,7 +210,7 @@ class ChildTool(BaseTool): ) -> Any: if type(self)._arun == BaseTool._arun: # If the tool does not implement async, fall back to default implementation - return super().ainvoke(input, config, **kwargs) + return await super().ainvoke(input, config, **kwargs) config = config or {} return await self.arun( @@ -461,7 +461,7 @@ class Tool(BaseTool): None, partial(self.invoke, input, config, **kwargs) ) - return super().ainvoke(input, config, **kwargs) + return await super().ainvoke(input, config, **kwargs) # --- Tool ---