From 52e5a8b43e46977f0f6fdea1788488b7201fccc0 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:07:30 +0100 Subject: [PATCH] Create new RunnableSerializable class in preparation for configurable runnables - Also move RunnableBranch to its own file --- libs/langchain/langchain/chains/base.py | 5 +- libs/langchain/langchain/llms/fake.py | 2 +- .../langchain/schema/language_model.py | 5 +- .../langchain/schema/output_parser.py | 11 +- .../langchain/schema/prompt_template.py | 5 +- libs/langchain/langchain/schema/retriever.py | 5 +- .../langchain/schema/runnable/__init__.py | 4 +- .../langchain/schema/runnable/_locals.py | 5 +- .../langchain/schema/runnable/base.py | 247 +++--------------- .../langchain/schema/runnable/branch.py | 234 +++++++++++++++++ .../langchain/schema/runnable/passthrough.py | 12 +- .../langchain/schema/runnable/router.py | 11 +- libs/langchain/langchain/tools/base.py | 5 +- .../langchain/tools/spark_sql/tool.py | 8 - .../langchain/tools/sql_database/tool.py | 8 - .../langchain/tools/vectorstore/tool.py | 5 - 16 files changed, 313 insertions(+), 259 deletions(-) create mode 100644 libs/langchain/langchain/schema/runnable/branch.py diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 94a9bdfd0a..fd54cfe6bc 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -21,7 +21,6 @@ from langchain.callbacks.manager import ( Callbacks, ) from langchain.load.dump import dumpd -from langchain.load.serializable import Serializable from langchain.pydantic_v1 import ( BaseModel, Field, @@ -30,7 +29,7 @@ from langchain.pydantic_v1 import ( validator, ) from langchain.schema import RUN_KEY, BaseMemory, RunInfo -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import RunnableConfig, RunnableSerializable logger = logging.getLogger(__name__) @@ -39,7 +38,7 @@ def _get_verbosity() -> bool: return langchain.verbose -class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): +class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): """Abstract base class for creating structured sequences of calls to components. Chains should be used to encode a sequence of calls to components like diff --git a/libs/langchain/langchain/llms/fake.py b/libs/langchain/langchain/llms/fake.py index b3c22eb644..cb3ea2792f 100644 --- a/libs/langchain/langchain/llms/fake.py +++ b/libs/langchain/langchain/llms/fake.py @@ -14,7 +14,7 @@ from langchain.schema.runnable import RunnableConfig class FakeListLLM(LLM): """Fake LLM for testing purposes.""" - responses: List + responses: List[str] sleep: Optional[float] = None i: int = 0 diff --git a/libs/langchain/langchain/schema/language_model.py b/libs/langchain/langchain/schema/language_model.py index 16e8edbc9c..c4e8e5169d 100644 --- a/libs/langchain/langchain/schema/language_model.py +++ b/libs/langchain/langchain/schema/language_model.py @@ -15,11 +15,10 @@ from typing import ( from typing_extensions import TypeAlias -from langchain.load.serializable import Serializable from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string from langchain.schema.output import LLMResult from langchain.schema.prompt import PromptValue -from langchain.schema.runnable import Runnable +from langchain.schema.runnable import RunnableSerializable from langchain.utils import get_pydantic_field_names if TYPE_CHECKING: @@ -54,7 +53,7 @@ LanguageModelOutput = TypeVar("LanguageModelOutput") class BaseLanguageModel( - Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC + RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC ): """Abstract base class for interfacing with language models. diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 157c1cd5f0..c675dfe49b 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -16,7 +16,6 @@ from typing import ( from typing_extensions import get_args -from langchain.load.serializable import Serializable from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk from langchain.schema.output import ( ChatGeneration, @@ -25,12 +24,12 @@ from langchain.schema.output import ( GenerationChunk, ) from langchain.schema.prompt import PromptValue -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import RunnableConfig, RunnableSerializable T = TypeVar("T") -class BaseLLMOutputParser(Serializable, Generic[T], ABC): +class BaseLLMOutputParser(Generic[T], ABC): """Abstract base class for parsing the outputs of a model.""" @abstractmethod @@ -63,7 +62,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): class BaseGenerationOutputParser( - BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] + BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T] ): """Base class to parse the output of an LLM call.""" @@ -121,7 +120,9 @@ class BaseGenerationOutputParser( ) -class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]): +class BaseOutputParser( + BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T] +): """Base class to parse the output of an LLM call. Output parsers help structure language model responses. diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index ab790753aa..b72e2fe55e 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -7,15 +7,14 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union import yaml -from langchain.load.serializable import Serializable from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator from langchain.schema.document import Document from langchain.schema.output_parser import BaseOutputParser from langchain.schema.prompt import PromptValue -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import RunnableConfig, RunnableSerializable -class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC): +class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): """Base class for all prompt templates, returning a prompt.""" input_variables: List[str] diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 04c2835434..25934eb3ed 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -6,9 +6,8 @@ from inspect import signature from typing import TYPE_CHECKING, Any, Dict, List, Optional from langchain.load.dump import dumpd -from langchain.load.serializable import Serializable from langchain.schema.document import Document -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import RunnableConfig, RunnableSerializable if TYPE_CHECKING: from langchain.callbacks.manager import ( @@ -18,7 +17,7 @@ if TYPE_CHECKING: ) -class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): +class BaseRetriever(RunnableSerializable[str, List[Document]], ABC): """Abstract base class for a Document retrieval system. A retrieval system is defined as something that can take string queries and return diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 2b068d5eba..bde6121bed 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -2,12 +2,13 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.base import ( Runnable, RunnableBinding, - RunnableBranch, RunnableLambda, RunnableMap, RunnableSequence, + RunnableSerializable, RunnableWithFallbacks, ) +from langchain.schema.runnable.branch import RunnableBranch from langchain.schema.runnable.config import RunnableConfig, patch_config from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable @@ -19,6 +20,7 @@ __all__ = [ "RouterInput", "RouterRunnable", "Runnable", + "RunnableSerializable", "RunnableBinding", "RunnableBranch", "RunnableConfig", diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py index 839bb5fbc8..e2fe854114 100644 --- a/libs/langchain/langchain/schema/runnable/_locals.py +++ b/libs/langchain/langchain/schema/runnable/_locals.py @@ -11,8 +11,7 @@ from typing import ( Union, ) -from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import Input, Output, Runnable +from langchain.schema.runnable.base import Input, Output, RunnableSerializable from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.passthrough import RunnablePassthrough @@ -104,7 +103,7 @@ class PutLocalVar(RunnablePassthrough): class GetLocalVar( - Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]] + RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]] ): key: str """The key to extract from the local state.""" diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7286333cf4..35939b128c 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -119,6 +119,24 @@ class Runnable(Generic[Input, Output], ABC): self.__class__.__name__ + "Output", __root__=(root_type, None) ) + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + class _Config: + arbitrary_types_allowed = True + + include = include or [] + + return create_model( # type: ignore[call-overload] + self.__class__.__name__ + "Config", + __config__=_Config, + **{ + field_name: (field_type, None) + for field_name, field_type in RunnableConfig.__annotations__.items() + if field_name in include + }, + ) + def __or__( self, other: Union[ @@ -812,209 +830,11 @@ class Runnable(Generic[Input, Output], ABC): await run_manager.on_chain_end(final_output, inputs=final_input) -class RunnableBranch(Serializable, Runnable[Input, Output]): - """A Runnable that selects which branch to run based on a condition. - - The runnable is initialized with a list of (condition, runnable) pairs and - a default branch. - - When operating on an input, the first condition that evaluates to True is - selected, and the corresponding runnable is run on the input. - - If no condition evaluates to True, the default branch is run on the input. - - Examples: - - .. code-block:: python - - from langchain.schema.runnable import RunnableBranch - - branch = RunnableBranch( - (lambda x: isinstance(x, str), lambda x: x.upper()), - (lambda x: isinstance(x, int), lambda x: x + 1), - (lambda x: isinstance(x, float), lambda x: x * 2), - lambda x: "goodbye", - ) - - branch.invoke("hello") # "HELLO" - branch.invoke(None) # "goodbye" - """ - - branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] - default: Runnable[Input, Output] - - def __init__( - self, - *branches: Union[ - Tuple[ - Union[ - Runnable[Input, bool], - Callable[[Input], bool], - Callable[[Input], Awaitable[bool]], - ], - RunnableLike, - ], - RunnableLike, # To accommodate the default branch - ], - ) -> None: - """A Runnable that runs one of two branches based on a condition.""" - if len(branches) < 2: - raise ValueError("RunnableBranch requires at least two branches") - - default = branches[-1] - - if not isinstance( - default, (Runnable, Callable, Mapping) # type: ignore[arg-type] - ): - raise TypeError( - "RunnableBranch default must be runnable, callable or mapping." - ) - - default_ = cast( - Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default)) - ) - - _branches = [] - - for branch in branches[:-1]: - if not isinstance(branch, (tuple, list)): # type: ignore[arg-type] - raise TypeError( - f"RunnableBranch branches must be " - f"tuples or lists, not {type(branch)}" - ) - - if not len(branch) == 2: - raise ValueError( - f"RunnableBranch branches must be " - f"tuples or lists of length 2, not {len(branch)}" - ) - condition, runnable = branch - condition = cast(Runnable[Input, bool], coerce_to_runnable(condition)) - runnable = coerce_to_runnable(runnable) - _branches.append((condition, runnable)) - - super().__init__(branches=_branches, default=default_) - - class Config: - arbitrary_types_allowed = True - - @classmethod - def is_lc_serializable(cls) -> bool: - """RunnableBranch is serializable if all its branches are serializable.""" - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """The namespace of a RunnableBranch is the namespace of its default branch.""" - return cls.__module__.split(".")[:-1] - - @property - def input_schema(self) -> type[BaseModel]: - runnables = ( - [self.default] - + [r for _, r in self.branches] - + [r for r, _ in self.branches] - ) - - for runnable in runnables: - if runnable.input_schema.schema().get("type") is not None: - return runnable.input_schema - - return super().input_schema - - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - """First evaluates the condition, then delegate to true or false branch.""" - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - - try: - for idx, branch in enumerate(self.branches): - condition, runnable = branch - - expression_value = condition.invoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), - ), - ) +class RunnableSerializable(Serializable, Runnable[Input, Output]): + pass - if expression_value: - output = runnable.invoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), - ), - ) - break - else: - output = self.default.invoke( - input, - config=patch_config( - config, callbacks=run_manager.get_child(tag="branch:default") - ), - ) - except Exception as e: - run_manager.on_chain_error(e) - raise - run_manager.on_chain_end(dumpd(output)) - return output - async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - """Async version of invoke.""" - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - try: - for idx, branch in enumerate(self.branches): - condition, runnable = branch - - expression_value = await condition.ainvoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), - ), - ) - - if expression_value: - output = await runnable.ainvoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), - ), - **kwargs, - ) - break - else: - output = await self.default.ainvoke( - input, - config=patch_config( - config, callbacks=run_manager.get_child(tag="branch:default") - ), - **kwargs, - ) - except Exception as e: - run_manager.on_chain_error(e) - raise - run_manager.on_chain_end(dumpd(output)) - return output - - -class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): +class RunnableWithFallbacks(RunnableSerializable[Input, Output]): """ A Runnable that can fallback to other Runnables if it fails. """ @@ -1042,6 +862,11 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): def output_schema(self) -> Type[BaseModel]: return self.runnable.output_schema + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + return self.runnable.config_schema(include=include) + @classmethod def is_lc_serializable(cls) -> bool: return True @@ -1267,7 +1092,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): raise first_error -class RunnableSequence(Serializable, Runnable[Input, Output]): +class RunnableSequence(RunnableSerializable[Input, Output]): """ A sequence of runnables, where the output of each is the input of the next. """ @@ -1749,7 +1574,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): yield chunk -class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): +class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]): """ A runnable that runs a mapping of runnables in parallel, and returns a mapping of their outputs. @@ -1799,7 +1624,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): return create_model( # type: ignore[call-overload] "RunnableMapInput", **{ - k: (v.type_, v.default) + k: (v.annotation, v.default) for step in self.steps.values() for k, v in step.input_schema.__fields__.items() if k != "__root__" @@ -2374,7 +2199,7 @@ class RunnableLambda(Runnable[Input, Output]): return await super().ainvoke(input, config) -class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): +class RunnableEach(RunnableSerializable[List[Input], List[Output]]): """ A runnable that delegates calls to another runnable with each element of the input sequence. @@ -2413,6 +2238,11 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): ), ) + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + return self.bound.config_schema(include=include) + @classmethod def is_lc_serializable(cls) -> bool: return True @@ -2455,7 +2285,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): return await self._acall_with_config(self._ainvoke, input, config) -class RunnableBinding(Serializable, Runnable[Input, Output]): +class RunnableBinding(RunnableSerializable[Input, Output]): """ A runnable that delegates calls to another runnable with a set of kwargs. """ @@ -2485,6 +2315,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): def output_schema(self) -> Type[BaseModel]: return self.bound.output_schema + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + return self.bound.config_schema(include=include) + @classmethod def is_lc_serializable(cls) -> bool: return True diff --git a/libs/langchain/langchain/schema/runnable/branch.py b/libs/langchain/langchain/schema/runnable/branch.py new file mode 100644 index 0000000000..d609fedeff --- /dev/null +++ b/libs/langchain/langchain/schema/runnable/branch.py @@ -0,0 +1,234 @@ +from typing import ( + Any, + Awaitable, + Callable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from langchain.load.dump import dumpd +from langchain.pydantic_v1 import BaseModel +from langchain.schema.runnable.base import ( + Runnable, + RunnableLike, + RunnableSerializable, + coerce_to_runnable, +) +from langchain.schema.runnable.config import ( + RunnableConfig, + ensure_config, + get_callback_manager_for_config, + patch_config, +) +from langchain.schema.runnable.utils import Input, Output + + +class RunnableBranch(RunnableSerializable[Input, Output]): + """A Runnable that selects which branch to run based on a condition. + + The runnable is initialized with a list of (condition, runnable) pairs and + a default branch. + + When operating on an input, the first condition that evaluates to True is + selected, and the corresponding runnable is run on the input. + + If no condition evaluates to True, the default branch is run on the input. + + Examples: + + .. code-block:: python + + from langchain.schema.runnable import RunnableBranch + + branch = RunnableBranch( + (lambda x: isinstance(x, str), lambda x: x.upper()), + (lambda x: isinstance(x, int), lambda x: x + 1), + (lambda x: isinstance(x, float), lambda x: x * 2), + lambda x: "goodbye", + ) + + branch.invoke("hello") # "HELLO" + branch.invoke(None) # "goodbye" + """ + + branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] + default: Runnable[Input, Output] + + def __init__( + self, + *branches: Union[ + Tuple[ + Union[ + Runnable[Input, bool], + Callable[[Input], bool], + Callable[[Input], Awaitable[bool]], + ], + RunnableLike, + ], + RunnableLike, # To accommodate the default branch + ], + ) -> None: + """A Runnable that runs one of two branches based on a condition.""" + if len(branches) < 2: + raise ValueError("RunnableBranch requires at least two branches") + + default = branches[-1] + + if not isinstance( + default, (Runnable, Callable, Mapping) # type: ignore[arg-type] + ): + raise TypeError( + "RunnableBranch default must be runnable, callable or mapping." + ) + + default_ = cast( + Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default)) + ) + + _branches = [] + + for branch in branches[:-1]: + if not isinstance(branch, (tuple, list)): # type: ignore[arg-type] + raise TypeError( + f"RunnableBranch branches must be " + f"tuples or lists, not {type(branch)}" + ) + + if not len(branch) == 2: + raise ValueError( + f"RunnableBranch branches must be " + f"tuples or lists of length 2, not {len(branch)}" + ) + condition, runnable = branch + condition = cast(Runnable[Input, bool], coerce_to_runnable(condition)) + runnable = coerce_to_runnable(runnable) + _branches.append((condition, runnable)) + + super().__init__(branches=_branches, default=default_) + + class Config: + arbitrary_types_allowed = True + + @classmethod + def is_lc_serializable(cls) -> bool: + """RunnableBranch is serializable if all its branches are serializable.""" + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """The namespace of a RunnableBranch is the namespace of its default branch.""" + return cls.__module__.split(".")[:-1] + + @property + def input_schema(self) -> type[BaseModel]: + runnables = ( + [self.default] + + [r for _, r in self.branches] + + [r for r, _ in self.branches] + ) + + for runnable in runnables: + if runnable.input_schema.schema().get("type") is not None: + return runnable.input_schema + + return super().input_schema + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + """First evaluates the condition, then delegate to true or false branch.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = condition.invoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), + ), + ) + + if expression_value: + output = runnable.invoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), + ), + **kwargs, + ) + break + else: + output = self.default.invoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + **kwargs, + ) + except Exception as e: + run_manager.on_chain_error(e) + raise + run_manager.on_chain_end(dumpd(output)) + return output + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + """Async version of invoke.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = await condition.ainvoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), + ), + ) + + if expression_value: + output = await runnable.ainvoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), + ), + **kwargs, + ) + break + else: + output = await self.default.ainvoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + **kwargs, + ) + except Exception as e: + run_manager.on_chain_error(e) + raise + run_manager.on_chain_end(dumpd(output)) + return output diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 18afe82591..1d1b046a57 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -16,9 +16,13 @@ from typing import ( cast, ) -from langchain.load.serializable import Serializable from langchain.pydantic_v1 import BaseModel, create_model -from langchain.schema.runnable.base import Input, Runnable, RunnableMap +from langchain.schema.runnable.base import ( + Input, + Runnable, + RunnableMap, + RunnableSerializable, +) from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config from langchain.schema.runnable.utils import AddableDict from langchain.utils.aiter import atee, py_anext @@ -33,7 +37,7 @@ async def aidentity(x: Input) -> Input: return x -class RunnablePassthrough(Serializable, Runnable[Input, Input]): +class RunnablePassthrough(RunnableSerializable[Input, Input]): """ A runnable that passes through the input. """ @@ -109,7 +113,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): yield chunk -class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]): +class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): """ A runnable that assigns key-value pairs to Dict[str, Any] inputs. """ diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index f697c0328c..9638235fc8 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -14,8 +14,13 @@ from typing import ( from typing_extensions import TypedDict -from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable +from langchain.schema.runnable.base import ( + Input, + Output, + Runnable, + RunnableSerializable, + coerce_to_runnable, +) from langchain.schema.runnable.config import ( RunnableConfig, get_config_list, @@ -36,7 +41,7 @@ class RouterInput(TypedDict): input: Any -class RouterRunnable(Serializable, Runnable[RouterInput, Output]): +class RouterRunnable(RunnableSerializable[RouterInput, Output]): """ A runnable that routes to a set of runnables based on Input['key']. Returns the output of the selected runnable. diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 269e2b4846..2ae81d246b 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -25,7 +25,7 @@ from langchain.pydantic_v1 import ( root_validator, validate_arguments, ) -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable class SchemaAnnotationError(TypeError): @@ -97,7 +97,7 @@ class ToolException(Exception): pass -class BaseTool(BaseModel, Runnable[Union[str, Dict], Any]): +class BaseTool(RunnableSerializable[Union[str, Dict], Any]): """Interface LangChain tools must implement.""" def __init_subclass__(cls, **kwargs: Any) -> None: @@ -168,7 +168,6 @@ class ChildTool(BaseTool): class Config: """Configuration for this pydantic object.""" - extra = Extra.forbid arbitrary_types_allowed = True @property diff --git a/libs/langchain/langchain/tools/spark_sql/tool.py b/libs/langchain/langchain/tools/spark_sql/tool.py index 4a650f7d51..c79bfd193a 100644 --- a/libs/langchain/langchain/tools/spark_sql/tool.py +++ b/libs/langchain/langchain/tools/spark_sql/tool.py @@ -21,14 +21,6 @@ class BaseSparkSQLTool(BaseModel): db: SparkSQL = Field(exclude=True) - # Override BaseTool.Config to appease mypy - # See https://github.com/pydantic/pydantic/issues/4173 - class Config(BaseTool.Config): - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - extra = Extra.forbid - class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): """Tool for querying a Spark SQL.""" diff --git a/libs/langchain/langchain/tools/sql_database/tool.py b/libs/langchain/langchain/tools/sql_database/tool.py index 75f45c7b9e..eba921c163 100644 --- a/libs/langchain/langchain/tools/sql_database/tool.py +++ b/libs/langchain/langchain/tools/sql_database/tool.py @@ -21,14 +21,6 @@ class BaseSQLDatabaseTool(BaseModel): db: SQLDatabase = Field(exclude=True) - # Override BaseTool.Config to appease mypy - # See https://github.com/pydantic/pydantic/issues/4173 - class Config(BaseTool.Config): - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - extra = Extra.forbid - class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): """Tool for querying a SQL database.""" diff --git a/libs/langchain/langchain/tools/vectorstore/tool.py b/libs/langchain/langchain/tools/vectorstore/tool.py index a0507964e7..c62145c4c9 100644 --- a/libs/langchain/langchain/tools/vectorstore/tool.py +++ b/libs/langchain/langchain/tools/vectorstore/tool.py @@ -17,11 +17,6 @@ class BaseVectorStoreTool(BaseModel): vectorstore: VectorStore = Field(exclude=True) llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) - class Config(BaseTool.Config): - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]: values["description"] = values["template"].format(name=values["name"])