mirror of https://github.com/hwchase17/langchain
core[patch]: pass exceptions to fallbacks (#16048)
parent
770f57196e
commit
c5656a4905
@ -0,0 +1,373 @@
|
||||
# serializer version: 1
|
||||
# name: test_fallbacks[chain]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableParallel"
|
||||
],
|
||||
"kwargs": {
|
||||
"steps": {
|
||||
"buz": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(lambda x: x)"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"buz"
|
||||
],
|
||||
"template": "what did baz say to {buz}",
|
||||
"template_format": "f-string",
|
||||
"partial_variables": {}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"buz"
|
||||
],
|
||||
"template": "what did baz say to {buz}",
|
||||
"template_format": "f-string",
|
||||
"partial_variables": {}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])"
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
],
|
||||
"exception_key": null
|
||||
}
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_fallbacks[chain_pass_exceptions]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableSequence"
|
||||
],
|
||||
"kwargs": {
|
||||
"first": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableParallel"
|
||||
],
|
||||
"kwargs": {
|
||||
"steps": {
|
||||
"text": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnablePassthrough"
|
||||
],
|
||||
"kwargs": {
|
||||
"func": null,
|
||||
"afunc": null,
|
||||
"input_type": null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
"last": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(_raise_error)"
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"langchain_core",
|
||||
"runnables",
|
||||
"base",
|
||||
"RunnableLambda"
|
||||
],
|
||||
"repr": "RunnableLambda(_dont_raise_error)"
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
],
|
||||
"exception_key": "exception"
|
||||
}
|
||||
},
|
||||
"name": null
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_fallbacks[llm]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])"
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
],
|
||||
"exception_key": null
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_fallbacks[llm_multi]
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"runnable",
|
||||
"RunnableWithFallbacks"
|
||||
],
|
||||
"kwargs": {
|
||||
"runnable": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||
},
|
||||
"fallbacks": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['baz'], i=1)"
|
||||
},
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"fake",
|
||||
"llm",
|
||||
"FakeListLLM"
|
||||
],
|
||||
"repr": "FakeListLLM(responses=['bar'])"
|
||||
}
|
||||
],
|
||||
"exceptions_to_handle": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"builtins",
|
||||
"Exception"
|
||||
],
|
||||
"repr": "<class 'Exception'>"
|
||||
}
|
||||
],
|
||||
"exception_key": null
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
@ -0,0 +1,231 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.load import dumps
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
return error_llm.with_fallbacks([pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_multi() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def chain() -> Runnable:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
prompt = PromptTemplate.from_template("what did baz say to {buz}")
|
||||
return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
|
||||
[prompt | pass_llm]
|
||||
)
|
||||
|
||||
|
||||
def _raise_error(inputs: dict) -> str:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
def _dont_raise_error(inputs: dict) -> str:
|
||||
if "exception" in inputs:
|
||||
return "bar"
|
||||
raise ValueError()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def chain_pass_exceptions() -> Runnable:
|
||||
fallback = RunnableLambda(_dont_raise_error)
|
||||
return {"text": RunnablePassthrough()} | RunnableLambda(
|
||||
_raise_error
|
||||
).with_fallbacks([fallback], exception_key="exception")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"runnable",
|
||||
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
|
||||
)
|
||||
async def test_fallbacks(
|
||||
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
runnable = request.getfixturevalue(runnable)
|
||||
assert runnable.invoke("hello") == "bar"
|
||||
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(runnable.stream("hello")) == ["bar"]
|
||||
assert await runnable.ainvoke("hello") == "bar"
|
||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||
if sys.version_info >= (3, 9):
|
||||
assert dumps(runnable, pretty=True) == snapshot
|
||||
|
||||
|
||||
def _runnable(inputs: dict) -> str:
|
||||
if inputs["text"] == "foo":
|
||||
return "first"
|
||||
if "exception" not in inputs:
|
||||
raise ValueError()
|
||||
if inputs["text"] == "bar":
|
||||
return "second"
|
||||
if isinstance(inputs["exception"], ValueError):
|
||||
raise RuntimeError()
|
||||
return "third"
|
||||
|
||||
|
||||
def _assert_potential_error(actual: list, expected: list) -> None:
|
||||
for x, y in zip(actual, expected):
|
||||
if isinstance(x, Exception):
|
||||
assert isinstance(y, type(x))
|
||||
else:
|
||||
assert x == y
|
||||
|
||||
|
||||
def test_invoke_with_exception_key() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
runnable_with_single.invoke({"text": "baz"})
|
||||
|
||||
actual = runnable_with_single.invoke({"text": "bar"})
|
||||
expected = "second"
|
||||
_assert_potential_error([actual], [expected])
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable], exception_key="exception"
|
||||
)
|
||||
actual = runnable_with_double.invoke({"text": "baz"})
|
||||
|
||||
expected = "third"
|
||||
_assert_potential_error([actual], [expected])
|
||||
|
||||
|
||||
async def test_ainvoke_with_exception_key() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await runnable_with_single.ainvoke({"text": "baz"})
|
||||
|
||||
actual = await runnable_with_single.ainvoke({"text": "bar"})
|
||||
expected = "second"
|
||||
_assert_potential_error([actual], [expected])
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable], exception_key="exception"
|
||||
)
|
||||
actual = await runnable_with_double.ainvoke({"text": "baz"})
|
||||
expected = "third"
|
||||
_assert_potential_error([actual], [expected])
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
with pytest.raises(ValueError):
|
||||
runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||
actual = runnable.batch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
expected = ["first", ValueError(), ValueError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
runnable_with_single.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||
actual = runnable_with_single.batch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable], exception_key="exception"
|
||||
)
|
||||
actual = runnable_with_double.batch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
|
||||
expected = ["first", "second", "third"]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable],
|
||||
exception_key="exception",
|
||||
exceptions_to_handle=(ValueError,),
|
||||
)
|
||||
actual = runnable_with_double.batch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
runnable = RunnableLambda(_runnable)
|
||||
with pytest.raises(ValueError):
|
||||
await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||
actual = await runnable.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
expected = ["first", ValueError(), ValueError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_single = runnable.with_fallbacks(
|
||||
[runnable], exception_key="exception"
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
await runnable_with_single.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}]
|
||||
)
|
||||
actual = await runnable_with_single.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable], exception_key="exception"
|
||||
)
|
||||
actual = await runnable_with_double.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
|
||||
expected = ["first", "second", "third"]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
runnable_with_double = runnable.with_fallbacks(
|
||||
[runnable, runnable],
|
||||
exception_key="exception",
|
||||
exceptions_to_handle=(ValueError,),
|
||||
)
|
||||
actual = await runnable_with_double.abatch(
|
||||
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||
)
|
||||
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
Loading…
Reference in New Issue