Bagatur/runnable with fallbacks (#8543)

This commit is contained in:
Bagatur 2023-08-04 14:06:05 -07:00 committed by GitHub
parent 003e1ca9a0
commit f437311eef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 293 additions and 2 deletions

View File

@ -730,7 +730,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
def on_chain_error( def on_chain_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when chain errors. """Run when chain errors.
@ -812,7 +812,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
async def on_chain_error( async def on_chain_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when chain errors. """Run when chain errors.

View File

@ -14,6 +14,9 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Tuple,
Type,
TypedDict, TypedDict,
TypeVar, TypeVar,
Union, Union,
@ -189,6 +192,247 @@ class Runnable(Generic[Input, Output], ABC):
) )
return output return output
def with_fallbacks(
self,
fallbacks: Sequence[Runnable[Input, Output]],
*,
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
) -> RunnableWithFallbacks[Input, Output]:
return RunnableWithFallbacks(
runnable=self,
fallbacks=fallbacks,
exceptions_to_handle=exceptions_to_handle,
)
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
runnable: Runnable[Input, Output]
fallbacks: Sequence[Runnable[Input, Output]]
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
class Config:
arbitrary_types_allowed = True
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable
yield from self.fallbacks
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
first_error = None
for runnable in self.runnables:
try:
output = runnable.invoke(
input,
_patch_config(config, run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
run_manager.on_chain_error(first_error)
raise first_error
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
first_error = None
for runnable in self.runnables:
try:
output = await runnable.ainvoke(
input,
_patch_config(config, run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
await run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await run_manager.on_chain_error(first_error)
raise first_error
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
]
first_error = None
for runnable in self.runnables:
try:
outputs = runnable.batch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
for rm in run_managers:
rm.on_chain_error(e)
raise e
else:
for rm, output in zip(run_managers, outputs):
rm.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
for rm in run_managers:
rm.on_chain_error(first_error)
raise first_error
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
)
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
)
)
first_error = None
for runnable in self.runnables:
try:
outputs = await runnable.abatch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
else:
await asyncio.gather(
*(
rm.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
for rm, output in zip(run_managers, outputs)
)
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
raise first_error
class RunnableSequence(Serializable, Runnable[Input, Output]): class RunnableSequence(Serializable, Runnable[Input, Output]):
first: Runnable[Input, Any] first: Runnable[Input, Any]

View File

@ -6,6 +6,7 @@ from freezegun import freeze_time
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from langchain import PromptTemplate
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
@ -30,6 +31,7 @@ from langchain.schema.runnable import (
RunnableMap, RunnableMap,
RunnablePassthrough, RunnablePassthrough,
RunnableSequence, RunnableSequence,
RunnableWithFallbacks,
) )
@ -754,3 +756,48 @@ def test_bind_bind() -> None:
stop=["Observation:"], hello="world" stop=["Observation:"], hello="world"
) )
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world")) ) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
@pytest.fixture()
def llm_with_fallbacks() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
return error_llm.with_fallbacks([pass_llm])
@pytest.fixture()
def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
error_llm = FakeListLLM(responses=["foo"], i=1)
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
return error_llm.with_fallbacks([error_llm_2, pass_llm])
@pytest.fixture()
def llm_chain_with_fallbacks() -> RunnableSequence:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
prompt = PromptTemplate.from_template("what did baz say to {buz}")
return RunnableMap({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
[prompt | pass_llm]
)
@pytest.mark.parametrize(
"runnable",
["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"],
)
@pytest.mark.asyncio
async def test_llm_with_fallbacks(
runnable: RunnableWithFallbacks, request: Any
) -> None:
runnable = request.getfixturevalue(runnable)
assert runnable.invoke("hello") == "bar"
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(runnable.stream("hello")) == ["bar"]
assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar")