Create new RunnableSerializable base class in preparation for configurable runnables (#11279)

- Also move RunnableBranch to its own file

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11294/head^2
Nuno Campos 1 year ago committed by GitHub
commit 0638f7b83a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,7 +21,6 @@ from langchain.callbacks.manager import (
Callbacks, Callbacks,
) )
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import ( from langchain.pydantic_v1 import (
BaseModel, BaseModel,
Field, Field,
@ -30,7 +29,7 @@ from langchain.pydantic_v1 import (
validator, validator,
) )
from langchain.schema import RUN_KEY, BaseMemory, RunInfo from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig from langchain.schema.runnable import RunnableConfig, RunnableSerializable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,7 +38,7 @@ def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components. """Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like Chains should be used to encode a sequence of calls to components like

@ -14,7 +14,7 @@ from langchain.schema.runnable import RunnableConfig
class FakeListLLM(LLM): class FakeListLLM(LLM):
"""Fake LLM for testing purposes.""" """Fake LLM for testing purposes."""
responses: List responses: List[str]
sleep: Optional[float] = None sleep: Optional[float] = None
i: int = 0 i: int = 0

@ -15,11 +15,10 @@ from typing import (
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable from langchain.schema.runnable import RunnableSerializable
from langchain.utils import get_pydantic_field_names from langchain.utils import get_pydantic_field_names
if TYPE_CHECKING: if TYPE_CHECKING:
@ -54,7 +53,7 @@ LanguageModelOutput = TypeVar("LanguageModelOutput")
class BaseLanguageModel( class BaseLanguageModel(
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
): ):
"""Abstract base class for interfacing with language models. """Abstract base class for interfacing with language models.

@ -16,7 +16,6 @@ from typing import (
from typing_extensions import get_args from typing_extensions import get_args
from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
from langchain.schema.output import ( from langchain.schema.output import (
ChatGeneration, ChatGeneration,
@ -25,12 +24,12 @@ from langchain.schema.output import (
GenerationChunk, GenerationChunk,
) )
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig from langchain.schema.runnable import RunnableConfig, RunnableSerializable
T = TypeVar("T") T = TypeVar("T")
class BaseLLMOutputParser(Serializable, Generic[T], ABC): class BaseLLMOutputParser(Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model.""" """Abstract base class for parsing the outputs of a model."""
@abstractmethod @abstractmethod
@ -63,7 +62,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
class BaseGenerationOutputParser( class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
): ):
"""Base class to parse the output of an LLM call.""" """Base class to parse the output of an LLM call."""
@ -121,7 +120,9 @@ class BaseGenerationOutputParser(
) )
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]): class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call. """Base class to parse the output of an LLM call.
Output parsers help structure language model responses. Output parsers help structure language model responses.

@ -7,15 +7,14 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import yaml import yaml
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig from langchain.schema.runnable import RunnableConfig, RunnableSerializable
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC): class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt.""" """Base class for all prompt templates, returning a prompt."""
input_variables: List[str] input_variables: List[str]

@ -6,9 +6,8 @@ from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.runnable import Runnable, RunnableConfig from langchain.schema.runnable import RunnableConfig, RunnableSerializable
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -18,7 +17,7 @@ if TYPE_CHECKING:
) )
class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
"""Abstract base class for a Document retrieval system. """Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return A retrieval system is defined as something that can take string queries and return

@ -2,13 +2,14 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.base import ( from langchain.schema.runnable.base import (
Runnable, Runnable,
RunnableBinding, RunnableBinding,
RunnableBranch,
RunnableLambda, RunnableLambda,
RunnableMap, RunnableMap,
RunnableSequence, RunnableSequence,
RunnableWithFallbacks, RunnableSerializable,
) )
from langchain.schema.runnable.branch import RunnableBranch
from langchain.schema.runnable.config import RunnableConfig, patch_config from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable from langchain.schema.runnable.router import RouterInput, RouterRunnable
@ -19,6 +20,7 @@ __all__ = [
"RouterInput", "RouterInput",
"RouterRunnable", "RouterRunnable",
"Runnable", "Runnable",
"RunnableSerializable",
"RunnableBinding", "RunnableBinding",
"RunnableBranch", "RunnableBranch",
"RunnableConfig", "RunnableConfig",

@ -11,8 +11,7 @@ from typing import (
Union, Union,
) )
from langchain.load.serializable import Serializable from langchain.schema.runnable.base import Input, Output, RunnableSerializable
from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.passthrough import RunnablePassthrough
@ -104,7 +103,7 @@ class PutLocalVar(RunnablePassthrough):
class GetLocalVar( class GetLocalVar(
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]] RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
): ):
key: str key: str
"""The key to extract from the local state.""" """The key to extract from the local state."""

