mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Bagatur/runnable with fallbacks (#8543)
This commit is contained in:
parent
003e1ca9a0
commit
f437311eef
@ -730,7 +730,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors.
|
||||
@ -812,7 +812,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors.
|
||||
|
@ -14,6 +14,9 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -189,6 +192,247 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
return output
|
||||
|
||||
def with_fallbacks(
|
||||
self,
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
*,
|
||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
|
||||
) -> RunnableWithFallbacks[Input, Output]:
|
||||
return RunnableWithFallbacks(
|
||||
runnable=self,
|
||||
fallbacks=fallbacks,
|
||||
exceptions_to_handle=exceptions_to_handle,
|
||||
)
|
||||
|
||||
|
||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
runnable: Runnable[Input, Output]
|
||||
fallbacks: Sequence[Runnable[Input, Output]]
|
||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||
yield self.runnable
|
||||
yield from self.fallbacks
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
await run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
]
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = runnable.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
for rm, output in zip(run_managers, outputs):
|
||||
rm.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
# setup callbacks
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
for cm, input in zip(callback_managers, inputs)
|
||||
)
|
||||
)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = await runnable.abatch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
rm.on_chain_end(
|
||||
output if isinstance(output, dict) else {"output": output}
|
||||
)
|
||||
for rm, output in zip(run_managers, outputs)
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
||||
raise first_error
|
||||
|
||||
|
||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
first: Runnable[Input, Any]
|
||||
|
@ -6,6 +6,7 @@ from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
@ -30,6 +31,7 @@ from langchain.schema.runnable import (
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
|
||||
|
||||
@ -754,3 +756,48 @@ def test_bind_bind() -> None:
|
||||
stop=["Observation:"], hello="world"
|
||||
)
|
||||
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
return error_llm.with_fallbacks([pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_chain_with_fallbacks() -> RunnableSequence:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
prompt = PromptTemplate.from_template("what did baz say to {buz}")
|
||||
return RunnableMap({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
|
||||
[prompt | pass_llm]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"runnable",
|
||||
["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_with_fallbacks(
|
||||
runnable: RunnableWithFallbacks, request: Any
|
||||
) -> None:
|
||||
runnable = request.getfixturevalue(runnable)
|
||||
assert runnable.invoke("hello") == "bar"
|
||||
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(runnable.stream("hello")) == ["bar"]
|
||||
assert await runnable.ainvoke("hello") == "bar"
|
||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||
|
Loading…
Reference in New Issue
Block a user