Bagatur/runnable with fallbacks (#8543)

This commit is contained in:
Bagatur 2023-08-04 14:06:05 -07:00 committed by GitHub
parent 003e1ca9a0
commit f437311eef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 293 additions and 2 deletions

View File

@ -730,7 +730,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when chain errors.
@ -812,7 +812,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when chain errors.

View File

@ -14,6 +14,9 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
@ -189,6 +192,247 @@ class Runnable(Generic[Input, Output], ABC):
)
return output
def with_fallbacks(
self,
fallbacks: Sequence[Runnable[Input, Output]],
*,
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
) -> RunnableWithFallbacks[Input, Output]:
return RunnableWithFallbacks(
runnable=self,
fallbacks=fallbacks,
exceptions_to_handle=exceptions_to_handle,
)
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
runnable: Runnable[Input, Output]
fallbacks: Sequence[Runnable[Input, Output]]
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
class Config:
arbitrary_types_allowed = True
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable
yield from self.fallbacks
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
first_error = None
for runnable in self.runnables:
try:
output = runnable.invoke(
input,
_patch_config(config, run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
run_manager.on_chain_error(first_error)
raise first_error
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
first_error = None
for runnable in self.runnables:
try:
output = await runnable.ainvoke(
input,
_patch_config(config, run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
await run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await run_manager.on_chain_error(first_error)
raise first_error
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
]
first_error = None
for runnable in self.runnables:
try:
outputs = runnable.batch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
for rm in run_managers:
rm.on_chain_error(e)
raise e
else:
for rm, output in zip(run_managers, outputs):
rm.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
for rm in run_managers:
rm.on_chain_error(first_error)
raise first_error
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
)
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
)
)
first_error = None
for runnable in self.runnables:
try:
outputs = await runnable.abatch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
else:
await asyncio.gather(
*(
rm.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
for rm, output in zip(run_managers, outputs)
)
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
raise first_error
class RunnableSequence(Serializable, Runnable[Input, Output]):
first: Runnable[Input, Any]

View File

@ -6,6 +6,7 @@ from freezegun import freeze_time
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langchain import PromptTemplate
from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
@ -30,6 +31,7 @@ from langchain.schema.runnable import (
RunnableMap,
RunnablePassthrough,
RunnableSequence,
RunnableWithFallbacks,
)
@ -754,3 +756,48 @@ def test_bind_bind() -> None:
stop=["Observation:"], hello="world"
)
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
@pytest.fixture()
def llm_with_fallbacks() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
return error_llm.with_fallbacks([pass_llm])
@pytest.fixture()
def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
return error_llm.with_fallbacks([error_llm_2, pass_llm])
@pytest.fixture()
def llm_chain_with_fallbacks() -> RunnableSequence:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
prompt = PromptTemplate.from_template("what did baz say to {buz}")
return RunnableMap({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
[prompt | pass_llm]
)
@pytest.mark.parametrize(
"runnable",
["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"],
)
@pytest.mark.asyncio
async def test_llm_with_fallbacks(
runnable: RunnableWithFallbacks, request: Any
) -> None:
runnable = request.getfixturevalue(runnable)
assert runnable.invoke("hello") == "bar"
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(runnable.stream("hello")) == ["bar"]
assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar")