mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
core[patch]: pass exceptions to fallbacks (#16048)
This commit is contained in:
parent
770f57196e
commit
c5656a4905
@ -923,12 +923,17 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
*,
|
||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
|
||||
exception_key: Optional[str] = None,
|
||||
) -> RunnableWithFallbacksT[Input, Output]:
|
||||
"""Add fallbacks to a runnable, returning a new Runnable.
|
||||
|
||||
Args:
|
||||
fallbacks: A sequence of runnables to try if the original runnable fails.
|
||||
exceptions_to_handle: A tuple of exception types to handle.
|
||||
exception_key: If string is specified then handled exceptions will be passed
|
||||
to fallbacks as part of the input under the specified key. If None,
|
||||
exceptions will not be passed to fallbacks. If used, the base runnable
|
||||
and its fallbacks must accept a dictionary as input.
|
||||
|
||||
Returns:
|
||||
A new Runnable that will try the original runnable, and then each
|
||||
@ -940,6 +945,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
runnable=self,
|
||||
fallbacks=fallbacks,
|
||||
exceptions_to_handle=exceptions_to_handle,
|
||||
exception_key=exception_key,
|
||||
)
|
||||
|
||||
""" --- Helper methods for Subclasses --- """
|
||||
|
@ -2,6 +2,7 @@ import asyncio
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
@ -9,6 +10,7 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.load.dump import dumpd
|
||||
@ -89,6 +91,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
|
||||
Any exception that is not a subclass of these exceptions will be raised immediately.
|
||||
"""
|
||||
exception_key: Optional[str] = None
|
||||
"""If string is specified then handled exceptions will be passed to fallbacks as
|
||||
part of the input under the specified key. If None, exceptions
|
||||
will not be passed to fallbacks. If used, the base runnable and its fallbacks
|
||||
must accept a dictionary as input."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@ -136,6 +143,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
if self.exception_key is not None and not isinstance(input, dict):
|
||||
raise ValueError(
|
||||
"If 'exception_key' is specified then input must be a dictionary."
|
||||
f"However found a type of {type(input)} for input"
|
||||
)
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
@ -144,8 +156,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
first_error = None
|
||||
last_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
if self.exception_key and last_error is not None:
|
||||
input[self.exception_key] = last_error
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
@ -154,6 +169,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
last_error = e
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
@ -171,6 +187,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
if self.exception_key is not None and not isinstance(input, dict):
|
||||
raise ValueError(
|
||||
"If 'exception_key' is specified then input must be a dictionary."
|
||||
f"However found a type of {type(input)} for input"
|
||||
)
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
@ -180,8 +201,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
first_error = None
|
||||
last_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
if self.exception_key and last_error is not None:
|
||||
input[self.exception_key] = last_error
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
@ -190,6 +214,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
last_error = e
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
@ -211,8 +236,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
) -> List[Output]:
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
if self.exception_key is not None and not all(
|
||||
isinstance(input, dict) for input in inputs
|
||||
):
|
||||
raise ValueError(
|
||||
"If 'exception_key' is specified then inputs must be dictionaries."
|
||||
f"However found a type of {type(inputs[0])} for input"
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
@ -241,35 +271,51 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
|
||||
first_error = None
|
||||
to_return: Dict[int, Any] = {}
|
||||
run_again = {i: input for i, input in enumerate(inputs)}
|
||||
handled_exceptions: Dict[int, BaseException] = {}
|
||||
first_to_raise = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = runnable.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
for rm, output in zip(run_managers, outputs):
|
||||
rm.on_chain_end(output)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(first_error)
|
||||
raise first_error
|
||||
outputs = runnable.batch(
|
||||
[input for _, input in sorted(run_again.items())],
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
||||
for i in sorted(run_again)
|
||||
],
|
||||
return_exceptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
|
||||
if isinstance(output, BaseException) and not isinstance(
|
||||
output, self.exceptions_to_handle
|
||||
):
|
||||
if not return_exceptions:
|
||||
first_to_raise = first_to_raise or output
|
||||
else:
|
||||
handled_exceptions[i] = cast(BaseException, output)
|
||||
run_again.pop(i)
|
||||
elif isinstance(output, self.exceptions_to_handle):
|
||||
if self.exception_key:
|
||||
input[self.exception_key] = output # type: ignore
|
||||
handled_exceptions[i] = cast(BaseException, output)
|
||||
else:
|
||||
run_managers[i].on_chain_end(output)
|
||||
to_return[i] = output
|
||||
run_again.pop(i)
|
||||
handled_exceptions.pop(i, None)
|
||||
if first_to_raise:
|
||||
raise first_to_raise
|
||||
if not run_again:
|
||||
break
|
||||
|
||||
sorted_handled_exceptions = sorted(handled_exceptions.items())
|
||||
for i, error in sorted_handled_exceptions:
|
||||
run_managers[i].on_chain_error(error)
|
||||
if not return_exceptions and sorted_handled_exceptions:
|
||||
raise sorted_handled_exceptions[0][1]
|
||||
to_return.update(handled_exceptions)
|
||||
return [output for _, output in sorted(to_return.items())]
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
@ -281,8 +327,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
) -> List[Output]:
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
if self.exception_key is not None and not all(
|
||||
isinstance(input, dict) for input in inputs
|
||||
):
|
||||
raise ValueError(
|
||||
"If 'exception_key' is specified then inputs must be dictionaries."
|
||||
f"However found a type of {type(inputs[0])} for input"
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
@ -313,33 +364,54 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
)
|
||||
|
||||
first_error = None
|
||||
to_return = {}
|
||||
run_again = {i: input for i, input in enumerate(inputs)}
|
||||
handled_exceptions: Dict[int, BaseException] = {}
|
||||
first_to_raise = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = await runnable.abatch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
rm.on_chain_end(output)
|
||||
for rm, output in zip(run_managers, outputs)
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
||||
raise first_error
|
||||
outputs = await runnable.abatch(
|
||||
[input for _, input in sorted(run_again.items())],
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
||||
for i in sorted(run_again)
|
||||
],
|
||||
return_exceptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
|
||||
if isinstance(output, BaseException) and not isinstance(
|
||||
output, self.exceptions_to_handle
|
||||
):
|
||||
if not return_exceptions:
|
||||
first_to_raise = first_to_raise or output
|
||||
else:
|
||||
handled_exceptions[i] = cast(BaseException, output)
|
||||
run_again.pop(i)
|
||||
elif isinstance(output, self.exceptions_to_handle):
|
||||
if self.exception_key:
|
||||
input[self.exception_key] = output # type: ignore
|
||||
handled_exceptions[i] = cast(BaseException, output)
|
||||
else:
|
||||
to_return[i] = output
|
||||
await run_managers[i].on_chain_end(output)
|
||||
run_again.pop(i)
|
||||
handled_exceptions.pop(i, None)
|
||||
|
||||
if first_to_raise:
|
||||
raise first_to_raise
|
||||
if not run_again:
|
||||
break
|
||||
|
||||
sorted_handled_exceptions = sorted(handled_exceptions.items())
|
||||
await asyncio.gather(
|
||||
*(
|
||||
run_managers[i].on_chain_error(error)
|
||||
for i, error in sorted_handled_exceptions
|
||||
)
|
||||
)
|
||||
if not return_exceptions and sorted_handled_exceptions:
|
||||
raise sorted_handled_exceptions[0][1]
|
||||
to_return.update(handled_exceptions)
|
||||
return [output for _, output in sorted(to_return.items())] # type: ignore
|
||||
|
@ -0,0 +1,373 @@
|
||||
# serializer version: 1
|
||||
# name: test_fallbacks[chain]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableParallel"
|
||||
],
|
||||
"kwargs": {
|
||||
"steps": {
|
||||
"buz": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(lambda x: x)"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"buz"
|
||||
],
|
||||
"template": "what did baz say to {buz}",
|
||||
"template_format": "f-string",
|
||||
"partial_variables": {}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"buz"
|
||||
],
|
||||
"template": "what did baz say to {buz}",
|
||||
"template_format": "f-string",
|
||||
"partial_variables": {}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])"
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
],
|
||||
"exception_key": null
|
||||
}
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_fallbacks[chain_pass_exceptions]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableParallel"
|
||||
],
|
||||
"kwargs": {
|
||||
"steps": {
|
||||
"text": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnablePassthrough"
|
||||
],
|
||||
"kwargs": {
|
||||
"func": null,
|
||||
"afunc": null,
|
||||
"input_type": null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(_raise_error)"
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(_dont_raise_error)"
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
],
|
||||
"exception_key": "exception"
|
||||
}
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_fallbacks[llm]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])"
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
],
|
||||
"exception_key": null
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_fallbacks[llm_multi]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['baz'], i=1)"
|
||||
},
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])"
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
],
|
||||
"exception_key": null
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
@ -696,280 +696,6 @@
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_llm_with_fallbacks[llm_chain_with_fallbacks]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableParallel"
|
||||
],
|
||||
"kwargs": {
|
||||
"steps": {
|
||||
"buz": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(lambda x: x)"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"buz"
|
||||
],
|
||||
"template": "what did baz say to {buz}",
|
||||
"template_format": "f-string",
|
||||
"partial_variables": {}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"buz"
|
||||
],
|
||||
"template": "what did baz say to {buz}",
|
||||
"template_format": "f-string",
|
||||
"partial_variables": {}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])"
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_llm_with_fallbacks[llm_with_fallbacks]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])"
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_llm_with_fallbacks[llm_with_multi_fallbacks]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['baz'], i=1)"
|
||||
},
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])"
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_prompt_with_chat_model
|
||||
'''
|
||||
ChatPromptTemplate(input_variables=['question'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a nice assistant.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='{question}'))])
|
||||
|
231
libs/core/tests/unit_tests/runnables/test_fallbacks.py
Normal file
231
libs/core/tests/unit_tests/runnables/test_fallbacks.py
Normal file
@ -0,0 +1,231 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.load import dumps
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
return error_llm.with_fallbacks([pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_multi() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def chain() -> Runnable:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
prompt = PromptTemplate.from_template("what did baz say to {buz}")
|
||||
return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
|
||||
[prompt | pass_llm]
|
||||
)
|
||||
|
||||
|
||||
def _raise_error(inputs: dict) -> str:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
def _dont_raise_error(inputs: dict) -> str:
|
||||
if "exception" in inputs:
|
||||
return "bar"
|
||||
raise ValueError()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def chain_pass_exceptions() -> Runnable:
|
||||
fallback = RunnableLambda(_dont_raise_error)
|
||||
return {"text": RunnablePassthrough()} | RunnableLambda(
|
||||
_raise_error
|
||||
).with_fallbacks([fallback], exception_key="exception")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"runnable",
|
||||
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
|
||||
)
|
||||
async def test_fallbacks(
|
||||
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
runnable = request.getfixturevalue(runnable)
|
||||
assert runnable.invoke("hello") == "bar"
|
||||
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(runnable.stream("hello")) == ["bar"]
|
||||
assert await runnable.ainvoke("hello") == "bar"
|
||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||
if sys.version_info >= (3, 9):
|
||||
assert dumps(runnable, pretty=True) == snapshot
|
||||
|
||||
|
||||
def _runnable(inputs: dict) -> str:
|
||||
if inputs["text"] == "foo":
|
||||
return "first"
|
||||
if "exception" not in inputs:
|
||||
raise ValueError()
|
||||
if inputs["text"] == "bar":
|
||||
return "second"
|
||||
if isinstance(inputs["exception"], ValueError):
|
||||
raise RuntimeError()
|
||||
return "third"
|
||||
|
||||
|
||||
def _assert_potential_error(actual: list, expected: list) -> None:
|
||||
for x, y in zip(actual, expected):
|
||||
if isinstance(x, Exception):
|
||||
assert isinstance(y, type(x))
|
||||
else:
|
||||
assert x == y
|
||||
|
||||
|
||||
def test_invoke_with_exception_key() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
runnable_with_single.invoke({"text": "baz"})
|
||||
|
||||
actual = runnable_with_single.invoke({"text": "bar"})
|
||||
expected = "second"
|
||||
_assert_potential_error([actual], [expected])
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable], exception_key="exception"
|
||||
)
|
||||
actual = runnable_with_double.invoke({"text": "baz"})
|
||||
|
||||
expected = "third"
|
||||
_assert_potential_error([actual], [expected])
|
||||
|
||||
|
||||
async def test_ainvoke_with_exception_key() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await runnable_with_single.ainvoke({"text": "baz"})
|
||||
|
||||
actual = await runnable_with_single.ainvoke({"text": "bar"})
|
||||
expected = "second"
|
||||
_assert_potential_error([actual], [expected])
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable], exception_key="exception"
|
||||
)
|
||||
actual = await runnable_with_double.ainvoke({"text": "baz"})
|
||||
expected = "third"
|
||||
_assert_potential_error([actual], [expected])
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
with pytest.raises(ValueError):
|
||||
runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||
actual = runnable.batch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
expected = ["first", ValueError(), ValueError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
runnable_with_single.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||
actual = runnable_with_single.batch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable], exception_key="exception"
|
||||
)
|
||||
actual = runnable_with_double.batch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
|
||||
expected = ["first", "second", "third"]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable],
|
||||
exception_key="exception",
|
||||
exceptions_to_handle=(ValueError,),
|
||||
)
|
||||
actual = runnable_with_double.batch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
with pytest.raises(ValueError):
|
||||
await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||
actual = await runnable.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
expected = ["first", ValueError(), ValueError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
await runnable_with_single.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}]
|
||||
)
|
||||
actual = await runnable_with_single.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable], exception_key="exception"
|
||||
)
|
||||
actual = await runnable_with_double.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
|
||||
expected = ["first", "second", "third"]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable],
|
||||
exception_key="exception",
|
||||
exceptions_to_handle=(ValueError,),
|
||||
)
|
||||
actual = await runnable_with_double.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
@ -66,7 +66,6 @@ from langchain_core.runnables import (
|
||||
RunnablePassthrough,
|
||||
RunnablePick,
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
add,
|
||||
chain,
|
||||
)
|
||||
@ -3683,52 +3682,6 @@ async def test_runnable_sequence_atransform() -> None:
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
return error_llm.with_fallbacks([pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_chain_with_fallbacks() -> Runnable:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
prompt = PromptTemplate.from_template("what did baz say to {buz}")
|
||||
return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
|
||||
[prompt | pass_llm]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"runnable",
|
||||
["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"],
|
||||
)
|
||||
async def test_llm_with_fallbacks(
|
||||
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
runnable = request.getfixturevalue(runnable)
|
||||
assert runnable.invoke("hello") == "bar"
|
||||
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(runnable.stream("hello")) == ["bar"]
|
||||
assert await runnable.ainvoke("hello") == "bar"
|
||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||
if sys.version_info >= (3, 9):
|
||||
assert dumps(runnable, pretty=True) == snapshot
|
||||
|
||||
|
||||
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
|
||||
"""Parse the output of an LLM call to a comma-separated list."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user