Create new RunnableSerializable base class in preparation for configurable runnables (#11279)

- Also move RunnableBranch to its own file

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11294/head^2
Nuno Campos 1 year ago committed by GitHub
commit 0638f7b83a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,7 +21,6 @@ from langchain.callbacks.manager import (
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import (
BaseModel,
Field,
@ -30,7 +29,7 @@ from langchain.pydantic_v1 import (
validator,
)
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
logger = logging.getLogger(__name__)
@ -39,7 +38,7 @@ def _get_verbosity() -> bool:
return langchain.verbose
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like

@ -14,7 +14,7 @@ from langchain.schema.runnable import RunnableConfig
class FakeListLLM(LLM):
"""Fake LLM for testing purposes."""
responses: List
responses: List[str]
sleep: Optional[float] = None
i: int = 0

@ -15,11 +15,10 @@ from typing import (
from typing_extensions import TypeAlias
from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable
from langchain.schema.runnable import RunnableSerializable
from langchain.utils import get_pydantic_field_names
if TYPE_CHECKING:
@ -54,7 +53,7 @@ LanguageModelOutput = TypeVar("LanguageModelOutput")
class BaseLanguageModel(
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
):
"""Abstract base class for interfacing with language models.

@ -16,7 +16,6 @@ from typing import (
from typing_extensions import get_args
from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
from langchain.schema.output import (
ChatGeneration,
@ -25,12 +24,12 @@ from langchain.schema.output import (
GenerationChunk,
)
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
T = TypeVar("T")
class BaseLLMOutputParser(Serializable, Generic[T], ABC):
class BaseLLMOutputParser(Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
@ -63,7 +62,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call."""
@ -121,7 +120,9 @@ class BaseGenerationOutputParser(
)
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call.
Output parsers help structure language model responses.

@ -7,15 +7,14 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import yaml
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]

@ -6,9 +6,8 @@ from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema.document import Document
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
if TYPE_CHECKING:
from langchain.callbacks.manager import (
@ -18,7 +17,7 @@ if TYPE_CHECKING:
)
class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return

@ -2,13 +2,14 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.base import (
Runnable,
RunnableBinding,
RunnableBranch,
RunnableLambda,
RunnableMap,
RunnableSequence,
RunnableWithFallbacks,
RunnableSerializable,
)
from langchain.schema.runnable.branch import RunnableBranch
from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
@ -19,6 +20,7 @@ __all__ = [
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableSerializable",
"RunnableBinding",
"RunnableBranch",
"RunnableConfig",

@ -11,8 +11,7 @@ from typing import (
Union,
)
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
@ -104,7 +103,7 @@ class PutLocalVar(RunnablePassthrough):
class GetLocalVar(
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
):
key: str
"""The key to extract from the local state."""

@ -36,6 +36,9 @@ if TYPE_CHECKING:
CallbackManagerForChainRun,
)
from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.schema.runnable.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
from langchain.load.dump import dumpd
@ -119,6 +122,24 @@ class Runnable(Generic[Input, Output], ABC):
self.__class__.__name__ + "Output", __root__=(root_type, None)
)
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
class _Config:
arbitrary_types_allowed = True
include = include or []
return create_model( # type: ignore[call-overload]
self.__class__.__name__ + "Config",
__config__=_Config,
**{
field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items()
if field_name in include
},
)
def __or__(
self,
other: Union[
@ -437,7 +458,9 @@ class Runnable(Generic[Input, Output], ABC):
fallbacks: Sequence[Runnable[Input, Output]],
*,
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
) -> RunnableWithFallbacks[Input, Output]:
) -> RunnableWithFallbacksT[Input, Output]:
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
return RunnableWithFallbacks(
runnable=self,
fallbacks=fallbacks,
@ -812,462 +835,11 @@ class Runnable(Generic[Input, Output], ABC):
await run_manager.on_chain_end(final_output, inputs=final_input)
class RunnableBranch(Serializable, Runnable[Input, Output]):
"""A Runnable that selects which branch to run based on a condition.
The runnable is initialized with a list of (condition, runnable) pairs and
a default branch.
When operating on an input, the first condition that evaluates to True is
selected, and the corresponding runnable is run on the input.
class RunnableSerializable(Serializable, Runnable[Input, Output]):
pass
If no condition evaluates to True, the default branch is run on the input.
Examples:
.. code-block:: python
from langchain.schema.runnable import RunnableBranch
branch = RunnableBranch(
(lambda x: isinstance(x, str), lambda x: x.upper()),
(lambda x: isinstance(x, int), lambda x: x + 1),
(lambda x: isinstance(x, float), lambda x: x * 2),
lambda x: "goodbye",
)
branch.invoke("hello") # "HELLO"
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
Callable[[Input], Awaitable[bool]],
],
RunnableLike,
],
RunnableLike, # To accommodate the default branch
],
) -> None:
"""A Runnable that runs one of two branches based on a condition."""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
default = branches[-1]
if not isinstance(
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
)
default_ = cast(
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
)
_branches = []
for branch in branches[:-1]:
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
raise TypeError(
f"RunnableBranch branches must be "
f"tuples or lists, not {type(branch)}"
)
if not len(branch) == 2:
raise ValueError(
f"RunnableBranch branches must be "
f"tuples or lists of length 2, not {len(branch)}"
)
condition, runnable = branch
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_)
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""RunnableBranch is serializable if all its branches are serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema
return super().input_schema
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = runnable.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
)
break
else:
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Async version of invoke."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = await runnable.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
"""
A Runnable that can fallback to other Runnables if it fails.
"""
runnable: Runnable[Input, Output]
fallbacks: Sequence[Runnable[Input, Output]]
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
return self.runnable.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.runnable.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.runnable.output_schema
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable
yield from self.fallbacks
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = runnable.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
run_manager.on_chain_error(first_error)
raise first_error
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = await runnable.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
await run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await run_manager.on_chain_error(first_error)
raise first_error
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
first_error = None
for runnable in self.runnables:
try:
outputs = runnable.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
for rm in run_managers:
rm.on_chain_error(e)
raise e
else:
for rm, output in zip(run_managers, outputs):
rm.on_chain_end(output)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
for rm in run_managers:
rm.on_chain_error(first_error)
raise first_error
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
)
first_error = None
for runnable in self.runnables:
try:
outputs = await runnable.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
else:
await asyncio.gather(
*(
rm.on_chain_end(output)
for rm, output in zip(run_managers, outputs)
)
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
raise first_error
class RunnableSequence(Serializable, Runnable[Input, Output]):
class RunnableSequence(RunnableSerializable[Input, Output]):
"""
A sequence of runnables, where the output of each is the input of the next.
"""
@ -1749,7 +1321,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
yield chunk
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
"""
A runnable that runs a mapping of runnables in parallel,
and returns a mapping of their outputs.
@ -1799,7 +1371,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
return create_model( # type: ignore[call-overload]
"RunnableMapInput",
**{
k: (v.type_, v.default)
k: (v.annotation, v.default)
for step in self.steps.values()
for k, v in step.input_schema.__fields__.items()
if k != "__root__"
@ -2374,7 +1946,7 @@ class RunnableLambda(Runnable[Input, Output]):
return await super().ainvoke(input, config)
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
"""
A runnable that delegates calls to another runnable
with each element of the input sequence.
@ -2413,6 +1985,11 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
),
)
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -2455,7 +2032,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
return await self._acall_with_config(self._ainvoke, input, config)
class RunnableBinding(Serializable, Runnable[Input, Output]):
class RunnableBinding(RunnableSerializable[Input, Output]):
"""
A runnable that delegates calls to another runnable with a set of kwargs.
"""
@ -2485,6 +2062,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True

@ -0,0 +1,235 @@
from typing import (
Any,
Awaitable,
Callable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain.load.dump import dumpd
from langchain.pydantic_v1 import BaseModel
from langchain.schema.runnable.base import (
Runnable,
RunnableLike,
RunnableSerializable,
coerce_to_runnable,
)
from langchain.schema.runnable.config import (
RunnableConfig,
ensure_config,
get_callback_manager_for_config,
patch_config,
)
from langchain.schema.runnable.utils import Input, Output
class RunnableBranch(RunnableSerializable[Input, Output]):
"""A Runnable that selects which branch to run based on a condition.
The runnable is initialized with a list of (condition, runnable) pairs and
a default branch.
When operating on an input, the first condition that evaluates to True is
selected, and the corresponding runnable is run on the input.
If no condition evaluates to True, the default branch is run on the input.
Examples:
.. code-block:: python
from langchain.schema.runnable import RunnableBranch
branch = RunnableBranch(
(lambda x: isinstance(x, str), lambda x: x.upper()),
(lambda x: isinstance(x, int), lambda x: x + 1),
(lambda x: isinstance(x, float), lambda x: x * 2),
lambda x: "goodbye",
)
branch.invoke("hello") # "HELLO"
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
Callable[[Input], Awaitable[bool]],
],
RunnableLike,
],
RunnableLike, # To accommodate the default branch
],
) -> None:
"""A Runnable that runs one of two branches based on a condition."""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
default = branches[-1]
if not isinstance(
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
)
default_ = cast(
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
)
_branches = []
for branch in branches[:-1]:
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
raise TypeError(
f"RunnableBranch branches must be "
f"tuples or lists, not {type(branch)}"
)
if not len(branch) == 2:
raise ValueError(
f"RunnableBranch branches must be "
f"tuples or lists of length 2, not {len(branch)}"
)
condition, runnable = branch
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_)
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""RunnableBranch is serializable if all its branches are serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> Type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema
return super().input_schema
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = runnable.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Async version of invoke."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = await runnable.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output

