From f437311eef4a02fd3dd97758ab5455d6f2f093ba Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 4 Aug 2023 14:06:05 -0700 Subject: [PATCH] Bagatur/runnable with fallbacks (#8543) --- libs/langchain/langchain/callbacks/manager.py | 4 +- libs/langchain/langchain/schema/runnable.py | 244 ++++++++++++++++++ .../tests/unit_tests/schema/test_runnable.py | 47 ++++ 3 files changed, 293 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 1a0f6ac0a0..7016a13ed9 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -730,7 +730,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): def on_chain_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, **kwargs: Any, ) -> None: """Run when chain errors. @@ -812,7 +812,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): async def on_chain_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, **kwargs: Any, ) -> None: """Run when chain errors. diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 2669409a3a..4cfd3f913f 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -14,6 +14,9 @@ from typing import ( List, Mapping, Optional, + Sequence, + Tuple, + Type, TypedDict, TypeVar, Union, @@ -189,6 +192,247 @@ class Runnable(Generic[Input, Output], ABC): ) return output + def with_fallbacks( + self, + fallbacks: Sequence[Runnable[Input, Output]], + *, + exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,), + ) -> RunnableWithFallbacks[Input, Output]: + return RunnableWithFallbacks( + runnable=self, + fallbacks=fallbacks, + exceptions_to_handle=exceptions_to_handle, + ) + + +class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): + runnable: Runnable[Input, Output] + fallbacks: Sequence[Runnable[Input, Output]] + exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,) + + class Config: + arbitrary_types_allowed = True + + @property + def runnables(self) -> Iterator[Runnable[Input, Output]]: + yield self.runnable + yield from self.fallbacks + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + from langchain.callbacks.manager import CallbackManager + + # setup callbacks + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + # start the root run + run_manager = callback_manager.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + first_error = None + for runnable in self.runnables: + try: + output = runnable.invoke( + input, + _patch_config(config, run_manager.get_child()), + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + run_manager.on_chain_error(e) + raise e + else: + run_manager.on_chain_end( + output if isinstance(output, dict) else {"output": output} + ) + return output + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + run_manager.on_chain_error(first_error) + raise first_error + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Output: + from langchain.callbacks.manager import AsyncCallbackManager + + # setup callbacks + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + # start the root run + run_manager = await callback_manager.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + + first_error = None + for runnable in self.runnables: + try: + output = await runnable.ainvoke( + input, + _patch_config(config, run_manager.get_child()), + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + await run_manager.on_chain_error(e) + raise e + else: + await run_manager.on_chain_end( + output if isinstance(output, dict) else {"output": output} + ) + return output + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + await run_manager.on_chain_error(first_error) + raise first_error + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + max_concurrency: Optional[int] = None, + ) -> List[Output]: + from langchain.callbacks.manager import CallbackManager + + # setup callbacks + configs = self._get_config_list(config, len(inputs)) + callback_managers = [ + CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers = [ + cm.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + for cm, input in zip(callback_managers, inputs) + ] + + first_error = None + for runnable in self.runnables: + try: + outputs = runnable.batch( + inputs, + [ + # each step a child run of the corresponding root run + _patch_config(config, rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + max_concurrency=max_concurrency, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + for rm in run_managers: + rm.on_chain_error(e) + raise e + else: + for rm, output in zip(run_managers, outputs): + rm.on_chain_end( + output if isinstance(output, dict) else {"output": output} + ) + return outputs + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + for rm in run_managers: + rm.on_chain_error(first_error) + raise first_error + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + max_concurrency: Optional[int] = None, + ) -> List[Output]: + from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainRun, + ) + + # setup callbacks + configs = self._get_config_list(config, len(inputs)) + callback_managers = [ + AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + *( + cm.on_chain_start( + dumpd(self), input if isinstance(input, dict) else {"input": input} + ) + for cm, input in zip(callback_managers, inputs) + ) + ) + + first_error = None + for runnable in self.runnables: + try: + outputs = await runnable.abatch( + inputs, + [ + # each step a child run of the corresponding root run + _patch_config(config, rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + max_concurrency=max_concurrency, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) + else: + await asyncio.gather( + *( + rm.on_chain_end( + output if isinstance(output, dict) else {"output": output} + ) + for rm, output in zip(run_managers, outputs) + ) + ) + return outputs + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers)) + raise first_error + class RunnableSequence(Serializable, Runnable[Input, Output]): first: Runnable[Input, Any] diff --git a/libs/langchain/tests/unit_tests/schema/test_runnable.py b/libs/langchain/tests/unit_tests/schema/test_runnable.py index 3b73658886..181cf50c3a 100644 --- a/libs/langchain/tests/unit_tests/schema/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/test_runnable.py @@ -6,6 +6,7 @@ from freezegun import freeze_time from pytest_mock import MockerFixture from syrupy import SnapshotAssertion +from langchain import PromptTemplate from langchain.callbacks.manager import Callbacks from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run @@ -30,6 +31,7 @@ from langchain.schema.runnable import ( RunnableMap, RunnablePassthrough, RunnableSequence, + RunnableWithFallbacks, ) @@ -754,3 +756,48 @@ def test_bind_bind() -> None: stop=["Observation:"], hello="world" ) ) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world")) + + +@pytest.fixture() +def llm_with_fallbacks() -> RunnableWithFallbacks: + error_llm = FakeListLLM(responses=["foo"], i=1) + pass_llm = FakeListLLM(responses=["bar"]) + + return error_llm.with_fallbacks([pass_llm]) + + +@pytest.fixture() +def llm_with_multi_fallbacks() -> RunnableWithFallbacks: + error_llm = FakeListLLM(responses=["foo"], i=1) + error_llm_2 = FakeListLLM(responses=["baz"], i=1) + pass_llm = FakeListLLM(responses=["bar"]) + + return error_llm.with_fallbacks([error_llm_2, pass_llm]) + + +@pytest.fixture() +def llm_chain_with_fallbacks() -> RunnableSequence: + error_llm = FakeListLLM(responses=["foo"], i=1) + pass_llm = FakeListLLM(responses=["bar"]) + + prompt = PromptTemplate.from_template("what did baz say to {buz}") + return RunnableMap({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks( + [prompt | pass_llm] + ) + + +@pytest.mark.parametrize( + "runnable", + ["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"], +) +@pytest.mark.asyncio +async def test_llm_with_fallbacks( + runnable: RunnableWithFallbacks, request: Any +) -> None: + runnable = request.getfixturevalue(runnable) + assert runnable.invoke("hello") == "bar" + assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3 + assert list(runnable.stream("hello")) == ["bar"] + assert await runnable.ainvoke("hello") == "bar" + assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 + assert list(await runnable.ainvoke("hello")) == list("bar")