Fix Runnable.transform() for false-y inputs (#10893)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Nuno Campos 2023-09-21 19:27:09 +01:00 committed by GitHub
parent fcb5aba9f0
commit ea26c12b23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -286,16 +286,19 @@ class Runnable(Generic[Input, Output], ABC):
Subclasses should override this method if they can start producing output while Subclasses should override this method if they can start producing output while
input is still being generated. input is still being generated.
""" """
final: Union[Input, None] = None final: Input
got_first_val = False
for chunk in input: for chunk in input:
if final is None: if not got_first_val:
final = chunk final = chunk
got_first_val = True
else: else:
# Make a best effort to gather, for any type that supports `+` # Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails. # This method should throw an error if gathering fails.
final += chunk # type: ignore[operator] final += chunk # type: ignore[operator]
if final:
if got_first_val:
yield from self.stream(final, config, **kwargs) yield from self.stream(final, config, **kwargs)
async def atransform( async def atransform(
@ -309,17 +312,19 @@ class Runnable(Generic[Input, Output], ABC):
Subclasses should override this method if they can start producing output while Subclasses should override this method if they can start producing output while
input is still being generated. input is still being generated.
""" """
final: Union[Input, None] = None final: Input
got_first_val = False
async for chunk in input: async for chunk in input:
if final is None: if not got_first_val:
final = chunk final = chunk
got_first_val = True
else: else:
# Make a best effort to gather, for any type that supports `+` # Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails. # This method should throw an error if gathering fails.
final += chunk # type: ignore[operator] final += chunk # type: ignore[operator]
if final: if got_first_val:
async for output in self.astream(final, config, **kwargs): async for output in self.astream(final, config, **kwargs):
yield output yield output