mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
core(minor): Implement stream and astream for RunnableBranch (#14805)
* This PR adds `stream` implementations to Runnable Branch. * Runnable Branch still does not support `transform` so it'll break streaming if it happens in middle or end of sequence, but will work if happens at beginning of sequence. * Fixes use the async callback manager for async methods * Handle BaseException rather than Exception, so more errors could be logged as errors when they are encountered --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
65a9193db2
commit
42822484ef
@ -1,7 +1,9 @@
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
@ -23,6 +25,7 @@ from langchain_core.runnables.base import (
|
|||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
ensure_config,
|
ensure_config,
|
||||||
|
get_async_callback_manager_for_config,
|
||||||
get_callback_manager_for_config,
|
get_callback_manager_for_config,
|
||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
@ -212,7 +215,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
),
|
),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
run_manager.on_chain_end(dumpd(output))
|
run_manager.on_chain_end(dumpd(output))
|
||||||
@ -223,8 +226,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Async version of invoke."""
|
"""Async version of invoke."""
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input,
|
input,
|
||||||
name=config.get("run_name"),
|
name=config.get("run_name"),
|
||||||
@ -259,8 +262,156 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
),
|
),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
await run_manager.on_chain_end(dumpd(output))
|
||||||
|
return output
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
"""First evaluates the condition,
|
||||||
|
then delegate to true or false branch."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
final_output: Optional[Output] = None
|
||||||
|
final_output_supported = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = condition.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
for chunk in runnable.stream(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
if final_output_supported:
|
||||||
|
if final_output is None:
|
||||||
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output = final_output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
final_output = None
|
||||||
|
final_output_supported = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for chunk in self.default.stream(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag="branch:default"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
if final_output_supported:
|
||||||
|
if final_output is None:
|
||||||
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output = final_output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
final_output = None
|
||||||
|
final_output_supported = False
|
||||||
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
run_manager.on_chain_end(dumpd(output))
|
run_manager.on_chain_end(final_output)
|
||||||
return output
|
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
"""First evaluates the condition,
|
||||||
|
then delegate to true or false branch."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
final_output: Optional[Output] = None
|
||||||
|
final_output_supported = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = await condition.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
async for chunk in runnable.astream(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
if final_output_supported:
|
||||||
|
if final_output is None:
|
||||||
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output = final_output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
final_output = None
|
||||||
|
final_output_supported = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
async for chunk in self.default.astream(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag="branch:default"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
if final_output_supported:
|
||||||
|
if final_output is None:
|
||||||
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output = final_output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
final_output = None
|
||||||
|
final_output_supported = False
|
||||||
|
except BaseException as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
await run_manager.on_chain_end(final_output)
|
||||||
|
@ -3981,6 +3981,140 @@ async def test_runnable_branch_abatch() -> None:
|
|||||||
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
|
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_runnable_branch_stream() -> None:
|
||||||
|
"""Verify that stream works for RunnableBranch."""
|
||||||
|
|
||||||
|
llm_res = "i'm a textbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
|
branch = RunnableBranch[str, Any](
|
||||||
|
(lambda x: x == "hello", llm),
|
||||||
|
lambda x: x,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list(branch.stream("hello")) == list(llm_res)
|
||||||
|
assert list(branch.stream("bye")) == ["bye"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_runnable_branch_stream_with_callbacks() -> None:
|
||||||
|
"""Verify that stream works for RunnableBranch when using callbacks."""
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
def raise_value_error(x: str) -> Any:
|
||||||
|
"""Raise a value error."""
|
||||||
|
raise ValueError(f"x is {x}")
|
||||||
|
|
||||||
|
llm_res = "i'm a textbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
|
branch = RunnableBranch[str, Any](
|
||||||
|
(lambda x: x == "error", raise_value_error),
|
||||||
|
(lambda x: x == "hello", llm),
|
||||||
|
lambda x: x,
|
||||||
|
)
|
||||||
|
config: RunnableConfig = {"callbacks": [tracer]}
|
||||||
|
|
||||||
|
assert list(branch.stream("hello", config=config)) == list(llm_res)
|
||||||
|
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].error is None
|
||||||
|
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||||
|
|
||||||
|
# Verify that the chain on error is invoked
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
for _ in branch.stream("error", config=config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
||||||
|
assert tracer.runs[1].outputs is None
|
||||||
|
|
||||||
|
assert list(branch.stream("bye", config=config)) == ["bye"]
|
||||||
|
|
||||||
|
assert len(tracer.runs) == 3
|
||||||
|
assert tracer.runs[2].error is None
|
||||||
|
assert tracer.runs[2].outputs == {"output": "bye"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_runnable_branch_astream() -> None:
|
||||||
|
"""Verify that astream works for RunnableBranch."""
|
||||||
|
|
||||||
|
llm_res = "i'm a textbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
|
branch = RunnableBranch[str, Any](
|
||||||
|
(lambda x: x == "hello", llm),
|
||||||
|
lambda x: x,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [_ async for _ in branch.astream("hello")] == list(llm_res)
|
||||||
|
assert [_ async for _ in branch.astream("bye")] == ["bye"]
|
||||||
|
|
||||||
|
# Verify that the async variant is used if available
|
||||||
|
async def condition(x: str) -> bool:
|
||||||
|
return x == "hello"
|
||||||
|
|
||||||
|
async def repeat(x: str) -> str:
|
||||||
|
return x + x
|
||||||
|
|
||||||
|
async def reverse(x: str) -> str:
|
||||||
|
return x[::-1]
|
||||||
|
|
||||||
|
branch = RunnableBranch[str, Any]((condition, repeat), llm)
|
||||||
|
|
||||||
|
assert [_ async for _ in branch.astream("hello")] == ["hello" * 2]
|
||||||
|
assert [_ async for _ in branch.astream("bye")] == list(llm_res)
|
||||||
|
|
||||||
|
branch = RunnableBranch[str, Any]((condition, llm), reverse)
|
||||||
|
|
||||||
|
assert [_ async for _ in branch.astream("hello")] == list(llm_res)
|
||||||
|
assert [_ async for _ in branch.astream("bye")] == ["eyb"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_runnable_branch_astream_with_callbacks() -> None:
|
||||||
|
"""Verify that astream works for RunnableBranch when using callbacks."""
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
def raise_value_error(x: str) -> Any:
|
||||||
|
"""Raise a value error."""
|
||||||
|
raise ValueError(f"x is {x}")
|
||||||
|
|
||||||
|
llm_res = "i'm a textbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
|
branch = RunnableBranch[str, Any](
|
||||||
|
(lambda x: x == "error", raise_value_error),
|
||||||
|
(lambda x: x == "hello", llm),
|
||||||
|
lambda x: x,
|
||||||
|
)
|
||||||
|
config: RunnableConfig = {"callbacks": [tracer]}
|
||||||
|
|
||||||
|
assert [_ async for _ in branch.astream("hello", config=config)] == list(llm_res)
|
||||||
|
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].error is None
|
||||||
|
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||||
|
|
||||||
|
# Verify that the chain on error is invoked
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async for _ in branch.astream("error", config=config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
||||||
|
assert tracer.runs[1].outputs is None
|
||||||
|
|
||||||
|
assert [_ async for _ in branch.astream("bye", config=config)] == ["bye"]
|
||||||
|
|
||||||
|
assert len(tracer.runs) == 3
|
||||||
|
assert tracer.runs[2].error is None
|
||||||
|
assert tracer.runs[2].outputs == {"output": "bye"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user