core[patch]: pass exceptions to fallbacks (#16048)

This commit is contained in:
Bagatur 2024-01-16 09:36:43 -08:00 committed by GitHub
parent 770f57196e
commit c5656a4905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 743 additions and 382 deletions

View File

@ -923,12 +923,17 @@ class Runnable(Generic[Input, Output], ABC):
fallbacks: Sequence[Runnable[Input, Output]],
*,
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
exception_key: Optional[str] = None,
) -> RunnableWithFallbacksT[Input, Output]:
"""Add fallbacks to a runnable, returning a new Runnable.
Args:
fallbacks: A sequence of runnables to try if the original runnable fails.
exceptions_to_handle: A tuple of exception types to handle.
exception_key: If string is specified then handled exceptions will be passed
to fallbacks as part of the input under the specified key. If None,
exceptions will not be passed to fallbacks. If used, the base runnable
and its fallbacks must accept a dictionary as input.
Returns:
A new Runnable that will try the original runnable, and then each
@ -940,6 +945,7 @@ class Runnable(Generic[Input, Output], ABC):
runnable=self,
fallbacks=fallbacks,
exceptions_to_handle=exceptions_to_handle,
exception_key=exception_key,
)
""" --- Helper methods for Subclasses --- """

View File

@ -2,6 +2,7 @@ import asyncio
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Optional,
@ -9,6 +10,7 @@ from typing import (
Tuple,
Type,
Union,
cast,
)
from langchain_core.load.dump import dumpd
@ -89,6 +91,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
Any exception that is not a subclass of these exceptions will be raised immediately.
"""
exception_key: Optional[str] = None
"""If string is specified then handled exceptions will be passed to fallbacks as
part of the input under the specified key. If None, exceptions
will not be passed to fallbacks. If used, the base runnable and its fallbacks
must accept a dictionary as input."""
class Config:
arbitrary_types_allowed = True
@ -136,6 +143,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
if self.exception_key is not None and not isinstance(input, dict):
raise ValueError(
"If 'exception_key' is specified then input must be a dictionary."
f"However found a type of {type(input)} for input"
)
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
@ -144,8 +156,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
dumpd(self), input, name=config.get("run_name")
)
first_error = None
last_error = None
for runnable in self.runnables:
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
output = runnable.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
@ -154,6 +169,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
last_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
@ -171,6 +187,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
if self.exception_key is not None and not isinstance(input, dict):
raise ValueError(
"If 'exception_key' is specified then input must be a dictionary."
f"However found a type of {type(input)} for input"
)
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
@ -180,8 +201,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
)
first_error = None
last_error = None
for runnable in self.runnables:
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
output = await runnable.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
@ -190,6 +214,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
last_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
@ -211,8 +236,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
) -> List[Output]:
from langchain_core.callbacks.manager import CallbackManager
if return_exceptions:
raise NotImplementedError()
if self.exception_key is not None and not all(
isinstance(input, dict) for input in inputs
):
raise ValueError(
"If 'exception_key' is specified then inputs must be dictionaries."
f"However found a type of {type(inputs[0])} for input"
)
if not inputs:
return []
@ -241,35 +271,51 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
for cm, input, config in zip(callback_managers, inputs, configs)
]
first_error = None
to_return: Dict[int, Any] = {}
run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
try:
outputs = runnable.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
for rm in run_managers:
rm.on_chain_error(e)
raise e
else:
for rm, output in zip(run_managers, outputs):
rm.on_chain_end(output)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
for rm in run_managers:
rm.on_chain_error(first_error)
raise first_error
outputs = runnable.batch(
[input for _, input in sorted(run_again.items())],
[
# each step a child run of the corresponding root run
patch_config(configs[i], callbacks=run_managers[i].get_child())
for i in sorted(run_again)
],
return_exceptions=True,
**kwargs,
)
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
if isinstance(output, BaseException) and not isinstance(
output, self.exceptions_to_handle
):
if not return_exceptions:
first_to_raise = first_to_raise or output
else:
handled_exceptions[i] = cast(BaseException, output)
run_again.pop(i)
elif isinstance(output, self.exceptions_to_handle):
if self.exception_key:
input[self.exception_key] = output # type: ignore
handled_exceptions[i] = cast(BaseException, output)
else:
run_managers[i].on_chain_end(output)
to_return[i] = output
run_again.pop(i)
handled_exceptions.pop(i, None)
if first_to_raise:
raise first_to_raise
if not run_again:
break
sorted_handled_exceptions = sorted(handled_exceptions.items())
for i, error in sorted_handled_exceptions:
run_managers[i].on_chain_error(error)
if not return_exceptions and sorted_handled_exceptions:
raise sorted_handled_exceptions[0][1]
to_return.update(handled_exceptions)
return [output for _, output in sorted(to_return.items())]
async def abatch(
self,
@ -281,8 +327,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
) -> List[Output]:
from langchain_core.callbacks.manager import AsyncCallbackManager
if return_exceptions:
raise NotImplementedError()
if self.exception_key is not None and not all(
isinstance(input, dict) for input in inputs
):
raise ValueError(
"If 'exception_key' is specified then inputs must be dictionaries."
f"However found a type of {type(inputs[0])} for input"
)
if not inputs:
return []
@ -313,33 +364,54 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
)
)
first_error = None
to_return = {}
run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
try:
outputs = await runnable.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
else:
await asyncio.gather(
*(
rm.on_chain_end(output)
for rm, output in zip(run_managers, outputs)
)
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
raise first_error
outputs = await runnable.abatch(
[input for _, input in sorted(run_again.items())],
[
# each step a child run of the corresponding root run
patch_config(configs[i], callbacks=run_managers[i].get_child())
for i in sorted(run_again)
],
return_exceptions=True,
**kwargs,
)
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
if isinstance(output, BaseException) and not isinstance(
output, self.exceptions_to_handle
):
if not return_exceptions:
first_to_raise = first_to_raise or output
else:
handled_exceptions[i] = cast(BaseException, output)
run_again.pop(i)
elif isinstance(output, self.exceptions_to_handle):
if self.exception_key:
input[self.exception_key] = output # type: ignore
handled_exceptions[i] = cast(BaseException, output)
else:
to_return[i] = output
await run_managers[i].on_chain_end(output)
run_again.pop(i)
handled_exceptions.pop(i, None)
if first_to_raise:
raise first_to_raise
if not run_again:
break
sorted_handled_exceptions = sorted(handled_exceptions.items())
await asyncio.gather(
*(
run_managers[i].on_chain_error(error)
for i, error in sorted_handled_exceptions
)
)
if not return_exceptions and sorted_handled_exceptions:
raise sorted_handled_exceptions[0][1]
to_return.update(handled_exceptions)
return [output for _, output in sorted(to_return.items())] # type: ignore

View File

@ -0,0 +1,373 @@
# serializer version: 1
# name: test_fallbacks[chain]
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableSequence"
],
"kwargs": {
"first": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableParallel"
],
"kwargs": {
"steps": {
"buz": {
"lc": 1,
"type": "not_implemented",
"id": [
"langchain_core",
"runnables",
"base",
"RunnableLambda"
],
"repr": "RunnableLambda(lambda x: x)"
}
}
}
},
"middle": [],
"last": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableWithFallbacks"
],
"kwargs": {
"runnable": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableSequence"
],
"kwargs": {
"first": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"buz"
],
"template": "what did baz say to {buz}",
"template_format": "f-string",
"partial_variables": {}
}
},
"middle": [],
"last": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['foo'], i=1)"
},
"name": null
}
},
"fallbacks": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableSequence"
],
"kwargs": {
"first": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"buz"
],
"template": "what did baz say to {buz}",
"template_format": "f-string",
"partial_variables": {}
}
},
"middle": [],
"last": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['bar'])"
},
"name": null
}
}
],
"exceptions_to_handle": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"builtins",
"Exception"
],
"repr": "<class 'Exception'>"
}
],
"exception_key": null
}
},
"name": null
}
}
'''
# ---
# name: test_fallbacks[chain_pass_exceptions]
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableSequence"
],
"kwargs": {
"first": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableParallel"
],
"kwargs": {
"steps": {
"text": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnablePassthrough"
],
"kwargs": {
"func": null,
"afunc": null,
"input_type": null
}
}
}
}
},
"middle": [],
"last": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableWithFallbacks"
],
"kwargs": {
"runnable": {
"lc": 1,
"type": "not_implemented",
"id": [
"langchain_core",
"runnables",
"base",
"RunnableLambda"
],
"repr": "RunnableLambda(_raise_error)"
},
"fallbacks": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"langchain_core",
"runnables",
"base",
"RunnableLambda"
],
"repr": "RunnableLambda(_dont_raise_error)"
}
],
"exceptions_to_handle": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"builtins",
"Exception"
],
"repr": "<class 'Exception'>"
}
],
"exception_key": "exception"
}
},
"name": null
}
}
'''
# ---
# name: test_fallbacks[llm]
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableWithFallbacks"
],
"kwargs": {
"runnable": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['foo'], i=1)"
},
"fallbacks": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['bar'])"
}
],
"exceptions_to_handle": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"builtins",
"Exception"
],
"repr": "<class 'Exception'>"
}
],
"exception_key": null
}
}
'''
# ---
# name: test_fallbacks[llm_multi]
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableWithFallbacks"
],
"kwargs": {
"runnable": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['foo'], i=1)"
},
"fallbacks": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['baz'], i=1)"
},
{
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['bar'])"
}
],
"exceptions_to_handle": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"builtins",
"Exception"
],
"repr": "<class 'Exception'>"
}
],
"exception_key": null
}
}
'''
# ---

