diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 0bb5ddb11b..5ac7ce0dac 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -923,12 +923,17 @@ class Runnable(Generic[Input, Output], ABC): fallbacks: Sequence[Runnable[Input, Output]], *, exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,), + exception_key: Optional[str] = None, ) -> RunnableWithFallbacksT[Input, Output]: """Add fallbacks to a runnable, returning a new Runnable. Args: fallbacks: A sequence of runnables to try if the original runnable fails. exceptions_to_handle: A tuple of exception types to handle. + exception_key: If string is specified then handled exceptions will be passed + to fallbacks as part of the input under the specified key. If None, + exceptions will not be passed to fallbacks. If used, the base runnable + and its fallbacks must accept a dictionary as input. Returns: A new Runnable that will try the original runnable, and then each @@ -940,6 +945,7 @@ class Runnable(Generic[Input, Output], ABC): runnable=self, fallbacks=fallbacks, exceptions_to_handle=exceptions_to_handle, + exception_key=exception_key, ) """ --- Helper methods for Subclasses --- """ diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index 5f6dbf11bf..7f8ab1f866 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -2,6 +2,7 @@ import asyncio from typing import ( TYPE_CHECKING, Any, + Dict, Iterator, List, Optional, @@ -9,6 +10,7 @@ from typing import ( Tuple, Type, Union, + cast, ) from langchain_core.load.dump import dumpd @@ -89,6 +91,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): Any exception that is not a subclass of these exceptions will be raised immediately. """ + exception_key: Optional[str] = None + """If string is specified then handled exceptions will be passed to fallbacks as + part of the input under the specified key. If None, exceptions + will not be passed to fallbacks. If used, the base runnable and its fallbacks + must accept a dictionary as input.""" class Config: arbitrary_types_allowed = True @@ -136,6 +143,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: + if self.exception_key is not None and not isinstance(input, dict): + raise ValueError( + "If 'exception_key' is specified then input must be a dictionary." + f"However found a type of {type(input)} for input" + ) # setup callbacks config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) @@ -144,8 +156,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): dumpd(self), input, name=config.get("run_name") ) first_error = None + last_error = None for runnable in self.runnables: try: + if self.exception_key and last_error is not None: + input[self.exception_key] = last_error output = runnable.invoke( input, patch_config(config, callbacks=run_manager.get_child()), @@ -154,6 +169,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): except self.exceptions_to_handle as e: if first_error is None: first_error = e + last_error = e except BaseException as e: run_manager.on_chain_error(e) raise e @@ -171,6 +187,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: + if self.exception_key is not None and not isinstance(input, dict): + raise ValueError( + "If 'exception_key' is specified then input must be a dictionary." + f"However found a type of {type(input)} for input" + ) # setup callbacks config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) @@ -180,8 +201,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): ) first_error = None + last_error = None for runnable in self.runnables: try: + if self.exception_key and last_error is not None: + input[self.exception_key] = last_error output = await runnable.ainvoke( input, patch_config(config, callbacks=run_manager.get_child()), @@ -190,6 +214,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): except self.exceptions_to_handle as e: if first_error is None: first_error = e + last_error = e except BaseException as e: await run_manager.on_chain_error(e) raise e @@ -211,8 +236,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): ) -> List[Output]: from langchain_core.callbacks.manager import CallbackManager - if return_exceptions: - raise NotImplementedError() + if self.exception_key is not None and not all( + isinstance(input, dict) for input in inputs + ): + raise ValueError( + "If 'exception_key' is specified then inputs must be dictionaries." + f"However found a type of {type(inputs[0])} for input" + ) if not inputs: return [] @@ -241,35 +271,51 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): for cm, input, config in zip(callback_managers, inputs, configs) ] - first_error = None + to_return: Dict[int, Any] = {} + run_again = {i: input for i, input in enumerate(inputs)} + handled_exceptions: Dict[int, BaseException] = {} + first_to_raise = None for runnable in self.runnables: - try: - outputs = runnable.batch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - return_exceptions=return_exceptions, - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - for rm in run_managers: - rm.on_chain_error(e) - raise e - else: - for rm, output in zip(run_managers, outputs): - rm.on_chain_end(output) - return outputs - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - for rm in run_managers: - rm.on_chain_error(first_error) - raise first_error + outputs = runnable.batch( + [input for _, input in sorted(run_again.items())], + [ + # each step a child run of the corresponding root run + patch_config(configs[i], callbacks=run_managers[i].get_child()) + for i in sorted(run_again) + ], + return_exceptions=True, + **kwargs, + ) + for (i, input), output in zip(sorted(run_again.copy().items()), outputs): + if isinstance(output, BaseException) and not isinstance( + output, self.exceptions_to_handle + ): + if not return_exceptions: + first_to_raise = first_to_raise or output + else: + handled_exceptions[i] = cast(BaseException, output) + run_again.pop(i) + elif isinstance(output, self.exceptions_to_handle): + if self.exception_key: + input[self.exception_key] = output # type: ignore + handled_exceptions[i] = cast(BaseException, output) + else: + run_managers[i].on_chain_end(output) + to_return[i] = output + run_again.pop(i) + handled_exceptions.pop(i, None) + if first_to_raise: + raise first_to_raise + if not run_again: + break + + sorted_handled_exceptions = sorted(handled_exceptions.items()) + for i, error in sorted_handled_exceptions: + run_managers[i].on_chain_error(error) + if not return_exceptions and sorted_handled_exceptions: + raise sorted_handled_exceptions[0][1] + to_return.update(handled_exceptions) + return [output for _, output in sorted(to_return.items())] async def abatch( self, @@ -281,8 +327,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): ) -> List[Output]: from langchain_core.callbacks.manager import AsyncCallbackManager - if return_exceptions: - raise NotImplementedError() + if self.exception_key is not None and not all( + isinstance(input, dict) for input in inputs + ): + raise ValueError( + "If 'exception_key' is specified then inputs must be dictionaries." + f"However found a type of {type(inputs[0])} for input" + ) if not inputs: return [] @@ -313,33 +364,54 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): ) ) - first_error = None + to_return = {} + run_again = {i: input for i, input in enumerate(inputs)} + handled_exceptions: Dict[int, BaseException] = {} + first_to_raise = None for runnable in self.runnables: - try: - outputs = await runnable.abatch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - return_exceptions=return_exceptions, - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) - else: - await asyncio.gather( - *( - rm.on_chain_end(output) - for rm, output in zip(run_managers, outputs) - ) - ) - return outputs - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers)) - raise first_error + outputs = await runnable.abatch( + [input for _, input in sorted(run_again.items())], + [ + # each step a child run of the corresponding root run + patch_config(configs[i], callbacks=run_managers[i].get_child()) + for i in sorted(run_again) + ], + return_exceptions=True, + **kwargs, + ) + + for (i, input), output in zip(sorted(run_again.copy().items()), outputs): + if isinstance(output, BaseException) and not isinstance( + output, self.exceptions_to_handle + ): + if not return_exceptions: + first_to_raise = first_to_raise or output + else: + handled_exceptions[i] = cast(BaseException, output) + run_again.pop(i) + elif isinstance(output, self.exceptions_to_handle): + if self.exception_key: + input[self.exception_key] = output # type: ignore + handled_exceptions[i] = cast(BaseException, output) + else: + to_return[i] = output + await run_managers[i].on_chain_end(output) + run_again.pop(i) + handled_exceptions.pop(i, None) + + if first_to_raise: + raise first_to_raise + if not run_again: + break + + sorted_handled_exceptions = sorted(handled_exceptions.items()) + await asyncio.gather( + *( + run_managers[i].on_chain_error(error) + for i, error in sorted_handled_exceptions + ) + ) + if not return_exceptions and sorted_handled_exceptions: + raise sorted_handled_exceptions[0][1] + to_return.update(handled_exceptions) + return [output for _, output in sorted(to_return.items())] # type: ignore diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_fallbacks.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_fallbacks.ambr new file mode 100644 index 0000000000..751274bf76 --- /dev/null +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_fallbacks.ambr @@ -0,0 +1,373 @@ +# serializer version: 1 +# name: test_fallbacks[chain] + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableParallel" + ], + "kwargs": { + "steps": { + "buz": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain_core", + "runnables", + "base", + "RunnableLambda" + ], + "repr": "RunnableLambda(lambda x: x)" + } + } + } + }, + "middle": [], + "last": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableWithFallbacks" + ], + "kwargs": { + "runnable": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "buz" + ], + "template": "what did baz say to {buz}", + "template_format": "f-string", + "partial_variables": {} + } + }, + "middle": [], + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "tests", + "unit_tests", + "fake", + "llm", + "FakeListLLM" + ], + "repr": "FakeListLLM(responses=['foo'], i=1)" + }, + "name": null + } + }, + "fallbacks": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "buz" + ], + "template": "what did baz say to {buz}", + "template_format": "f-string", + "partial_variables": {} + } + }, + "middle": [], + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "tests", + "unit_tests", + "fake", + "llm", + "FakeListLLM" + ], + "repr": "FakeListLLM(responses=['bar'])" + }, + "name": null + } + } + ], + "exceptions_to_handle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "builtins", + "Exception" + ], + "repr": "" + } + ], + "exception_key": null + } + }, + "name": null + } + } + ''' +# --- +# name: test_fallbacks[chain_pass_exceptions] + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableParallel" + ], + "kwargs": { + "steps": { + "text": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnablePassthrough" + ], + "kwargs": { + "func": null, + "afunc": null, + "input_type": null + } + } + } + } + }, + "middle": [], + "last": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableWithFallbacks" + ], + "kwargs": { + "runnable": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain_core", + "runnables", + "base", + "RunnableLambda" + ], + "repr": "RunnableLambda(_raise_error)" + }, + "fallbacks": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain_core", + "runnables", + "base", + "RunnableLambda" + ], + "repr": "RunnableLambda(_dont_raise_error)" + } + ], + "exceptions_to_handle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "builtins", + "Exception" + ], + "repr": "" + } + ], + "exception_key": "exception" + } + }, + "name": null + } + } + ''' +# --- +# name: test_fallbacks[llm] + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableWithFallbacks" + ], + "kwargs": { + "runnable": { + "lc": 1, + "type": "not_implemented", + "id": [ + "tests", + "unit_tests", + "fake", + "llm", + "FakeListLLM" + ], + "repr": "FakeListLLM(responses=['foo'], i=1)" + }, + "fallbacks": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "tests", + "unit_tests", + "fake", + "llm", + "FakeListLLM" + ], + "repr": "FakeListLLM(responses=['bar'])" + } + ], + "exceptions_to_handle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "builtins", + "Exception" + ], + "repr": "" + } + ], + "exception_key": null + } + } + ''' +# --- +# name: test_fallbacks[llm_multi] + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableWithFallbacks" + ], + "kwargs": { + "runnable": { + "lc": 1, + "type": "not_implemented", + "id": [ + "tests", + "unit_tests", + "fake", + "llm", + "FakeListLLM" + ], + "repr": "FakeListLLM(responses=['foo'], i=1)" + }, + "fallbacks": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "tests", + "unit_tests", + "fake", + "llm", + "FakeListLLM" + ], + "repr": "FakeListLLM(responses=['baz'], i=1)" + }, + { + "lc": 1, + "type": "not_implemented", + "id": [ + "tests", + "unit_tests", + "fake", + "llm", + "FakeListLLM" + ], + "repr": "FakeListLLM(responses=['bar'])" + } + ], + "exceptions_to_handle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "builtins", + "Exception" + ], + "repr": "" + } + ], + "exception_key": null + } + } + ''' +# --- diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index 8ff1fe98db..051520c045 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -696,280 +696,6 @@ } ''' # --- -# name: test_llm_with_fallbacks[llm_chain_with_fallbacks] - ''' - { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "runnable", - "RunnableSequence" - ], - "kwargs": { - "first": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "runnable", - "RunnableParallel" - ], - "kwargs": { - "steps": { - "buz": { - "lc": 1, - "type": "not_implemented", - "id": [ - "langchain_core", - "runnables", - "base", - "RunnableLambda" - ], - "repr": "RunnableLambda(lambda x: x)" - } - } - } - }, - "middle": [], - "last": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "runnable", - "RunnableWithFallbacks" - ], - "kwargs": { - "runnable": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "runnable", - "RunnableSequence" - ], - "kwargs": { - "first": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "prompts", - "prompt", - "PromptTemplate" - ], - "kwargs": { - "input_variables": [ - "buz" - ], - "template": "what did baz say to {buz}", - "template_format": "f-string", - "partial_variables": {} - } - }, - "middle": [], - "last": { - "lc": 1, - "type": "not_implemented", - "id": [ - "tests", - "unit_tests", - "fake", - "llm", - "FakeListLLM" - ], - "repr": "FakeListLLM(responses=['foo'], i=1)" - }, - "name": null - } - }, - "fallbacks": [ - { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "runnable", - "RunnableSequence" - ], - "kwargs": { - "first": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "prompts", - "prompt", - "PromptTemplate" - ], - "kwargs": { - "input_variables": [ - "buz" - ], - "template": "what did baz say to {buz}", - "template_format": "f-string", - "partial_variables": {} - } - }, - "middle": [], - "last": { - "lc": 1, - "type": "not_implemented", - "id": [ - "tests", - "unit_tests", - "fake", - "llm", - "FakeListLLM" - ], - "repr": "FakeListLLM(responses=['bar'])" - }, - "name": null - } - } - ], - "exceptions_to_handle": [ - { - "lc": 1, - "type": "not_implemented", - "id": [ - "builtins", - "Exception" - ], - "repr": "" - } - ] - } - }, - "name": null - } - } - ''' -# --- -# name: test_llm_with_fallbacks[llm_with_fallbacks] - ''' - { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "runnable", - "RunnableWithFallbacks" - ], - "kwargs": { - "runnable": { - "lc": 1, - "type": "not_implemented", - "id": [ - "tests", - "unit_tests", - "fake", - "llm", - "FakeListLLM" - ], - "repr": "FakeListLLM(responses=['foo'], i=1)" - }, - "fallbacks": [ - { - "lc": 1, - "type": "not_implemented", - "id": [ - "tests", - "unit_tests", - "fake", - "llm", - "FakeListLLM" - ], - "repr": "FakeListLLM(responses=['bar'])" - } - ], - "exceptions_to_handle": [ - { - "lc": 1, - "type": "not_implemented", - "id": [ - "builtins", - "Exception" - ], - "repr": "" - } - ] - } - } - ''' -# --- -# name: test_llm_with_fallbacks[llm_with_multi_fallbacks] - ''' - { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "runnable", - "RunnableWithFallbacks" - ], - "kwargs": { - "runnable": { - "lc": 1, - "type": "not_implemented", - "id": [ - "tests", - "unit_tests", - "fake", - "llm", - "FakeListLLM" - ], - "repr": "FakeListLLM(responses=['foo'], i=1)" - }, - "fallbacks": [ - { - "lc": 1, - "type": "not_implemented", - "id": [ - "tests", - "unit_tests", - "fake", - "llm", - "FakeListLLM" - ], - "repr": "FakeListLLM(responses=['baz'], i=1)" - }, - { - "lc": 1, - "type": "not_implemented", - "id": [ - "tests", - "unit_tests", - "fake", - "llm", - "FakeListLLM" - ], - "repr": "FakeListLLM(responses=['bar'])" - } - ], - "exceptions_to_handle": [ - { - "lc": 1, - "type": "not_implemented", - "id": [ - "builtins", - "Exception" - ], - "repr": "" - } - ] - } - } - ''' -# --- # name: test_prompt_with_chat_model ''' ChatPromptTemplate(input_variables=['question'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a nice assistant.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='{question}'))]) diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py new file mode 100644 index 0000000000..ecd9cb6fc9 --- /dev/null +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -0,0 +1,231 @@ +import sys +from typing import Any + +import pytest +from syrupy import SnapshotAssertion + +from langchain_core.load import dumps +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import ( + Runnable, + RunnableLambda, + RunnableParallel, + RunnablePassthrough, + RunnableWithFallbacks, +) +from tests.unit_tests.fake.llm import FakeListLLM + + +@pytest.fixture() +def llm() -> RunnableWithFallbacks: + error_llm = FakeListLLM(responses=["foo"], i=1) + pass_llm = FakeListLLM(responses=["bar"]) + + return error_llm.with_fallbacks([pass_llm]) + + +@pytest.fixture() +def llm_multi() -> RunnableWithFallbacks: + error_llm = FakeListLLM(responses=["foo"], i=1) + error_llm_2 = FakeListLLM(responses=["baz"], i=1) + pass_llm = FakeListLLM(responses=["bar"]) + + return error_llm.with_fallbacks([error_llm_2, pass_llm]) + + +@pytest.fixture() +def chain() -> Runnable: + error_llm = FakeListLLM(responses=["foo"], i=1) + pass_llm = FakeListLLM(responses=["bar"]) + + prompt = PromptTemplate.from_template("what did baz say to {buz}") + return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks( + [prompt | pass_llm] + ) + + +def _raise_error(inputs: dict) -> str: + raise ValueError() + + +def _dont_raise_error(inputs: dict) -> str: + if "exception" in inputs: + return "bar" + raise ValueError() + + +@pytest.fixture() +def chain_pass_exceptions() -> Runnable: + fallback = RunnableLambda(_dont_raise_error) + return {"text": RunnablePassthrough()} | RunnableLambda( + _raise_error + ).with_fallbacks([fallback], exception_key="exception") + + +@pytest.mark.parametrize( + "runnable", + ["llm", "llm_multi", "chain", "chain_pass_exceptions"], +) +async def test_fallbacks( + runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion +) -> None: + runnable = request.getfixturevalue(runnable) + assert runnable.invoke("hello") == "bar" + assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3 + assert list(runnable.stream("hello")) == ["bar"] + assert await runnable.ainvoke("hello") == "bar" + assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 + assert list(await runnable.ainvoke("hello")) == list("bar") + if sys.version_info >= (3, 9): + assert dumps(runnable, pretty=True) == snapshot + + +def _runnable(inputs: dict) -> str: + if inputs["text"] == "foo": + return "first" + if "exception" not in inputs: + raise ValueError() + if inputs["text"] == "bar": + return "second" + if isinstance(inputs["exception"], ValueError): + raise RuntimeError() + return "third" + + +def _assert_potential_error(actual: list, expected: list) -> None: + for x, y in zip(actual, expected): + if isinstance(x, Exception): + assert isinstance(y, type(x)) + else: + assert x == y + + +def test_invoke_with_exception_key() -> None: + runnable = RunnableLambda(_runnable) + runnable_with_single = runnable.with_fallbacks( + [runnable], exception_key="exception" + ) + with pytest.raises(ValueError): + runnable_with_single.invoke({"text": "baz"}) + + actual = runnable_with_single.invoke({"text": "bar"}) + expected = "second" + _assert_potential_error([actual], [expected]) + + runnable_with_double = runnable.with_fallbacks( + [runnable, runnable], exception_key="exception" + ) + actual = runnable_with_double.invoke({"text": "baz"}) + + expected = "third" + _assert_potential_error([actual], [expected]) + + +async def test_ainvoke_with_exception_key() -> None: + runnable = RunnableLambda(_runnable) + runnable_with_single = runnable.with_fallbacks( + [runnable], exception_key="exception" + ) + with pytest.raises(ValueError): + await runnable_with_single.ainvoke({"text": "baz"}) + + actual = await runnable_with_single.ainvoke({"text": "bar"}) + expected = "second" + _assert_potential_error([actual], [expected]) + + runnable_with_double = runnable.with_fallbacks( + [runnable, runnable], exception_key="exception" + ) + actual = await runnable_with_double.ainvoke({"text": "baz"}) + expected = "third" + _assert_potential_error([actual], [expected]) + + +def test_batch() -> None: + runnable = RunnableLambda(_runnable) + with pytest.raises(ValueError): + runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}]) + actual = runnable.batch( + [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True + ) + expected = ["first", ValueError(), ValueError()] + _assert_potential_error(actual, expected) + + runnable_with_single = runnable.with_fallbacks( + [runnable], exception_key="exception" + ) + with pytest.raises(RuntimeError): + runnable_with_single.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}]) + actual = runnable_with_single.batch( + [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True + ) + expected = ["first", "second", RuntimeError()] + _assert_potential_error(actual, expected) + + runnable_with_double = runnable.with_fallbacks( + [runnable, runnable], exception_key="exception" + ) + actual = runnable_with_double.batch( + [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True + ) + + expected = ["first", "second", "third"] + _assert_potential_error(actual, expected) + + runnable_with_double = runnable.with_fallbacks( + [runnable, runnable], + exception_key="exception", + exceptions_to_handle=(ValueError,), + ) + actual = runnable_with_double.batch( + [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True + ) + + expected = ["first", "second", RuntimeError()] + _assert_potential_error(actual, expected) + + +async def test_abatch() -> None: + runnable = RunnableLambda(_runnable) + with pytest.raises(ValueError): + await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}]) + actual = await runnable.abatch( + [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True + ) + expected = ["first", ValueError(), ValueError()] + _assert_potential_error(actual, expected) + + runnable_with_single = runnable.with_fallbacks( + [runnable], exception_key="exception" + ) + with pytest.raises(RuntimeError): + await runnable_with_single.abatch( + [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}] + ) + actual = await runnable_with_single.abatch( + [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True + ) + expected = ["first", "second", RuntimeError()] + _assert_potential_error(actual, expected) + + runnable_with_double = runnable.with_fallbacks( + [runnable, runnable], exception_key="exception" + ) + actual = await runnable_with_double.abatch( + [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True + ) + + expected = ["first", "second", "third"] + _assert_potential_error(actual, expected) + + runnable_with_double = runnable.with_fallbacks( + [runnable, runnable], + exception_key="exception", + exceptions_to_handle=(ValueError,), + ) + actual = await runnable_with_double.abatch( + [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True + ) + + expected = ["first", "second", RuntimeError()] + _assert_potential_error(actual, expected) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 2a94cf2469..49e729b595 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -66,7 +66,6 @@ from langchain_core.runnables import ( RunnablePassthrough, RunnablePick, RunnableSequence, - RunnableWithFallbacks, add, chain, ) @@ -3683,52 +3682,6 @@ async def test_runnable_sequence_atransform() -> None: assert "".join(chunks) == "foo-lish" -@pytest.fixture() -def llm_with_fallbacks() -> RunnableWithFallbacks: - error_llm = FakeListLLM(responses=["foo"], i=1) - pass_llm = FakeListLLM(responses=["bar"]) - - return error_llm.with_fallbacks([pass_llm]) - - -@pytest.fixture() -def llm_with_multi_fallbacks() -> RunnableWithFallbacks: - error_llm = FakeListLLM(responses=["foo"], i=1) - error_llm_2 = FakeListLLM(responses=["baz"], i=1) - pass_llm = FakeListLLM(responses=["bar"]) - - return error_llm.with_fallbacks([error_llm_2, pass_llm]) - - -@pytest.fixture() -def llm_chain_with_fallbacks() -> Runnable: - error_llm = FakeListLLM(responses=["foo"], i=1) - pass_llm = FakeListLLM(responses=["bar"]) - - prompt = PromptTemplate.from_template("what did baz say to {buz}") - return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks( - [prompt | pass_llm] - ) - - -@pytest.mark.parametrize( - "runnable", - ["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"], -) -async def test_llm_with_fallbacks( - runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion -) -> None: - runnable = request.getfixturevalue(runnable) - assert runnable.invoke("hello") == "bar" - assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3 - assert list(runnable.stream("hello")) == ["bar"] - assert await runnable.ainvoke("hello") == "bar" - assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 - assert list(await runnable.ainvoke("hello")) == list("bar") - if sys.version_info >= (3, 9): - assert dumps(runnable, pretty=True) == snapshot - - class FakeSplitIntoListParser(BaseOutputParser[List[str]]): """Parse the output of an LLM call to a comma-separated list."""