core[patch]: simple fallback streaming (#16055)

This commit is contained in:
Bagatur 2024-01-19 16:31:54 -08:00 committed by GitHub
parent 4ef0ed4ddc
commit 1e29b676d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 179 additions and 2 deletions

View File

@ -302,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@ -2,6 +2,8 @@ import asyncio
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Dict,
Iterator,
List,
@ -30,6 +32,7 @@ from langchain_core.runnables.utils import (
Output,
get_unique_config_specs,
)
from langchain_core.utils.aiter import py_anext
if TYPE_CHECKING:
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
@ -415,3 +418,118 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
raise sorted_handled_exceptions[0][1]
to_return.update(handled_exceptions)
return [output for _, output in sorted(to_return.items())] # type: ignore
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
""""""
if self.exception_key is not None and not isinstance(input, dict):
raise ValueError(
"If 'exception_key' is specified then input must be a dictionary."
f"However found a type of {type(input)} for input"
)
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
last_error = None
for runnable in self.runnables:
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
stream = runnable.stream(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
chunk = next(stream)
except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error
last_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
first_error = None
break
if first_error:
run_manager.on_chain_error(first_error)
raise first_error
yield chunk
output: Optional[Output] = chunk
try:
for chunk in stream:
yield chunk
try:
output = output + chunk # type: ignore
except TypeError:
output = None
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(output)
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
if self.exception_key is not None and not isinstance(input, dict):
raise ValueError(
"If 'exception_key' is specified then input must be a dictionary."
f"However found a type of {type(input)} for input"
)
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
last_error = None
for runnable in self.runnables:
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
stream = runnable.astream(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
chunk = await cast(Awaitable[Output], py_anext(stream))
except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error
last_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
first_error = None
break
if first_error:
await run_manager.on_chain_error(first_error)
raise first_error
yield chunk
output: Optional[Output] = chunk
try:
async for chunk in stream:
yield chunk
try:
output = output + chunk # type: ignore
except TypeError:
output = None
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(output)

View File

@ -1,5 +1,5 @@
import sys
from typing import Any
from typing import Any, AsyncIterator, Iterator
import pytest
from syrupy import SnapshotAssertion
@ -8,6 +8,7 @@ from langchain_core.load import dumps
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import (
Runnable,
RunnableGenerator,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
@ -229,3 +230,61 @@ async def test_abatch() -> None:
expected = ["first", "second", RuntimeError()]
_assert_potential_error(actual, expected)
def _generate(input: Iterator) -> Iterator[str]:
yield from "foo bar"
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
raise ValueError()
yield ""
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
yield ""
raise ValueError()
def test_fallbacks_stream() -> None:
runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks(
[RunnableGenerator(_generate)]
)
assert list(runnable.stream({})) == [c for c in "foo bar"]
with pytest.raises(ValueError):
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
[RunnableGenerator(_generate)]
)
list(runnable.stream({}))
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
for c in "foo bar":
yield c
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
raise ValueError()
yield ""
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
yield ""
raise ValueError()
async def test_fallbacks_astream() -> None:
runnable = RunnableGenerator(_agenerate_immediate_error).with_fallbacks(
[RunnableGenerator(_agenerate)]
)
expected = (c for c in "foo bar")
async for c in runnable.astream({}):
assert c == next(expected)
with pytest.raises(ValueError):
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
[RunnableGenerator(_agenerate)]
)
async for c in runnable.astream({}):
pass