From ea26c12b23f84ca380f053bf15647036aacf8d25 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 21 Sep 2023 19:27:09 +0100 Subject: [PATCH] Fix Runnable.transform() for false-y inputs (#10893) --------- Co-authored-by: Bagatur --- .../langchain/langchain/schema/runnable/base.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 709fb53309..e049ed73fa 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -286,16 +286,19 @@ class Runnable(Generic[Input, Output], ABC): Subclasses should override this method if they can start producing output while input is still being generated. """ - final: Union[Input, None] = None + final: Input + got_first_val = False for chunk in input: - if final is None: + if not got_first_val: final = chunk + got_first_val = True else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. final += chunk # type: ignore[operator] - if final: + + if got_first_val: yield from self.stream(final, config, **kwargs) async def atransform( @@ -309,17 +312,19 @@ class Runnable(Generic[Input, Output], ABC): Subclasses should override this method if they can start producing output while input is still being generated. """ - final: Union[Input, None] = None + final: Input + got_first_val = False async for chunk in input: - if final is None: + if not got_first_val: final = chunk + got_first_val = True else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. final += chunk # type: ignore[operator] - if final: + if got_first_val: async for output in self.astream(final, config, **kwargs): yield output