Bug fixes for runnables (#10738)

- tools invoked in async methods would not work due to missing await
- RunnableSequence.stream() was creating an extra root run by mistake,
and it can simplified due to existence of default implementation for
.transform()

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/10607/head^2
Nuno Campos 11 months ago committed by GitHub
parent 6e48092746
commit 8201cae770
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1269,98 +1269,23 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Iterator[Output]:
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = 0
for i in range(len(steps) - 1, 0, -1):
if type(steps[i]).transform != Runnable.transform:
streaming_start_index = i - 1
else:
break
final_pipeline = None
gathered_input = None
if streaming_start_index == 0:
final_pipeline = steps[streaming_start_index].transform(
input,
patch_config(config, callbacks=run_manager.get_child("seq:step:1")),
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
final_pipeline = cast(Iterator[Output], input)
for step in steps:
final_pipeline = step.transform(
final_pipeline,
patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
),
)
else:
try:
for input_chunk in input:
if gathered_input is None:
gathered_input = input_chunk
else:
gathered_input += input_chunk
# invoke the first steps
for step in steps[0:streaming_start_index]:
gathered_input = step.invoke(
gathered_input,
# mark each step as a child run
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
)
# stream the first of the last steps with the final non-streaming input
final_pipeline = steps[streaming_start_index].stream(
gathered_input,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{streaming_start_index+1}"
),
),
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
# stream the last steps
final: Union[Output, None] = None
final_supported = True
try:
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform(
final_pipeline,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
)
for output in final_pipeline:
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
if final is None:
final = output
else:
try:
final += output # type: ignore[operator]
except TypeError:
final = None
final_supported = False
pass
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(final)
for output in final_pipeline:
yield output
async def _atransform(
self,
@ -1368,97 +1293,23 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> AsyncIterator[Output]:
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last]
streaming_start_index = len(steps) - 1
for i in range(len(steps) - 1, 0, -1):
if type(steps[i]).atransform != Runnable.atransform:
streaming_start_index = i - 1
else:
break
final_pipeline = None
gathered_input = None
if streaming_start_index == 0:
final_pipeline = steps[0].atransform(
input,
patch_config(config, callbacks=run_manager.get_child("seq:step:1")),
)
else:
try:
async for input_chunk in input:
if gathered_input is None:
gathered_input = input_chunk
else:
gathered_input += input_chunk
# invoke the first steps
for step in steps[0:streaming_start_index]:
gathered_input = await step.ainvoke(
gathered_input,
# mark each step as a child run
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
)
# stream the first of the last steps with the final non-streaming input
final_pipeline = steps[streaming_start_index].astream(
gathered_input,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{streaming_start_index+1}"
),
),
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
# stream the last steps
final: Union[Output, None] = None
final_supported = True
try:
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform(
final_pipeline,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
)
async for output in final_pipeline:
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
if final is None:
final = output
else:
try:
final += output # type: ignore[operator]
except TypeError:
final = None
final_supported = False
pass
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(final)
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
final_pipeline = cast(AsyncIterator[Output], input)
for step in steps:
final_pipeline = step.atransform(
final_pipeline,
patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
),
)
async for output in final_pipeline:
yield output
def transform(
self,

@ -210,7 +210,7 @@ class ChildTool(BaseTool):
) -> Any:
if type(self)._arun == BaseTool._arun:
# If the tool does not implement async, fall back to default implementation
return super().ainvoke(input, config, **kwargs)
return await super().ainvoke(input, config, **kwargs)
config = config or {}
return await self.arun(
@ -461,7 +461,7 @@ class Tool(BaseTool):
None, partial(self.invoke, input, config, **kwargs)
)
return super().ainvoke(input, config, **kwargs)
return await super().ainvoke(input, config, **kwargs)
# --- Tool ---

Loading…
Cancel
Save