View File

@ -696,280 +696,6 @@
}
'''
# ---
# name: test_llm_with_fallbacks[llm_chain_with_fallbacks]
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableSequence"
],
"kwargs": {
"first": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableParallel"
],
"kwargs": {
"steps": {
"buz": {
"lc": 1,
"type": "not_implemented",
"id": [
"langchain_core",
"runnables",
"base",
"RunnableLambda"
],
"repr": "RunnableLambda(lambda x: x)"
}
}
}
},
"middle": [],
"last": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableWithFallbacks"
],
"kwargs": {
"runnable": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableSequence"
],
"kwargs": {
"first": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"buz"
],
"template": "what did baz say to {buz}",
"template_format": "f-string",
"partial_variables": {}
}
},
"middle": [],
"last": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['foo'], i=1)"
},
"name": null
}
},
"fallbacks": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableSequence"
],
"kwargs": {
"first": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"buz"
],
"template": "what did baz say to {buz}",
"template_format": "f-string",
"partial_variables": {}
}
},
"middle": [],
"last": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['bar'])"
},
"name": null
}
}
],
"exceptions_to_handle": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"builtins",
"Exception"
],
"repr": "<class 'Exception'>"
}
]
}
},
"name": null
}
}
'''
# ---
# name: test_llm_with_fallbacks[llm_with_fallbacks]
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableWithFallbacks"
],
"kwargs": {
"runnable": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['foo'], i=1)"
},
"fallbacks": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['bar'])"
}
],
"exceptions_to_handle": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"builtins",
"Exception"
],
"repr": "<class 'Exception'>"
}
]
}
}
'''
# ---
# name: test_llm_with_fallbacks[llm_with_multi_fallbacks]
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"runnable",
"RunnableWithFallbacks"
],
"kwargs": {
"runnable": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['foo'], i=1)"
},
"fallbacks": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['baz'], i=1)"
},
{
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"fake",
"llm",
"FakeListLLM"
],
"repr": "FakeListLLM(responses=['bar'])"
}
],
"exceptions_to_handle": [
{
"lc": 1,
"type": "not_implemented",
"id": [
"builtins",
"Exception"
],
"repr": "<class 'Exception'>"
}
]
}
}
'''
# ---
# name: test_prompt_with_chat_model
'''
ChatPromptTemplate(input_variables=['question'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a nice assistant.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='{question}'))])

View File

@ -0,0 +1,231 @@
import sys
from typing import Any
import pytest
from syrupy import SnapshotAssertion
from langchain_core.load import dumps
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
RunnableWithFallbacks,
)
from tests.unit_tests.fake.llm import FakeListLLM
@pytest.fixture()
def llm() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
return error_llm.with_fallbacks([pass_llm])
@pytest.fixture()
def llm_multi() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
return error_llm.with_fallbacks([error_llm_2, pass_llm])
@pytest.fixture()
def chain() -> Runnable:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
prompt = PromptTemplate.from_template("what did baz say to {buz}")
return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
[prompt | pass_llm]
)
def _raise_error(inputs: dict) -> str:
raise ValueError()
def _dont_raise_error(inputs: dict) -> str:
if "exception" in inputs:
return "bar"
raise ValueError()
@pytest.fixture()
def chain_pass_exceptions() -> Runnable:
fallback = RunnableLambda(_dont_raise_error)
return {"text": RunnablePassthrough()} | RunnableLambda(
_raise_error
).with_fallbacks([fallback], exception_key="exception")
@pytest.mark.parametrize(
"runnable",
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
)
async def test_fallbacks(
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
) -> None:
runnable = request.getfixturevalue(runnable)
assert runnable.invoke("hello") == "bar"
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(runnable.stream("hello")) == ["bar"]
assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar")
if sys.version_info >= (3, 9):
assert dumps(runnable, pretty=True) == snapshot
def _runnable(inputs: dict) -> str:
if inputs["text"] == "foo":
return "first"
if "exception" not in inputs:
raise ValueError()
if inputs["text"] == "bar":
return "second"
if isinstance(inputs["exception"], ValueError):
raise RuntimeError()
return "third"
def _assert_potential_error(actual: list, expected: list) -> None:
for x, y in zip(actual, expected):
if isinstance(x, Exception):
assert isinstance(y, type(x))
else:
assert x == y
def test_invoke_with_exception_key() -> None:
runnable = RunnableLambda(_runnable)
runnable_with_single = runnable.with_fallbacks(
[runnable], exception_key="exception"
)
with pytest.raises(ValueError):
runnable_with_single.invoke({"text": "baz"})
actual = runnable_with_single.invoke({"text": "bar"})
expected = "second"
_assert_potential_error([actual], [expected])
runnable_with_double = runnable.with_fallbacks(
[runnable, runnable], exception_key="exception"
)
actual = runnable_with_double.invoke({"text": "baz"})
expected = "third"
_assert_potential_error([actual], [expected])
async def test_ainvoke_with_exception_key() -> None:
runnable = RunnableLambda(_runnable)
runnable_with_single = runnable.with_fallbacks(
[runnable], exception_key="exception"
)
with pytest.raises(ValueError):
await runnable_with_single.ainvoke({"text": "baz"})
actual = await runnable_with_single.ainvoke({"text": "bar"})
expected = "second"
_assert_potential_error([actual], [expected])
runnable_with_double = runnable.with_fallbacks(
[runnable, runnable], exception_key="exception"
)
actual = await runnable_with_double.ainvoke({"text": "baz"})
expected = "third"
_assert_potential_error([actual], [expected])
def test_batch() -> None:
runnable = RunnableLambda(_runnable)
with pytest.raises(ValueError):
runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
actual = runnable.batch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
)
expected = ["first", ValueError(), ValueError()]
_assert_potential_error(actual, expected)
runnable_with_single = runnable.with_fallbacks(
[runnable], exception_key="exception"
)
with pytest.raises(RuntimeError):
runnable_with_single.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
actual = runnable_with_single.batch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
)
expected = ["first", "second", RuntimeError()]
_assert_potential_error(actual, expected)
runnable_with_double = runnable.with_fallbacks(
[runnable, runnable], exception_key="exception"
)
actual = runnable_with_double.batch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
)
expected = ["first", "second", "third"]
_assert_potential_error(actual, expected)
runnable_with_double = runnable.with_fallbacks(
[runnable, runnable],
exception_key="exception",
exceptions_to_handle=(ValueError,),
)
actual = runnable_with_double.batch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
)
expected = ["first", "second", RuntimeError()]
_assert_potential_error(actual, expected)
async def test_abatch() -> None:
runnable = RunnableLambda(_runnable)
with pytest.raises(ValueError):
await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
actual = await runnable.abatch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
)
expected = ["first", ValueError(), ValueError()]
_assert_potential_error(actual, expected)
runnable_with_single = runnable.with_fallbacks(
[runnable], exception_key="exception"
)
with pytest.raises(RuntimeError):
await runnable_with_single.abatch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}]
)
actual = await runnable_with_single.abatch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
)
expected = ["first", "second", RuntimeError()]
_assert_potential_error(actual, expected)
runnable_with_double = runnable.with_fallbacks(
[runnable, runnable], exception_key="exception"
)
actual = await runnable_with_double.abatch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
)
expected = ["first", "second", "third"]
_assert_potential_error(actual, expected)
runnable_with_double = runnable.with_fallbacks(
[runnable, runnable],
exception_key="exception",
exceptions_to_handle=(ValueError,),
)
actual = await runnable_with_double.abatch(
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
)
expected = ["first", "second", RuntimeError()]
_assert_potential_error(actual, expected)

View File

@ -66,7 +66,6 @@ from langchain_core.runnables import (
RunnablePassthrough,
RunnablePick,
RunnableSequence,
RunnableWithFallbacks,
add,
chain,
)
@ -3683,52 +3682,6 @@ async def test_runnable_sequence_atransform() -> None:
assert "".join(chunks) == "foo-lish"
@pytest.fixture()
def llm_with_fallbacks() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
return error_llm.with_fallbacks([pass_llm])
@pytest.fixture()
def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
return error_llm.with_fallbacks([error_llm_2, pass_llm])
@pytest.fixture()
def llm_chain_with_fallbacks() -> Runnable:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
prompt = PromptTemplate.from_template("what did baz say to {buz}")
return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
[prompt | pass_llm]
)
@pytest.mark.parametrize(
"runnable",
["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"],
)
async def test_llm_with_fallbacks(
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
) -> None:
runnable = request.getfixturevalue(runnable)
assert runnable.invoke("hello") == "bar"
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(runnable.stream("hello")) == ["bar"]
assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar")
if sys.version_info >= (3, 9):
assert dumps(runnable, pretty=True) == snapshot
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a comma-separated list."""