@ -36,6 +36,9 @@ if TYPE_CHECKING:
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
from langchain.callbacks.tracers.log_stream import RunLogPatch from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.schema.runnable.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
@ -119,6 +122,24 @@ class Runnable(Generic[Input, Output], ABC):
self.__class__.__name__ + "Output", __root__=(root_type, None) self.__class__.__name__ + "Output", __root__=(root_type, None)
) )
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
class _Config:
arbitrary_types_allowed = True
include = include or []
return create_model( # type: ignore[call-overload]
self.__class__.__name__ + "Config",
__config__=_Config,
**{
field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items()
if field_name in include
},
)
def __or__( def __or__(
self, self,
other: Union[ other: Union[
@ -437,7 +458,9 @@ class Runnable(Generic[Input, Output], ABC):
fallbacks: Sequence[Runnable[Input, Output]], fallbacks: Sequence[Runnable[Input, Output]],
*, *,
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,), exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
) -> RunnableWithFallbacks[Input, Output]: ) -> RunnableWithFallbacksT[Input, Output]:
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
return RunnableWithFallbacks( return RunnableWithFallbacks(
runnable=self, runnable=self,
fallbacks=fallbacks, fallbacks=fallbacks,
@ -812,462 +835,11 @@ class Runnable(Generic[Input, Output], ABC):
await run_manager.on_chain_end(final_output, inputs=final_input) await run_manager.on_chain_end(final_output, inputs=final_input)
class RunnableBranch(Serializable, Runnable[Input, Output]): class RunnableSerializable(Serializable, Runnable[Input, Output]):
"""A Runnable that selects which branch to run based on a condition. pass
The runnable is initialized with a list of (condition, runnable) pairs and
a default branch.
When operating on an input, the first condition that evaluates to True is
selected, and the corresponding runnable is run on the input.
If no condition evaluates to True, the default branch is run on the input.
Examples:
.. code-block:: python
from langchain.schema.runnable import RunnableBranch
branch = RunnableBranch(
(lambda x: isinstance(x, str), lambda x: x.upper()),
(lambda x: isinstance(x, int), lambda x: x + 1),
(lambda x: isinstance(x, float), lambda x: x * 2),
lambda x: "goodbye",
)
branch.invoke("hello") # "HELLO"
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
Callable[[Input], Awaitable[bool]],
],
RunnableLike,
],
RunnableLike, # To accommodate the default branch
],
) -> None:
"""A Runnable that runs one of two branches based on a condition."""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
default = branches[-1]
if not isinstance(
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
)
default_ = cast(
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
)
_branches = []
for branch in branches[:-1]:
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
raise TypeError(
f"RunnableBranch branches must be "
f"tuples or lists, not {type(branch)}"
)
if not len(branch) == 2:
raise ValueError(
f"RunnableBranch branches must be "
f"tuples or lists of length 2, not {len(branch)}"
)
condition, runnable = branch
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_)
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""RunnableBranch is serializable if all its branches are serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema
return super().input_schema
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = runnable.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
)
break
else:
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Async version of invoke."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = await runnable.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
"""
A Runnable that can fallback to other Runnables if it fails.
"""
runnable: Runnable[Input, Output]
fallbacks: Sequence[Runnable[Input, Output]]
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
return self.runnable.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.runnable.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.runnable.output_schema
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable
yield from self.fallbacks
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = runnable.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
run_manager.on_chain_error(first_error)
raise first_error
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = await runnable.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
await run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await run_manager.on_chain_error(first_error)
raise first_error
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
first_error = None
for runnable in self.runnables:
try:
outputs = runnable.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
for rm in run_managers:
rm.on_chain_error(e)
raise e
else:
for rm, output in zip(run_managers, outputs):
rm.on_chain_end(output)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
for rm in run_managers:
rm.on_chain_error(first_error)
raise first_error
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
)
first_error = None
for runnable in self.runnables:
try:
outputs = await runnable.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
else:
await asyncio.gather(
*(
rm.on_chain_end(output)
for rm, output in zip(run_managers, outputs)
)
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
raise first_error
class RunnableSequence(Serializable, Runnable[Input, Output]): class RunnableSequence(RunnableSerializable[Input, Output]):
""" """
A sequence of runnables, where the output of each is the input of the next. A sequence of runnables, where the output of each is the input of the next.
""" """
@ -1749,7 +1321,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
yield chunk yield chunk
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
""" """
A runnable that runs a mapping of runnables in parallel, A runnable that runs a mapping of runnables in parallel,
and returns a mapping of their outputs. and returns a mapping of their outputs.
@ -1799,7 +1371,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"RunnableMapInput", "RunnableMapInput",
**{ **{
k: (v.type_, v.default) k: (v.annotation, v.default)
for step in self.steps.values() for step in self.steps.values()
for k, v in step.input_schema.__fields__.items() for k, v in step.input_schema.__fields__.items()
if k != "__root__" if k != "__root__"
@ -2374,7 +1946,7 @@ class RunnableLambda(Runnable[Input, Output]):
return await super().ainvoke(input, config) return await super().ainvoke(input, config)
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
""" """
A runnable that delegates calls to another runnable A runnable that delegates calls to another runnable
with each element of the input sequence. with each element of the input sequence.
@ -2413,6 +1985,11 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
), ),
) )
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@ -2455,7 +2032,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
return await self._acall_with_config(self._ainvoke, input, config) return await self._acall_with_config(self._ainvoke, input, config)
class RunnableBinding(Serializable, Runnable[Input, Output]): class RunnableBinding(RunnableSerializable[Input, Output]):
""" """
A runnable that delegates calls to another runnable with a set of kwargs. A runnable that delegates calls to another runnable with a set of kwargs.
""" """
@ -2485,6 +2062,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def output_schema(self) -> Type[BaseModel]: def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema return self.bound.output_schema
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True

@ -0,0 +1,235 @@
from typing import (
Any,
Awaitable,
Callable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain.load.dump import dumpd
from langchain.pydantic_v1 import BaseModel
from langchain.schema.runnable.base import (
Runnable,
RunnableLike,
RunnableSerializable,
coerce_to_runnable,
)
from langchain.schema.runnable.config import (
RunnableConfig,
ensure_config,
get_callback_manager_for_config,
patch_config,
)
from langchain.schema.runnable.utils import Input, Output
class RunnableBranch(RunnableSerializable[Input, Output]):
"""A Runnable that selects which branch to run based on a condition.
The runnable is initialized with a list of (condition, runnable) pairs and
a default branch.
When operating on an input, the first condition that evaluates to True is
selected, and the corresponding runnable is run on the input.
If no condition evaluates to True, the default branch is run on the input.
Examples:
.. code-block:: python
from langchain.schema.runnable import RunnableBranch
branch = RunnableBranch(
(lambda x: isinstance(x, str), lambda x: x.upper()),
(lambda x: isinstance(x, int), lambda x: x + 1),
(lambda x: isinstance(x, float), lambda x: x * 2),
lambda x: "goodbye",
)
branch.invoke("hello") # "HELLO"
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
Callable[[Input], Awaitable[bool]],
],
RunnableLike,
],
RunnableLike, # To accommodate the default branch
],
) -> None:
"""A Runnable that runs one of two branches based on a condition."""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
default = branches[-1]
if not isinstance(
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
)
default_ = cast(
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
)
_branches = []
for branch in branches[:-1]:
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
raise TypeError(
f"RunnableBranch branches must be "
f"tuples or lists, not {type(branch)}"
)
if not len(branch) == 2:
raise ValueError(
f"RunnableBranch branches must be "
f"tuples or lists of length 2, not {len(branch)}"
)
condition, runnable = branch
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_)
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""RunnableBranch is serializable if all its branches are serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> Type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema
return super().input_schema
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = runnable.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Async version of invoke."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = await runnable.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output

