Create new RunnableSerializable class in preparation for configurable runnables

- Also move RunnableBranch to its own file
pull/11279/head
Nuno Campos 9 months ago
parent 33eb5f8300
commit 52e5a8b43e

@ -21,7 +21,6 @@ from langchain.callbacks.manager import (
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import (
BaseModel,
Field,
@ -30,7 +29,7 @@ from langchain.pydantic_v1 import (
validator,
)
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
logger = logging.getLogger(__name__)
@ -39,7 +38,7 @@ def _get_verbosity() -> bool:
return langchain.verbose
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like

@ -14,7 +14,7 @@ from langchain.schema.runnable import RunnableConfig
class FakeListLLM(LLM):
"""Fake LLM for testing purposes."""
responses: List
responses: List[str]
sleep: Optional[float] = None
i: int = 0

@ -15,11 +15,10 @@ from typing import (
from typing_extensions import TypeAlias
from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable
from langchain.schema.runnable import RunnableSerializable
from langchain.utils import get_pydantic_field_names
if TYPE_CHECKING:
@ -54,7 +53,7 @@ LanguageModelOutput = TypeVar("LanguageModelOutput")
class BaseLanguageModel(
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
):
"""Abstract base class for interfacing with language models.

@ -16,7 +16,6 @@ from typing import (
from typing_extensions import get_args
from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
from langchain.schema.output import (
ChatGeneration,
@ -25,12 +24,12 @@ from langchain.schema.output import (
GenerationChunk,
)
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
T = TypeVar("T")
class BaseLLMOutputParser(Serializable, Generic[T], ABC):
class BaseLLMOutputParser(Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
@ -63,7 +62,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call."""
@ -121,7 +120,9 @@ class BaseGenerationOutputParser(
)
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call.
Output parsers help structure language model responses.

@ -7,15 +7,14 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import yaml
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]

@ -6,9 +6,8 @@ from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema.document import Document
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
if TYPE_CHECKING:
from langchain.callbacks.manager import (
@ -18,7 +17,7 @@ if TYPE_CHECKING:
)
class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return

@ -2,12 +2,13 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.base import (
Runnable,
RunnableBinding,
RunnableBranch,
RunnableLambda,
RunnableMap,
RunnableSequence,
RunnableSerializable,
RunnableWithFallbacks,
)
from langchain.schema.runnable.branch import RunnableBranch
from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
@ -19,6 +20,7 @@ __all__ = [
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableSerializable",
"RunnableBinding",
"RunnableBranch",
"RunnableConfig",

@ -11,8 +11,7 @@ from typing import (
Union,
)
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
@ -104,7 +103,7 @@ class PutLocalVar(RunnablePassthrough):
class GetLocalVar(
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
):
key: str
"""The key to extract from the local state."""

@ -119,6 +119,24 @@ class Runnable(Generic[Input, Output], ABC):
self.__class__.__name__ + "Output", __root__=(root_type, None)
)
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
class _Config:
arbitrary_types_allowed = True
include = include or []
return create_model( # type: ignore[call-overload]
self.__class__.__name__ + "Config",
__config__=_Config,
**{
field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items()
if field_name in include
},
)
def __or__(
self,
other: Union[
@ -812,209 +830,11 @@ class Runnable(Generic[Input, Output], ABC):
await run_manager.on_chain_end(final_output, inputs=final_input)
class RunnableBranch(Serializable, Runnable[Input, Output]):
"""A Runnable that selects which branch to run based on a condition.
The runnable is initialized with a list of (condition, runnable) pairs and
a default branch.
When operating on an input, the first condition that evaluates to True is
selected, and the corresponding runnable is run on the input.
If no condition evaluates to True, the default branch is run on the input.
Examples:
.. code-block:: python
from langchain.schema.runnable import RunnableBranch
branch = RunnableBranch(
(lambda x: isinstance(x, str), lambda x: x.upper()),
(lambda x: isinstance(x, int), lambda x: x + 1),
(lambda x: isinstance(x, float), lambda x: x * 2),
lambda x: "goodbye",
)
branch.invoke("hello") # "HELLO"
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
Callable[[Input], Awaitable[bool]],
],
RunnableLike,
],
RunnableLike, # To accommodate the default branch
],
) -> None:
"""A Runnable that runs one of two branches based on a condition."""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
default = branches[-1]
if not isinstance(
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
)
default_ = cast(
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
)
_branches = []
for branch in branches[:-1]:
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
raise TypeError(
f"RunnableBranch branches must be "
f"tuples or lists, not {type(branch)}"
)
if not len(branch) == 2:
raise ValueError(
f"RunnableBranch branches must be "
f"tuples or lists of length 2, not {len(branch)}"
)
condition, runnable = branch
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_)
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""RunnableBranch is serializable if all its branches are serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema
return super().input_schema
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
class RunnableSerializable(Serializable, Runnable[Input, Output]):
pass
if expression_value:
output = runnable.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
)
break
else:
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Async version of invoke."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = await runnable.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
"""
A Runnable that can fallback to other Runnables if it fails.
"""
@ -1042,6 +862,11 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
def output_schema(self) -> Type[BaseModel]:
return self.runnable.output_schema
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.runnable.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -1267,7 +1092,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
raise first_error
class RunnableSequence(Serializable, Runnable[Input, Output]):
class RunnableSequence(RunnableSerializable[Input, Output]):
"""
A sequence of runnables, where the output of each is the input of the next.
"""
@ -1749,7 +1574,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
yield chunk
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
"""
A runnable that runs a mapping of runnables in parallel,
and returns a mapping of their outputs.
@ -1799,7 +1624,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
return create_model( # type: ignore[call-overload]
"RunnableMapInput",
**{
k: (v.type_, v.default)
k: (v.annotation, v.default)
for step in self.steps.values()
for k, v in step.input_schema.__fields__.items()
if k != "__root__"
@ -2374,7 +2199,7 @@ class RunnableLambda(Runnable[Input, Output]):
return await super().ainvoke(input, config)
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
"""
A runnable that delegates calls to another runnable
with each element of the input sequence.
@ -2413,6 +2238,11 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
),
)
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -2455,7 +2285,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
return await self._acall_with_config(self._ainvoke, input, config)
class RunnableBinding(Serializable, Runnable[Input, Output]):
class RunnableBinding(RunnableSerializable[Input, Output]):
"""
A runnable that delegates calls to another runnable with a set of kwargs.
"""
@ -2485,6 +2315,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True

@ -0,0 +1,234 @@
from typing import (
Any,
Awaitable,
Callable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from langchain.load.dump import dumpd
from langchain.pydantic_v1 import BaseModel
from langchain.schema.runnable.base import (
Runnable,
RunnableLike,
RunnableSerializable,
coerce_to_runnable,
)
from langchain.schema.runnable.config import (
RunnableConfig,
ensure_config,
get_callback_manager_for_config,
patch_config,
)
from langchain.schema.runnable.utils import Input, Output
class RunnableBranch(RunnableSerializable[Input, Output]):
"""A Runnable that selects which branch to run based on a condition.
The runnable is initialized with a list of (condition, runnable) pairs and
a default branch.
When operating on an input, the first condition that evaluates to True is
selected, and the corresponding runnable is run on the input.
If no condition evaluates to True, the default branch is run on the input.
Examples:
.. code-block:: python
from langchain.schema.runnable import RunnableBranch
branch = RunnableBranch(
(lambda x: isinstance(x, str), lambda x: x.upper()),
(lambda x: isinstance(x, int), lambda x: x + 1),
(lambda x: isinstance(x, float), lambda x: x * 2),
lambda x: "goodbye",
)
branch.invoke("hello") # "HELLO"
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
Callable[[Input], Awaitable[bool]],
],
RunnableLike,
],
RunnableLike, # To accommodate the default branch
],
) -> None:
"""A Runnable that runs one of two branches based on a condition."""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
default = branches[-1]
if not isinstance(
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
)
default_ = cast(
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
)
_branches = []
for branch in branches[:-1]:
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
raise TypeError(
f"RunnableBranch branches must be "
f"tuples or lists, not {type(branch)}"
)
if not len(branch) == 2:
raise ValueError(
f"RunnableBranch branches must be "
f"tuples or lists of length 2, not {len(branch)}"
)
condition, runnable = branch
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_)
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""RunnableBranch is serializable if all its branches are serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema
return super().input_schema
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = runnable.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Async version of invoke."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = await runnable.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output

@ -16,9 +16,13 @@ from typing import (
cast,
)
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.runnable.base import Input, Runnable, RunnableMap
from langchain.schema.runnable.base import (
Input,
Runnable,
RunnableMap,
RunnableSerializable,
)
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
from langchain.schema.runnable.utils import AddableDict
from langchain.utils.aiter import atee, py_anext
@ -33,7 +37,7 @@ async def aidentity(x: Input) -> Input:
return x
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
class RunnablePassthrough(RunnableSerializable[Input, Input]):
"""
A runnable that passes through the input.
"""
@ -109,7 +113,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
yield chunk
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]):
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
"""

@ -14,8 +14,13 @@ from typing import (
from typing_extensions import TypedDict
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable
from langchain.schema.runnable.base import (
Input,
Output,
Runnable,
RunnableSerializable,
coerce_to_runnable,
)
from langchain.schema.runnable.config import (
RunnableConfig,
get_config_list,
@ -36,7 +41,7 @@ class RouterInput(TypedDict):
input: Any
class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.

@ -25,7 +25,7 @@ from langchain.pydantic_v1 import (
root_validator,
validate_arguments,
)
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable
class SchemaAnnotationError(TypeError):
@ -97,7 +97,7 @@ class ToolException(Exception):
pass
class BaseTool(BaseModel, Runnable[Union[str, Dict], Any]):
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
"""Interface LangChain tools must implement."""
def __init_subclass__(cls, **kwargs: Any) -> None:
@ -168,7 +168,6 @@ class ChildTool(BaseTool):
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property

@ -21,14 +21,6 @@ class BaseSparkSQLTool(BaseModel):
db: SparkSQL = Field(exclude=True)
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
extra = Extra.forbid
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for querying a Spark SQL."""

@ -21,14 +21,6 @@ class BaseSQLDatabaseTool(BaseModel):
db: SQLDatabase = Field(exclude=True)
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
extra = Extra.forbid
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""

@ -17,11 +17,6 @@ class BaseVectorStoreTool(BaseModel):
vectorstore: VectorStore = Field(exclude=True)
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
values["description"] = values["template"].format(name=values["name"])

Loading…
Cancel
Save