Adds transform and atransform support to runnable sequences (#9583)

Allow runnable sequences to support transform if each individual
runnable inside supports transform/atransform.

@nfcampos
pull/10649/head
Jacob Lee 1 year ago committed by GitHub
parent c0e1a1d32c
commit a50e62e44b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1231,11 +1231,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
else:
raise first_exception
def stream(
def _transform(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Iterator[Output]:
# setup callbacks
config = ensure_config(config)
@ -1254,37 +1254,50 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
else:
break
# invoke the first steps
try:
for step in steps[0:streaming_start_index]:
input = step.invoke(
input,
# mark each step as a child run
final_pipeline = None
gathered_input = None
if streaming_start_index == 0:
final_pipeline = steps[streaming_start_index].transform(
input,
patch_config(config, callbacks=run_manager.get_child("seq:step:1")),
)
else:
try:
for input_chunk in input:
if gathered_input is None:
gathered_input = input_chunk
else:
gathered_input += input_chunk
# invoke the first steps
for step in steps[0:streaming_start_index]:
gathered_input = step.invoke(
gathered_input,
# mark each step as a child run
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
)
# stream the first of the last steps with the final non-streaming input
final_pipeline = steps[streaming_start_index].stream(
gathered_input,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
f"seq:step:{streaming_start_index+1}"
),
),
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
# stream the last steps
final: Union[Output, None] = None
final_supported = True
try:
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].stream(
input,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{streaming_start_index+1}"
),
),
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform(
@ -1296,6 +1309,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
),
),
)
for output in final_pipeline:
yield output
# Accumulate output if possible, otherwise disable accumulation
@ -1316,11 +1330,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
else:
run_manager.on_chain_end(final)
async def astream(
async def _atransform(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
input: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> AsyncIterator[Output]:
# setup callbacks
config = ensure_config(config)
@ -1334,42 +1348,55 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
streaming_start_index = len(steps) - 1
for i in range(len(steps) - 1, 0, -1):
if type(steps[i]).transform != Runnable.transform:
if type(steps[i]).atransform != Runnable.atransform:
streaming_start_index = i - 1
else:
break
# invoke the first steps
try:
for step in steps[0:streaming_start_index]:
input = await step.ainvoke(
input,
# mark each step as a child run
final_pipeline = None
gathered_input = None
if streaming_start_index == 0:
final_pipeline = steps[0].atransform(
input,
patch_config(config, callbacks=run_manager.get_child("seq:step:1")),
)
else:
try:
async for input_chunk in input:
if gathered_input is None:
gathered_input = input_chunk
else:
gathered_input += input_chunk
# invoke the first steps
for step in steps[0:streaming_start_index]:
gathered_input = await step.ainvoke(
gathered_input,
# mark each step as a child run
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
)
# stream the first of the last steps with the final non-streaming input
final_pipeline = steps[streaming_start_index].astream(
gathered_input,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
f"seq:step:{streaming_start_index+1}"
),
),
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
# stream the last steps
final: Union[Output, None] = None
final_supported = True
try:
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].astream(
input,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{streaming_start_index+1}"
),
),
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform(
@ -1401,6 +1428,47 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
else:
await run_manager.on_chain_end(final)
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
yield from self.transform(iter([input]), config, **kwargs)
async def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
):
yield chunk
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async def input_aiter() -> AsyncIterator[Input]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk
class RunnableMapChunk(Dict[str, Any]):
"""

@ -1315,6 +1315,37 @@ async def test_deep_astream() -> None:
assert "".join(chunks) == "foo-lish"
def test_runnable_sequence_transform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = llm | StrOutputParser()
stream = chain.transform(llm.stream("Hi there!"))
chunks = []
for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@pytest.mark.asyncio
async def test_runnable_sequence_atransform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = llm | StrOutputParser()
stream = chain.atransform(llm.astream("Hi there!"))
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@pytest.fixture()
def llm_with_fallbacks() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)

Loading…
Cancel
Save