@ -0,0 +1,286 @@
import asyncio
from typing import (
TYPE_CHECKING,
Any,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from langchain.load.dump import dumpd
from langchain.pydantic_v1 import BaseModel
from langchain.schema.runnable.base import Runnable, RunnableSerializable
from langchain.schema.runnable.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
get_config_list,
patch_config,
)
from langchain.schema.runnable.utils import Input, Output
if TYPE_CHECKING:
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
"""
A Runnable that can fallback to other Runnables if it fails.
"""
runnable: Runnable[Input, Output]
fallbacks: Sequence[Runnable[Input, Output]]
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
return self.runnable.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.runnable.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.runnable.output_schema
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.runnable.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable
yield from self.fallbacks
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = runnable.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
run_manager.on_chain_error(first_error)
raise first_error
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
try:
output = await runnable.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
await run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await run_manager.on_chain_error(first_error)
raise first_error
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
first_error = None
for runnable in self.runnables:
try:
outputs = runnable.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
for rm in run_managers:
rm.on_chain_error(e)
raise e
else:
for rm, output in zip(run_managers, outputs):
rm.on_chain_end(output)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
for rm in run_managers:
rm.on_chain_error(first_error)
raise first_error
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
)
first_error = None
for runnable in self.runnables:
try:
outputs = await runnable.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
else:
await asyncio.gather(
*(
rm.on_chain_end(output)
for rm, output in zip(run_managers, outputs)
)
)
return outputs
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
raise first_error

