mirror of https://github.com/hwchase17/langchain
Create new RunnableSerializable base class in preparation for configurable runnables (#11279)
- Also move RunnableBranch to its own file <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->pull/11294/head^2
commit
0638f7b83a
@ -0,0 +1,235 @@
|
|||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.pydantic_v1 import BaseModel
|
||||||
|
from langchain.schema.runnable.base import (
|
||||||
|
Runnable,
|
||||||
|
RunnableLike,
|
||||||
|
RunnableSerializable,
|
||||||
|
coerce_to_runnable,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
ensure_config,
|
||||||
|
get_callback_manager_for_config,
|
||||||
|
patch_config,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.utils import Input, Output
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||||
|
"""A Runnable that selects which branch to run based on a condition.
|
||||||
|
|
||||||
|
The runnable is initialized with a list of (condition, runnable) pairs and
|
||||||
|
a default branch.
|
||||||
|
|
||||||
|
When operating on an input, the first condition that evaluates to True is
|
||||||
|
selected, and the corresponding runnable is run on the input.
|
||||||
|
|
||||||
|
If no condition evaluates to True, the default branch is run on the input.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.schema.runnable import RunnableBranch
|
||||||
|
|
||||||
|
branch = RunnableBranch(
|
||||||
|
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
||||||
|
(lambda x: isinstance(x, int), lambda x: x + 1),
|
||||||
|
(lambda x: isinstance(x, float), lambda x: x * 2),
|
||||||
|
lambda x: "goodbye",
|
||||||
|
)
|
||||||
|
|
||||||
|
branch.invoke("hello") # "HELLO"
|
||||||
|
branch.invoke(None) # "goodbye"
|
||||||
|
"""
|
||||||
|
|
||||||
|
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||||
|
default: Runnable[Input, Output]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*branches: Union[
|
||||||
|
Tuple[
|
||||||
|
Union[
|
||||||
|
Runnable[Input, bool],
|
||||||
|
Callable[[Input], bool],
|
||||||
|
Callable[[Input], Awaitable[bool]],
|
||||||
|
],
|
||||||
|
RunnableLike,
|
||||||
|
],
|
||||||
|
RunnableLike, # To accommodate the default branch
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""A Runnable that runs one of two branches based on a condition."""
|
||||||
|
if len(branches) < 2:
|
||||||
|
raise ValueError("RunnableBranch requires at least two branches")
|
||||||
|
|
||||||
|
default = branches[-1]
|
||||||
|
|
||||||
|
if not isinstance(
|
||||||
|
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
"RunnableBranch default must be runnable, callable or mapping."
|
||||||
|
)
|
||||||
|
|
||||||
|
default_ = cast(
|
||||||
|
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
||||||
|
)
|
||||||
|
|
||||||
|
_branches = []
|
||||||
|
|
||||||
|
for branch in branches[:-1]:
|
||||||
|
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
||||||
|
raise TypeError(
|
||||||
|
f"RunnableBranch branches must be "
|
||||||
|
f"tuples or lists, not {type(branch)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not len(branch) == 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"RunnableBranch branches must be "
|
||||||
|
f"tuples or lists of length 2, not {len(branch)}"
|
||||||
|
)
|
||||||
|
condition, runnable = branch
|
||||||
|
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
||||||
|
runnable = coerce_to_runnable(runnable)
|
||||||
|
_branches.append((condition, runnable))
|
||||||
|
|
||||||
|
super().__init__(branches=_branches, default=default_)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""RunnableBranch is serializable if all its branches are serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
|
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||||
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
runnables = (
|
||||||
|
[self.default]
|
||||||
|
+ [r for _, r in self.branches]
|
||||||
|
+ [r for r, _ in self.branches]
|
||||||
|
)
|
||||||
|
|
||||||
|
for runnable in runnables:
|
||||||
|
if runnable.input_schema.schema().get("type") is not None:
|
||||||
|
return runnable.input_schema
|
||||||
|
|
||||||
|
return super().input_schema
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
"""First evaluates the condition, then delegate to true or false branch."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = condition.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
output = runnable.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
output = self.default.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
run_manager.on_chain_end(dumpd(output))
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
"""Async version of invoke."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = await condition.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
output = await runnable.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
output = await self.default.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
run_manager.on_chain_end(dumpd(output))
|
||||||
|
return output
|
@ -0,0 +1,286 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.pydantic_v1 import BaseModel
|
||||||
|
from langchain.schema.runnable.base import Runnable, RunnableSerializable
|
||||||
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
ensure_config,
|
||||||
|
get_async_callback_manager_for_config,
|
||||||
|
get_callback_manager_for_config,
|
||||||
|
get_config_list,
|
||||||
|
patch_config,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.utils import Input, Output
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||||
|
"""
|
||||||
|
A Runnable that can fallback to other Runnables if it fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
runnable: Runnable[Input, Output]
|
||||||
|
fallbacks: Sequence[Runnable[Input, Output]]
|
||||||
|
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Type[Input]:
|
||||||
|
return self.runnable.InputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Type[Output]:
|
||||||
|
return self.runnable.OutputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.runnable.input_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.runnable.output_schema
|
||||||
|
|
||||||
|
def config_schema(
|
||||||
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
return self.runnable.config_schema(include=include)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||||
|
yield self.runnable
|
||||||
|
yield from self.fallbacks
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
# setup callbacks
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
# start the root run
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
output = runnable.invoke(
|
||||||
|
input,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(output)
|
||||||
|
return output
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
run_manager.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Output:
|
||||||
|
# setup callbacks
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
|
# start the root run
|
||||||
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
output = await runnable.ainvoke(
|
||||||
|
input,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
await run_manager.on_chain_end(output)
|
||||||
|
return output
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
await run_manager.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
if return_exceptions:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
callback_managers = [
|
||||||
|
CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
for config in configs
|
||||||
|
]
|
||||||
|
# start the root runs, one per input
|
||||||
|
run_managers = [
|
||||||
|
cm.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input if isinstance(input, dict) else {"input": input},
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
|
]
|
||||||
|
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
outputs = runnable.batch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
patch_config(config, callbacks=rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
for rm in run_managers:
|
||||||
|
rm.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
for rm, output in zip(run_managers, outputs):
|
||||||
|
rm.on_chain_end(output)
|
||||||
|
return outputs
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
for rm in run_managers:
|
||||||
|
rm.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
if return_exceptions:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
callback_managers = [
|
||||||
|
AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
for config in configs
|
||||||
|
]
|
||||||
|
# start the root runs, one per input
|
||||||
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
cm.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
outputs = await runnable.abatch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
patch_config(config, callbacks=rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||||
|
else:
|
||||||
|
await asyncio.gather(
|
||||||
|
*(
|
||||||
|
rm.on_chain_end(output)
|
||||||
|
for rm, output in zip(run_managers, outputs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
||||||
|
raise first_error
|
Loading…
Reference in New Issue