mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Bagatur/runnable with fallbacks (#8543)
This commit is contained in:
parent
003e1ca9a0
commit
f437311eef
@ -730,7 +730,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
|||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain errors.
|
"""Run when chain errors.
|
||||||
@ -812,7 +812,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
|||||||
|
|
||||||
async def on_chain_error(
|
async def on_chain_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain errors.
|
"""Run when chain errors.
|
||||||
|
@ -14,6 +14,9 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -189,6 +192,247 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def with_fallbacks(
|
||||||
|
self,
|
||||||
|
fallbacks: Sequence[Runnable[Input, Output]],
|
||||||
|
*,
|
||||||
|
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
|
||||||
|
) -> RunnableWithFallbacks[Input, Output]:
|
||||||
|
return RunnableWithFallbacks(
|
||||||
|
runnable=self,
|
||||||
|
fallbacks=fallbacks,
|
||||||
|
exceptions_to_handle=exceptions_to_handle,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||||
|
runnable: Runnable[Input, Output]
|
||||||
|
fallbacks: Sequence[Runnable[Input, Output]]
|
||||||
|
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||||
|
yield self.runnable
|
||||||
|
yield from self.fallbacks
|
||||||
|
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
# start the root run
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
output = runnable.invoke(
|
||||||
|
input,
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(
|
||||||
|
output if isinstance(output, dict) else {"output": output}
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
run_manager.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
# start the root run
|
||||||
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
output = await runnable.ainvoke(
|
||||||
|
input,
|
||||||
|
_patch_config(config, run_manager.get_child()),
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
await run_manager.on_chain_end(
|
||||||
|
output if isinstance(output, dict) else {"output": output}
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
await run_manager.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
callback_managers = [
|
||||||
|
CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
for config in configs
|
||||||
|
]
|
||||||
|
# start the root runs, one per input
|
||||||
|
run_managers = [
|
||||||
|
cm.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
for cm, input in zip(callback_managers, inputs)
|
||||||
|
]
|
||||||
|
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
outputs = runnable.batch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
_patch_config(config, rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
max_concurrency=max_concurrency,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
for rm in run_managers:
|
||||||
|
rm.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
for rm, output in zip(run_managers, outputs):
|
||||||
|
rm.on_chain_end(
|
||||||
|
output if isinstance(output, dict) else {"output": output}
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
for rm in run_managers:
|
||||||
|
rm.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManager,
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
callback_managers = [
|
||||||
|
AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
for config in configs
|
||||||
|
]
|
||||||
|
# start the root runs, one per input
|
||||||
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
cm.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
for cm, input in zip(callback_managers, inputs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
outputs = await runnable.abatch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
_patch_config(config, rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
max_concurrency=max_concurrency,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||||
|
else:
|
||||||
|
await asyncio.gather(
|
||||||
|
*(
|
||||||
|
rm.on_chain_end(
|
||||||
|
output if isinstance(output, dict) else {"output": output}
|
||||||
|
)
|
||||||
|
for rm, output in zip(run_managers, outputs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
|
||||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||||
first: Runnable[Input, Any]
|
first: Runnable[Input, Any]
|
||||||
|
@ -6,6 +6,7 @@ from freezegun import freeze_time
|
|||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
|
from langchain import PromptTemplate
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
@ -30,6 +31,7 @@ from langchain.schema.runnable import (
|
|||||||
RunnableMap,
|
RunnableMap,
|
||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
|
RunnableWithFallbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -754,3 +756,48 @@ def test_bind_bind() -> None:
|
|||||||
stop=["Observation:"], hello="world"
|
stop=["Observation:"], hello="world"
|
||||||
)
|
)
|
||||||
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
|
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
||||||
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
|
|
||||||
|
return error_llm.with_fallbacks([pass_llm])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
|
||||||
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
|
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
||||||
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
|
|
||||||
|
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def llm_chain_with_fallbacks() -> RunnableSequence:
|
||||||
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
|
|
||||||
|
prompt = PromptTemplate.from_template("what did baz say to {buz}")
|
||||||
|
return RunnableMap({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
|
||||||
|
[prompt | pass_llm]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"runnable",
|
||||||
|
["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_with_fallbacks(
|
||||||
|
runnable: RunnableWithFallbacks, request: Any
|
||||||
|
) -> None:
|
||||||
|
runnable = request.getfixturevalue(runnable)
|
||||||
|
assert runnable.invoke("hello") == "bar"
|
||||||
|
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||||
|
assert list(runnable.stream("hello")) == ["bar"]
|
||||||
|
assert await runnable.ainvoke("hello") == "bar"
|
||||||
|
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||||
|
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||||
|
Loading…
Reference in New Issue
Block a user