@ -16,9 +16,13 @@ from typing import (
cast, cast,
) )
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import BaseModel, create_model from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.runnable.base import Input, Runnable, RunnableMap from langchain.schema.runnable.base import (
Input,
Runnable,
RunnableMap,
RunnableSerializable,
)
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
from langchain.schema.runnable.utils import AddableDict from langchain.schema.runnable.utils import AddableDict
from langchain.utils.aiter import atee, py_anext from langchain.utils.aiter import atee, py_anext
@ -33,7 +37,7 @@ async def aidentity(x: Input) -> Input:
return x return x
class RunnablePassthrough(Serializable, Runnable[Input, Input]): class RunnablePassthrough(RunnableSerializable[Input, Input]):
""" """
A runnable that passes through the input. A runnable that passes through the input.
""" """
@ -109,7 +113,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
yield chunk yield chunk
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]): class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
""" """
A runnable that assigns key-value pairs to Dict[str, Any] inputs. A runnable that assigns key-value pairs to Dict[str, Any] inputs.
""" """

@ -14,8 +14,13 @@ from typing import (
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain.load.serializable import Serializable from langchain.schema.runnable.base import (
from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable Input,
Output,
Runnable,
RunnableSerializable,
coerce_to_runnable,
)
from langchain.schema.runnable.config import ( from langchain.schema.runnable.config import (
RunnableConfig, RunnableConfig,
get_config_list, get_config_list,
@ -36,7 +41,7 @@ class RouterInput(TypedDict):
input: Any input: Any
class RouterRunnable(Serializable, Runnable[RouterInput, Output]): class RouterRunnable(RunnableSerializable[RouterInput, Output]):
""" """
A runnable that routes to a set of runnables based on Input['key']. A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable. Returns the output of the selected runnable.

@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForToolRun, CallbackManagerForToolRun,
Callbacks, Callbacks,
) )
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import ( from langchain.pydantic_v1 import (
BaseModel, BaseModel,
Extra, Extra,
@ -25,7 +26,7 @@ from langchain.pydantic_v1 import (
root_validator, root_validator,
validate_arguments, validate_arguments,
) )
from langchain.schema.runnable import Runnable, RunnableConfig from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable
class SchemaAnnotationError(TypeError): class SchemaAnnotationError(TypeError):
@ -97,7 +98,7 @@ class ToolException(Exception):
pass pass
class BaseTool(BaseModel, Runnable[Union[str, Dict], Any]): class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
"""Interface LangChain tools must implement.""" """Interface LangChain tools must implement."""
def __init_subclass__(cls, **kwargs: Any) -> None: def __init_subclass__(cls, **kwargs: Any) -> None:
@ -165,10 +166,9 @@ class ChildTool(BaseTool):
] = False ] = False
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
class Config: class Config(Serializable.Config):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property @property

@ -2,7 +2,7 @@
"""Tools for interacting with Spark SQL.""" """Tools for interacting with Spark SQL."""
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -21,13 +21,8 @@ class BaseSparkSQLTool(BaseModel):
db: SparkSQL = Field(exclude=True) db: SparkSQL = Field(exclude=True)
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config): class Config(BaseTool.Config):
"""Configuration for this pydantic object.""" pass
arbitrary_types_allowed = True
extra = Extra.forbid
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):

@ -21,13 +21,8 @@ class BaseSQLDatabaseTool(BaseModel):
db: SQLDatabase = Field(exclude=True) db: SQLDatabase = Field(exclude=True)
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config): class Config(BaseTool.Config):
"""Configuration for this pydantic object.""" pass
arbitrary_types_allowed = True
extra = Extra.forbid
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

@ -18,9 +18,7 @@ class BaseVectorStoreTool(BaseModel):
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
class Config(BaseTool.Config): class Config(BaseTool.Config):
"""Configuration for this pydantic object.""" pass
arbitrary_types_allowed = True
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]: def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:

Loading…
Cancel
Save