@ -0,0 +1,286 @@
import asyncio
from typing import (
TYPE_CHECKING,
Any,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from langchain.load.dump import dumpd
from langchain.pydantic_v1 import BaseModel
from langchain.schema.runnable.base import Runnable, RunnableSerializable
from langchain.schema.runnable.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
get_config_list,
patch_config,
)
from langchain.schema.runnable.utils import Input, Output
if TYPE_CHECKING:
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
"""
A Runnable that can fallback to other Runnables if it fails.
"""
runnable: Runnable[Input, Output]
fallbacks: Sequence[Runnable[Input, Output]]
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
return self.runnable.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.runnable.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.runnable.output_schema
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.runnable.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable
yield from self.fallbacks
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = runnable.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
run_manager.on_chain_error(first_error)
raise first_error
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = await runnable.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
await run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await run_manager.on_chain_error(first_error)
raise first_error
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
first_error = None
for runnable in self.runnables:
try:
outputs = runnable.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
for rm in run_managers:
rm.on_chain_error(e)
raise e
else:
for rm, output in zip(run_managers, outputs):
rm.on_chain_end(output)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
for rm in run_managers:
rm.on_chain_error(first_error)
raise first_error
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
)
first_error = None
for runnable in self.runnables:
try:
outputs = await runnable.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
else:
await asyncio.gather(
*(
rm.on_chain_end(output)
for rm, output in zip(run_managers, outputs)
)
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
raise first_error

@ -16,9 +16,13 @@ from typing import (
cast,
)
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.runnable.base import Input, Runnable, RunnableMap
from langchain.schema.runnable.base import (
Input,
Runnable,
RunnableMap,
RunnableSerializable,
)
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
from langchain.schema.runnable.utils import AddableDict
from langchain.utils.aiter import atee, py_anext
@ -33,7 +37,7 @@ async def aidentity(x: Input) -> Input:
return x
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
class RunnablePassthrough(RunnableSerializable[Input, Input]):
"""
A runnable that passes through the input.
"""
@ -109,7 +113,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
yield chunk
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]):
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
"""

@ -14,8 +14,13 @@ from typing import (
from typing_extensions import TypedDict
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable
from langchain.schema.runnable.base import (
Input,
Output,
Runnable,
RunnableSerializable,
coerce_to_runnable,
)
from langchain.schema.runnable.config import (
RunnableConfig,
get_config_list,
@ -36,7 +41,7 @@ class RouterInput(TypedDict):
input: Any
class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.

@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForToolRun,
Callbacks,
)
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import (
BaseModel,
Extra,
@ -25,7 +26,7 @@ from langchain.pydantic_v1 import (
root_validator,
validate_arguments,
)
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable
class SchemaAnnotationError(TypeError):
@ -97,7 +98,7 @@ class ToolException(Exception):
pass
class BaseTool(BaseModel, Runnable[Union[str, Dict], Any]):
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
"""Interface LangChain tools must implement."""
def __init_subclass__(cls, **kwargs: Any) -> None:
@ -165,10 +166,9 @@ class ChildTool(BaseTool):
] = False
"""Handle the content of the ToolException thrown."""
class Config:
class Config(Serializable.Config):
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property

@ -2,7 +2,7 @@
"""Tools for interacting with Spark SQL."""
from typing import Any, Dict, Optional
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.schema.language_model import BaseLanguageModel
from langchain.callbacks.manager import (
@ -21,13 +21,8 @@ class BaseSparkSQLTool(BaseModel):
db: SparkSQL = Field(exclude=True)
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
extra = Extra.forbid
pass
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):

@ -21,13 +21,8 @@ class BaseSQLDatabaseTool(BaseModel):
db: SQLDatabase = Field(exclude=True)
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
extra = Extra.forbid
pass
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

@ -18,9 +18,7 @@ class BaseVectorStoreTool(BaseModel):
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
pass
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:

Loading…
Cancel
Save