core[patch]: docstrings `langchain_core/` files update (#24285)

Added missed docstrings. Formatted docstrings to the consistent form.
pull/24307/head
Leonid Ganeline 2 months ago committed by GitHub
parent 7aeaa1974d
commit 198b85334f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -65,12 +65,15 @@ class AgentAction(Serializable):
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
"""Return whether or not the class is serializable.
Default is True.
"""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "agent"]."""
return ["langchain", "schema", "agent"]
@property
@ -80,7 +83,7 @@ class AgentAction(Serializable):
class AgentActionMessageLog(AgentAction):
"""A representation of an action to be executed by an agent.
"""Representation of an action to be executed by an agent.
This is similar to AgentAction, but includes a message log consisting of
chat messages. This is useful when working with ChatModels, and is used
@ -102,7 +105,7 @@ class AgentActionMessageLog(AgentAction):
class AgentStep(Serializable):
"""The result of running an AgentAction."""
"""Result of running an AgentAction."""
action: AgentAction
"""The AgentAction that was executed."""
@ -111,12 +114,12 @@ class AgentStep(Serializable):
@property
def messages(self) -> Sequence[BaseMessage]:
"""Return the messages that correspond to this observation."""
"""Messages that correspond to this observation."""
return _convert_agent_observation_to_messages(self.action, self.observation)
class AgentFinish(Serializable):
"""The final return value of an ActionAgent.
"""Final return value of an ActionAgent.
Agents return an AgentFinish when they have reached a stopping condition.
"""
@ -148,7 +151,7 @@ class AgentFinish(Serializable):
@property
def messages(self) -> Sequence[BaseMessage]:
"""Return the messages that correspond to this observation."""
"""Messages that correspond to this observation."""
return [AIMessage(content=self.log)]
@ -180,6 +183,7 @@ def _convert_agent_observation_to_messages(
Args:
agent_action: Agent action to convert.
observation: Observation to convert to a message.
Returns:
AIMessage that corresponds to the original tool invocation.
@ -196,11 +200,11 @@ def _create_function_message(
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
agent_action: the tool invocation request from the agent.
observation: the result of the tool invocation.
Returns:
FunctionMessage that corresponds to the original tool invocation
FunctionMessage that corresponds to the original tool invocation.
"""
if not isinstance(observation, str):
try:

@ -44,7 +44,7 @@ class BaseCache(ABC):
The default implementation of the async methods is to run the synchronous
method in an executor. It's recommended to override the async methods
and provide an async implementations to avoid unnecessary overhead.
and provide async implementations to avoid unnecessary overhead.
"""
@abstractmethod
@ -56,8 +56,8 @@ class BaseCache(ABC):
Args:
prompt: a string representation of the prompt.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
llm_string: A string representation of the LLM configuration.
This is used to capture the invocation parameters of the LLM
(e.g., model name, temperature, stop tokens, max tokens, etc.).
@ -78,8 +78,8 @@ class BaseCache(ABC):
Args:
prompt: a string representation of the prompt.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
llm_string: A string representation of the LLM configuration.
This is used to capture the invocation parameters of the LLM
(e.g., model name, temperature, stop tokens, max tokens, etc.).
@ -101,8 +101,8 @@ class BaseCache(ABC):
Args:
prompt: a string representation of the prompt.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
llm_string: A string representation of the LLM configuration.
This is used to capture the invocation parameters of the LLM
(e.g., model name, temperature, stop tokens, max tokens, etc.).
@ -125,8 +125,8 @@ class BaseCache(ABC):
Args:
prompt: a string representation of the prompt.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
In the case of a Chat model, the prompt is a non-trivial
serialization of the prompt into the language model.
llm_string: A string representation of the LLM configuration.
This is used to capture the invocation parameters of the LLM
(e.g., model name, temperature, stop tokens, max tokens, etc.).
@ -152,6 +152,10 @@ class InMemoryCache(BaseCache):
maxsize: The maximum number of items to store in the cache.
If None, the cache has no maximum size.
If the cache exceeds the maximum size, the oldest items are removed.
Default is None.
Raises:
ValueError: If maxsize is less than or equal to 0.
"""
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
if maxsize is not None and maxsize <= 0:

@ -116,7 +116,7 @@ class BaseChatMessageHistory(ABC):
This method may be deprecated in a future release.
Args:
message: The human message to add
message: The human message to add to the store.
"""
if isinstance(message, HumanMessage):
self.add_message(message)
@ -200,22 +200,38 @@ class BaseChatMessageHistory(ABC):
class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history.
Stores messages in an in memory list.
Stores messages in a memory list.
"""
messages: List[BaseMessage] = Field(default_factory=list)
"""A list of messages stored in memory."""
async def aget_messages(self) -> List[BaseMessage]:
"""Async version of getting messages."""
"""Async version of getting messages.
Can over-ride this method to provide an efficient async implementation.
In general, fetching messages may involve IO to the underlying
persistence layer.
Returns:
List of messages.
"""
return self.messages
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store."""
"""Add a self-created message to the store.
Args:
message: The message to add.
"""
self.messages.append(message)
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Async add messages to the store"""
"""Async add messages to the store.
Args:
messages: The messages to add.
"""
self.add_messages(messages)
def clear(self) -> None:

@ -4,7 +4,11 @@ from functools import lru_cache
@lru_cache(maxsize=1)
def get_runtime_environment() -> dict:
"""Get information about the LangChain runtime environment."""
"""Get information about the LangChain runtime environment.
Returns:
A dictionary with information about the runtime environment.
"""
# Lazy import to avoid circular imports
from langchain_core import __version__

@ -22,13 +22,14 @@ class OutputParserException(ValueError, LangChainException):
Parameters:
error: The error that's being re-raised or an error message.
observation: String explanation of error which can be passed to a
model to try and remediate the issue.
model to try and remediate the issue. Defaults to None.
llm_output: String model output which is error-ing.
Defaults to None.
send_to_llm: Whether to send the observation and llm_output back to an Agent
after an OutputParserException has been raised. This gives the underlying
model driving the agent the context that the previous output was improperly
structured, in the hopes that it will update the output to the correct
format.
format. Defaults to False.
"""
def __init__(

@ -18,7 +18,11 @@ _llm_cache: Optional["BaseCache"] = None
def set_verbose(value: bool) -> None:
"""Set a new value for the `verbose` global setting."""
"""Set a new value for the `verbose` global setting.
Args:
value: The new value for the `verbose` global setting.
"""
try:
import langchain # type: ignore[import]
@ -46,7 +50,11 @@ def set_verbose(value: bool) -> None:
def get_verbose() -> bool:
"""Get the value of the `verbose` global setting."""
"""Get the value of the `verbose` global setting.
Returns:
The value of the `verbose` global setting.
"""
try:
import langchain # type: ignore[import]
@ -79,7 +87,11 @@ def get_verbose() -> bool:
def set_debug(value: bool) -> None:
"""Set a new value for the `debug` global setting."""
"""Set a new value for the `debug` global setting.
Args:
value: The new value for the `debug` global setting.
"""
try:
import langchain # type: ignore[import]
@ -105,7 +117,11 @@ def set_debug(value: bool) -> None:
def get_debug() -> bool:
"""Get the value of the `debug` global setting."""
"""Get the value of the `debug` global setting.
Returns:
The value of the `debug` global setting.
"""
try:
import langchain # type: ignore[import]
@ -168,7 +184,11 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
def get_llm_cache() -> "BaseCache":
"""Get the value of the `llm_cache` global setting."""
"""Get the value of the `llm_cache` global setting.
Returns:
The value of the `llm_cache` global setting.
"""
try:
import langchain # type: ignore[import]

@ -62,13 +62,20 @@ class BaseMemory(Serializable, ABC):
"""Return key-value pairs given the text input to the chain.
Args:
inputs: The inputs to the chain."""
inputs: The inputs to the chain.
Returns:
A dictionary of key-value pairs.
"""
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Async return key-value pairs given the text input to the chain.
Args:
inputs: The inputs to the chain.
Returns:
A dictionary of key-value pairs.
"""
return await run_in_executor(None, self.load_memory_variables, inputs)

@ -24,17 +24,20 @@ class PromptValue(Serializable, ABC):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
ChatModel inputs.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
"""Return whether this class is serializable. Defaults to True."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "schema", "prompt"].
"""
return ["langchain", "schema", "prompt"]
@abstractmethod
@ -55,7 +58,10 @@ class StringPromptValue(PromptValue):
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "base"].
"""
return ["langchain", "prompts", "base"]
def to_string(self) -> str:
@ -86,7 +92,10 @@ class ChatPromptValue(PromptValue):
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "chat"].
"""
return ["langchain", "prompts", "chat"]
@ -94,7 +103,8 @@ class ImageURL(TypedDict, total=False):
"""Image URL."""
detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image."""
"""Specifies the detail level of the image. Defaults to "auto".
Can be "auto", "low", or "high"."""
url: str
"""Either a URL of the image or the base64 encoded image data."""
@ -127,5 +137,8 @@ class ChatPromptValueConcrete(ChatPromptValue):
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "chat"].
"""
return ["langchain", "prompts", "chat"]

@ -53,14 +53,13 @@ RetrieverOutputLike = Runnable[Any, RetrieverOutput]
class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
the most 'relevant' Documents from some source.
Usage:
A retriever follows the standard Runnable interface, and should be used
via the standard runnable methods of `invoke`, `ainvoke`, `batch`, `abatch`.
via the standard Runnable methods of `invoke`, `ainvoke`, `batch`, `abatch`.
Implementation:
@ -89,7 +88,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
\"\"\"(Optional) async native implementation.\"\"\"
return self.docs[:self.k]
Example: A simple retriever based on a scitkit learn vectorizer
Example: A simple retriever based on a scikit-learn vectorizer
.. code-block:: python
@ -178,12 +177,12 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
Main entry point for synchronous retriever invocations.
Args:
input: The query string
config: Configuration for the retriever
**kwargs: Additional arguments to pass to the retriever
input: The query string.
config: Configuration for the retriever. Defaults to None.
**kwargs: Additional arguments to pass to the retriever.
Returns:
List of relevant documents
List of relevant documents.
Examples:
@ -237,12 +236,12 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
Main entry point for asynchronous retriever invocations.
Args:
input: The query string
config: Configuration for the retriever
**kwargs: Additional arguments to pass to the retriever
input: The query string.
config: Configuration for the retriever. Defaults to None.
**kwargs: Additional arguments to pass to the retriever.
Returns:
List of relevant documents
List of relevant documents.
Examples:
@ -292,10 +291,10 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callback handler to use
query: String to find relevant documents for.
run_manager: The callback handler to use.
Returns:
List of relevant documents
List of relevant documents.
"""
async def _aget_relevant_documents(
@ -333,18 +332,21 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
`get_relevant_documents directly`.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
query: string to find relevant documents for.
callbacks: Callback manager or list of callbacks. Defaults to None.
tags: Optional list of tags associated with the retriever.
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
Defaults to None.
metadata: Optional metadata associated with the retriever.
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
run_name: Optional name for the run.
Defaults to None.
run_name: Optional name for the run. Defaults to None.
**kwargs: Additional arguments to pass to the retriever.
Returns:
List of relevant documents
List of relevant documents.
"""
config: RunnableConfig = {}
if callbacks:
@ -374,18 +376,21 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
`aget_relevant_documents directly`.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
query: string to find relevant documents for.
callbacks: Callback manager or list of callbacks.
tags: Optional list of tags associated with the retriever.
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
Defaults to None.
metadata: Optional metadata associated with the retriever.
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
run_name: Optional name for the run.
Defaults to None.
run_name: Optional name for the run. Defaults to None.
**kwargs: Additional arguments to pass to the retriever.
Returns:
List of relevant documents
List of relevant documents.
"""
config: RunnableConfig = {}
if callbacks:

@ -150,7 +150,6 @@ class BaseStore(Generic[K, V], ABC):
Yields:
Iterator[K | str]: An iterator over keys that match the given prefix.
This method is allowed to return an iterator over either K or str
depending on what makes more sense for the given store.
"""
@ -165,7 +164,6 @@ class BaseStore(Generic[K, V], ABC):
Yields:
Iterator[K | str]: An iterator over keys that match the given prefix.
This method is allowed to return an iterator over either K or str
depending on what makes more sense for the given store.
"""

@ -10,10 +10,12 @@ from langchain_core.pydantic_v1 import BaseModel
class Visitor(ABC):
"""Defines interface for IR translation using visitor pattern."""
"""Defines interface for IR translation using a visitor pattern."""
allowed_comparators: Optional[Sequence[Comparator]] = None
"""Allowed comparators for the visitor."""
allowed_operators: Optional[Sequence[Operator]] = None
"""Allowed operators for the visitor."""
def _validate_func(self, func: Union[Operator, Comparator]) -> None:
if isinstance(func, Operator) and self.allowed_operators is not None:
@ -31,15 +33,27 @@ class Visitor(ABC):
@abstractmethod
def visit_operation(self, operation: Operation) -> Any:
"""Translate an Operation."""
"""Translate an Operation.
Args:
operation: Operation to translate.
"""
@abstractmethod
def visit_comparison(self, comparison: Comparison) -> Any:
"""Translate a Comparison."""
"""Translate a Comparison.
Args:
comparison: Comparison to translate.
"""
@abstractmethod
def visit_structured_query(self, structured_query: StructuredQuery) -> Any:
"""Translate a StructuredQuery."""
"""Translate a StructuredQuery.
Args:
structured_query: StructuredQuery to translate.
"""
def _to_snake_case(name: str) -> str:
@ -60,10 +74,10 @@ class Expr(BaseModel):
"""Accept a visitor.
Args:
visitor: visitor to accept
visitor: visitor to accept.
Returns:
result of visiting
result of visiting.
"""
return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")(
self
@ -98,7 +112,13 @@ class FilterDirective(Expr, ABC):
class Comparison(FilterDirective):
"""Comparison to a value."""
"""Comparison to a value.
Parameters:
comparator: The comparator to use.
attribute: The attribute to compare.
value: The value to compare to.
"""
comparator: Comparator
attribute: str
@ -113,7 +133,12 @@ class Comparison(FilterDirective):
class Operation(FilterDirective):
"""Llogical operation over other directives."""
"""Logical operation over other directives.
Parameters:
operator: The operator to use.
arguments: The arguments to the operator.
"""
operator: Operator
arguments: List[FilterDirective]

@ -6,7 +6,11 @@ from typing import Sequence
def print_sys_info(*, additional_pkgs: Sequence[str] = tuple()) -> None:
"""Print information about the environment for debugging purposes."""
"""Print information about the environment for debugging purposes.
Args:
additional_pkgs: Additional packages to include in the output.
"""
import pkgutil
import platform
import sys

@ -249,7 +249,16 @@ def _infer_arg_descriptions(
class _SchemaConfig:
"""Configuration for the pydantic model."""
"""Configuration for the pydantic model.
This is used to configure the pydantic model created from
a function's signature.
Parameters:
extra: Whether to allow extra fields in the model.
arbitrary_types_allowed: Whether to allow arbitrary types in the model.
Defaults to True.
"""
extra: Any = Extra.forbid
arbitrary_types_allowed: bool = True
@ -265,15 +274,18 @@ def create_schema_from_function(
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from
filter_args: Optional list of arguments to exclude from the schema
model_name: Name to assign to the generated pydantic schema.
func: Function to generate the schema from.
filter_args: Optional list of arguments to exclude from the schema.
Defaults to FILTERED_ARGS.
parse_docstring: Whether to parse the function's docstring for descriptions
for each argument.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
for each argument. Defaults to False.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to False.
Returns:
A pydantic model with the same arguments as the function
A pydantic model with the same arguments as the function.
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
@ -348,8 +360,9 @@ class ChildTool(BaseTool):
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means
"""Whether to return the tool's output directly.
Setting this to True means
that after the tool is called, the AgentExecutor will stop looping.
"""
verbose: bool = False
@ -360,13 +373,13 @@ class ChildTool(BaseTool):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Deprecated. Please use callbacks instead."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the tool. Defaults to None
"""Optional list of tags associated with the tool. Defaults to None.
These tags will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the tool. Defaults to None
"""Optional metadata associated with the tool. Defaults to None.
This metadata will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case.
@ -383,7 +396,7 @@ class ChildTool(BaseTool):
"""Handle the content of the ValidationError thrown."""
response_format: Literal["content", "content_and_artifact"] = "content"
"""The tool response format.
"""The tool response format. Defaults to 'content'.
If "content" then the output of the tool is interpreted as the contents of a
ToolMessage. If "content_and_artifact" then the output is expected to be a
@ -414,7 +427,14 @@ class ChildTool(BaseTool):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The tool's input schema."""
"""The tool's input schema.
Args:
config: The configuration for the tool.
Returns:
The input schema for the tool.
"""
if self.args_schema is not None:
return self.args_schema
else:
@ -441,7 +461,11 @@ class ChildTool(BaseTool):
# --- Tool ---
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
"""Convert tool input to a pydantic model.
Args:
tool_input: The input to the tool.
"""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
@ -460,7 +484,14 @@ class ChildTool(BaseTool):
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
"""Raise deprecation warning if callback_manager is used.
Args:
values: The values to validate.
Returns:
The validated values.
"""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
@ -514,7 +545,28 @@ class ChildTool(BaseTool):
tool_call_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
"""Run the tool.
Args:
tool_input: The input to the tool.
verbose: Whether to log the tool's progress. Defaults to None.
start_color: The color to use when starting the tool. Defaults to 'green'.
color: The color to use when ending the tool. Defaults to 'green'.
callbacks: Callbacks to be called during tool execution. Defaults to None.
tags: Optional list of tags associated with the tool. Defaults to None.
metadata: Optional metadata associated with the tool. Defaults to None.
run_name: The name of the run. Defaults to None.
run_id: The id of the run. Defaults to None.
config: The configuration for the tool. Defaults to None.
tool_call_id: The id of the tool call. Defaults to None.
kwargs: Additional arguments to pass to the tool
Returns:
The output of the tool.
Raises:
ToolException: If an error occurs during tool execution.
"""
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
@ -600,7 +652,28 @@ class ChildTool(BaseTool):
tool_call_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
"""Run the tool asynchronously.
Args:
tool_input: The input to the tool.
verbose: Whether to log the tool's progress. Defaults to None.
start_color: The color to use when starting the tool. Defaults to 'green'.
color: The color to use when ending the tool. Defaults to 'green'.
callbacks: Callbacks to be called during tool execution. Defaults to None.
tags: Optional list of tags associated with the tool. Defaults to None.
metadata: Optional metadata associated with the tool. Defaults to None.
run_name: The name of the run. Defaults to None.
run_id: The id of the run. Defaults to None.
config: The configuration for the tool. Defaults to None.
tool_call_id: The id of the tool call. Defaults to None.
kwargs: Additional arguments to pass to the tool
Returns:
The output of the tool.
Raises:
ToolException: If an error occurs during tool execution.
"""
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
@ -709,7 +782,11 @@ class Tool(BaseTool):
@property
def args(self) -> dict:
"""The tool's input arguments."""
"""The tool's input arguments.
Returns:
The input arguments for the tool.
"""
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
# For backwards compatibility, if the function signature is ambiguous,
@ -788,7 +865,23 @@ class Tool(BaseTool):
] = None, # This is last for compatibility, but should be after func
**kwargs: Any,
) -> Tool:
"""Initialize tool from a function."""
"""Initialize tool from a function.
Args:
func: The function to create the tool from.
name: The name of the tool.
description: The description of the tool.
return_direct: Whether to return the output directly. Defaults to False.
args_schema: The schema of the tool's input arguments. Defaults to None.
coroutine: The asynchronous version of the function. Defaults to None.
**kwargs: Additional arguments to pass to the tool.
Returns:
The tool.
Raises:
ValueError: If the function is not provided.
"""
if func is None and coroutine is None:
raise ValueError("Function and/or coroutine must be provided")
return cls(
@ -893,25 +986,34 @@ class StructuredTool(BaseTool):
A classmethod that helps to create a tool from a function.
Args:
func: The function from which to create a tool
coroutine: The async function from which to create a tool
name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback
args_schema: The schema of the tool's input arguments
infer_schema: Whether to infer the schema from the function's signature
func: The function from which to create a tool.
coroutine: The async function from which to create a tool.
name: The name of the tool. Defaults to the function name.
description: The description of the tool.
Defaults to the function docstring.
return_direct: Whether to return the result directly or as a callback.
Defaults to False.
args_schema: The schema of the tool's input arguments. Defaults to None.
infer_schema: Whether to infer the schema from the function's signature.
Defaults to True.
response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If
"content_and_artifact" then the output is expected to be a two-tuple
corresponding to the (content, artifact) of a ToolMessage.
Defaults to "content".
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
to parse parameter descriptions from Google Style function docstrings.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
Defaults to False.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to False.
**kwargs: Additional arguments to pass to the tool
Returns:
The tool
The tool.
Raises:
ValueError: If the function is not provided.
Examples:
@ -989,19 +1091,27 @@ def tool(
Args:
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
than continuing the agent loop. Defaults to False.
args_schema: optional argument schema for user to specify.
Defaults to None.
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
Defaults to True.
response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If
"content_and_artifact" then the output is expected to be a two-tuple
corresponding to the (content, artifact) of a ToolMessage.
Defaults to "content".
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
parse parameter descriptions from Google Style function docstrings.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
Defaults to False.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
Returns:
The tool.
Requires:
- Function must be of type (str) -> str
@ -1230,9 +1340,11 @@ def create_retriever_tool(
so should be unique and somewhat descriptive.
description: The description for the tool. This will be passed to the language
model, so should be descriptive.
document_prompt: The prompt to use for the document. Defaults to None.
document_separator: The separator to use between documents. Defaults to "\n\n".
Returns:
Tool class to pass to an agent
Tool class to pass to an agent.
"""
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
func = partial(
@ -1262,6 +1374,12 @@ ToolsRenderer = Callable[[List[BaseTool]], str]
def render_text_description(tools: List[BaseTool]) -> str:
"""Render the tool name and description in plain text.
Args:
tools: The tools to render.
Returns:
The rendered text.
Output will be in the format of:
.. code-block:: markdown
@ -1284,6 +1402,12 @@ def render_text_description(tools: List[BaseTool]) -> str:
def render_text_description_and_args(tools: List[BaseTool]) -> str:
"""Render the tool name, description, and args in plain text.
Args:
tools: The tools to render.
Returns:
The rendered text.
Output will be in the format of:
.. code-block:: markdown
@ -1444,7 +1568,18 @@ def convert_runnable_to_tool(
description: Optional[str] = None,
arg_types: Optional[Dict[str, Type]] = None,
) -> BaseTool:
"""Convert a Runnable into a BaseTool."""
"""Convert a Runnable into a BaseTool.
Args:
runnable: The runnable to convert.
args_schema: The schema for the tool's input arguments. Defaults to None.
name: The name of the tool. Defaults to None.
description: The description of the tool. Defaults to None.
arg_types: The types of the arguments. Defaults to None.
Returns:
The tool.
"""
if args_schema:
runnable = runnable.with_types(input_type=args_schema)
description = description or _get_description_from_runnable(runnable)

Loading…
Cancel
Save