From 1e29b676d5ed1e2fbc1ec1fe2e04a4360dda569f Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:31:54 -0800 Subject: [PATCH] core[patch]: simple fallback streaming (#16055) --- .../how_to/fallbacks.ipynb | 2 +- .../langchain_core/runnables/fallbacks.py | 118 ++++++++++++++++++ .../unit_tests/runnables/test_fallbacks.py | 61 ++++++++- 3 files changed, 179 insertions(+), 2 deletions(-) diff --git a/docs/docs/expression_language/how_to/fallbacks.ipynb b/docs/docs/expression_language/how_to/fallbacks.ipynb index 23459f8be7..de915b3240 100644 --- a/docs/docs/expression_language/how_to/fallbacks.ipynb +++ b/docs/docs/expression_language/how_to/fallbacks.ipynb @@ -302,7 +302,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index 7f8ab1f866..bc7128c1bf 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -2,6 +2,8 @@ import asyncio from typing import ( TYPE_CHECKING, Any, + AsyncIterator, + Awaitable, Dict, Iterator, List, @@ -30,6 +32,7 @@ from langchain_core.runnables.utils import ( Output, get_unique_config_specs, ) +from langchain_core.utils.aiter import py_anext if TYPE_CHECKING: from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun @@ -415,3 +418,118 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): raise sorted_handled_exceptions[0][1] to_return.update(handled_exceptions) return [output for _, output in sorted(to_return.items())] # type: ignore + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + """""" + if self.exception_key is not None and not isinstance(input, dict): + raise ValueError( + "If 'exception_key' is specified then input must be a dictionary." + f"However found a type of {type(input)} for input" + ) + # setup callbacks + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + # start the root run + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + first_error = None + last_error = None + for runnable in self.runnables: + try: + if self.exception_key and last_error is not None: + input[self.exception_key] = last_error + stream = runnable.stream( + input, + patch_config(config, callbacks=run_manager.get_child()), + **kwargs, + ) + chunk = next(stream) + except self.exceptions_to_handle as e: + first_error = e if first_error is None else first_error + last_error = e + except BaseException as e: + run_manager.on_chain_error(e) + raise e + else: + first_error = None + break + if first_error: + run_manager.on_chain_error(first_error) + raise first_error + + yield chunk + output: Optional[Output] = chunk + try: + for chunk in stream: + yield chunk + try: + output = output + chunk # type: ignore + except TypeError: + output = None + except BaseException as e: + run_manager.on_chain_error(e) + raise e + run_manager.on_chain_end(output) + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + if self.exception_key is not None and not isinstance(input, dict): + raise ValueError( + "If 'exception_key' is specified then input must be a dictionary." + f"However found a type of {type(input)} for input" + ) + # setup callbacks + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) + # start the root run + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + first_error = None + last_error = None + for runnable in self.runnables: + try: + if self.exception_key and last_error is not None: + input[self.exception_key] = last_error + stream = runnable.astream( + input, + patch_config(config, callbacks=run_manager.get_child()), + **kwargs, + ) + chunk = await cast(Awaitable[Output], py_anext(stream)) + except self.exceptions_to_handle as e: + first_error = e if first_error is None else first_error + last_error = e + except BaseException as e: + await run_manager.on_chain_error(e) + raise e + else: + first_error = None + break + if first_error: + await run_manager.on_chain_error(first_error) + raise first_error + + yield chunk + output: Optional[Output] = chunk + try: + async for chunk in stream: + yield chunk + try: + output = output + chunk # type: ignore + except TypeError: + output = None + except BaseException as e: + await run_manager.on_chain_error(e) + raise e + await run_manager.on_chain_end(output) diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index ecd9cb6fc9..de1447a726 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -1,5 +1,5 @@ import sys -from typing import Any +from typing import Any, AsyncIterator, Iterator import pytest from syrupy import SnapshotAssertion @@ -8,6 +8,7 @@ from langchain_core.load import dumps from langchain_core.prompts import PromptTemplate from langchain_core.runnables import ( Runnable, + RunnableGenerator, RunnableLambda, RunnableParallel, RunnablePassthrough, @@ -229,3 +230,61 @@ async def test_abatch() -> None: expected = ["first", "second", RuntimeError()] _assert_potential_error(actual, expected) + + +def _generate(input: Iterator) -> Iterator[str]: + yield from "foo bar" + + +def _generate_immediate_error(input: Iterator) -> Iterator[str]: + raise ValueError() + yield "" + + +def _generate_delayed_error(input: Iterator) -> Iterator[str]: + yield "" + raise ValueError() + + +def test_fallbacks_stream() -> None: + runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks( + [RunnableGenerator(_generate)] + ) + assert list(runnable.stream({})) == [c for c in "foo bar"] + + with pytest.raises(ValueError): + runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks( + [RunnableGenerator(_generate)] + ) + list(runnable.stream({})) + + +async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]: + for c in "foo bar": + yield c + + +async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]: + raise ValueError() + yield "" + + +async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]: + yield "" + raise ValueError() + + +async def test_fallbacks_astream() -> None: + runnable = RunnableGenerator(_agenerate_immediate_error).with_fallbacks( + [RunnableGenerator(_agenerate)] + ) + expected = (c for c in "foo bar") + async for c in runnable.astream({}): + assert c == next(expected) + + with pytest.raises(ValueError): + runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks( + [RunnableGenerator(_agenerate)] + ) + async for c in runnable.astream({}): + pass