mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
df99b832a7
![Screenshot 2024-08-02 at 4 23 17 PM](https://github.com/user-attachments/assets/c757e093-877e-4af6-9dcd-984195454158)
1750 lines
61 KiB
Python
1750 lines
61 KiB
Python
"""**Tools** are classes that an Agent uses to interact with the world.
|
|
|
|
Each tool has a **description**. Agent uses the description to choose the right
|
|
tool for the job.
|
|
|
|
**Class hierarchy:**
|
|
|
|
.. code-block::
|
|
|
|
RunnableSerializable --> BaseTool --> <name>Tool # Examples: AIPluginTool, BaseGraphQLTool
|
|
<name> # Examples: BraveSearch, HumanInputRun
|
|
|
|
**Main helpers:**
|
|
|
|
.. code-block::
|
|
|
|
CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
|
|
""" # noqa: E501
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import functools
|
|
import inspect
|
|
import json
|
|
import textwrap
|
|
import uuid
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from contextvars import copy_context
|
|
from functools import partial
|
|
from inspect import signature
|
|
from typing import (
|
|
Any,
|
|
Awaitable,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
get_type_hints,
|
|
)
|
|
|
|
from typing_extensions import Annotated, TypeVar, cast, get_args, get_origin
|
|
|
|
from langchain_core._api import deprecated
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManager,
|
|
AsyncCallbackManagerForToolRun,
|
|
BaseCallbackManager,
|
|
CallbackManager,
|
|
CallbackManagerForToolRun,
|
|
)
|
|
from langchain_core.callbacks.manager import (
|
|
Callbacks,
|
|
)
|
|
from langchain_core.load.serializable import Serializable
|
|
from langchain_core.messages.tool import ToolCall, ToolMessage
|
|
from langchain_core.prompts import (
|
|
BasePromptTemplate,
|
|
PromptTemplate,
|
|
aformat_document,
|
|
format_document,
|
|
)
|
|
from langchain_core.pydantic_v1 import (
|
|
BaseModel,
|
|
Extra,
|
|
Field,
|
|
ValidationError,
|
|
create_model,
|
|
root_validator,
|
|
validate_arguments,
|
|
)
|
|
from langchain_core.retrievers import BaseRetriever
|
|
from langchain_core.runnables import (
|
|
Runnable,
|
|
RunnableConfig,
|
|
RunnableSerializable,
|
|
ensure_config,
|
|
)
|
|
from langchain_core.runnables.config import (
|
|
_set_config_context,
|
|
patch_config,
|
|
run_in_executor,
|
|
)
|
|
from langchain_core.runnables.utils import asyncio_accepts_context
|
|
from langchain_core.utils.function_calling import (
|
|
_parse_google_docstring,
|
|
_py_38_safe_origin,
|
|
)
|
|
from langchain_core.utils.pydantic import (
|
|
TypeBaseModel,
|
|
_create_subset_model,
|
|
is_basemodel_subclass,
|
|
is_pydantic_v1_subclass,
|
|
is_pydantic_v2_subclass,
|
|
)
|
|
|
|
FILTERED_ARGS = ("run_manager", "callbacks")
|
|
|
|
|
|
class SchemaAnnotationError(TypeError):
|
|
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
|
|
|
|
|
def _is_annotated_type(typ: Type[Any]) -> bool:
|
|
return get_origin(typ) is Annotated
|
|
|
|
|
|
def _get_annotation_description(arg_type: Type) -> str | None:
|
|
if _is_annotated_type(arg_type):
|
|
annotated_args = get_args(arg_type)
|
|
for annotation in annotated_args[1:]:
|
|
if isinstance(annotation, str):
|
|
return annotation
|
|
return None
|
|
|
|
|
|
def _get_filtered_args(
|
|
inferred_model: Type[BaseModel],
|
|
func: Callable,
|
|
*,
|
|
filter_args: Sequence[str],
|
|
include_injected: bool = True,
|
|
) -> dict:
|
|
"""Get the arguments from a function's signature."""
|
|
schema = inferred_model.schema()["properties"]
|
|
valid_keys = signature(func).parameters
|
|
return {
|
|
k: schema[k]
|
|
for i, (k, param) in enumerate(valid_keys.items())
|
|
if k not in filter_args
|
|
and (i > 0 or param.name not in ("self", "cls"))
|
|
and (include_injected or not _is_injected_arg_type(param.annotation))
|
|
}
|
|
|
|
|
|
def _parse_python_function_docstring(
|
|
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
|
|
) -> Tuple[str, dict]:
|
|
"""Parse the function and argument descriptions from the docstring of a function.
|
|
|
|
Assumes the function docstring follows Google Python style guide.
|
|
"""
|
|
docstring = inspect.getdoc(function)
|
|
return _parse_google_docstring(
|
|
docstring,
|
|
list(annotations),
|
|
error_on_invalid_docstring=error_on_invalid_docstring,
|
|
)
|
|
|
|
|
|
def _validate_docstring_args_against_annotations(
|
|
arg_descriptions: dict, annotations: dict
|
|
) -> None:
|
|
"""Raise error if docstring arg is not in type annotations."""
|
|
for docstring_arg in arg_descriptions:
|
|
if docstring_arg not in annotations:
|
|
raise ValueError(
|
|
f"Arg {docstring_arg} in docstring not found in function signature."
|
|
)
|
|
|
|
|
|
def _infer_arg_descriptions(
|
|
fn: Callable,
|
|
*,
|
|
parse_docstring: bool = False,
|
|
error_on_invalid_docstring: bool = False,
|
|
) -> Tuple[str, dict]:
|
|
"""Infer argument descriptions from a function's docstring."""
|
|
if hasattr(inspect, "get_annotations"):
|
|
# This is for python < 3.10
|
|
annotations = inspect.get_annotations(fn) # type: ignore
|
|
else:
|
|
annotations = getattr(fn, "__annotations__", {})
|
|
if parse_docstring:
|
|
description, arg_descriptions = _parse_python_function_docstring(
|
|
fn, annotations, error_on_invalid_docstring=error_on_invalid_docstring
|
|
)
|
|
else:
|
|
description = inspect.getdoc(fn) or ""
|
|
arg_descriptions = {}
|
|
if parse_docstring:
|
|
_validate_docstring_args_against_annotations(arg_descriptions, annotations)
|
|
for arg, arg_type in annotations.items():
|
|
if arg in arg_descriptions:
|
|
continue
|
|
if desc := _get_annotation_description(arg_type):
|
|
arg_descriptions[arg] = desc
|
|
return description, arg_descriptions
|
|
|
|
|
|
class _SchemaConfig:
|
|
"""Configuration for the pydantic model.
|
|
|
|
This is used to configure the pydantic model created from
|
|
a function's signature.
|
|
|
|
Parameters:
|
|
extra: Whether to allow extra fields in the model.
|
|
arbitrary_types_allowed: Whether to allow arbitrary types in the model.
|
|
Defaults to True.
|
|
"""
|
|
|
|
extra: Any = Extra.forbid
|
|
arbitrary_types_allowed: bool = True
|
|
|
|
|
|
def create_schema_from_function(
|
|
model_name: str,
|
|
func: Callable,
|
|
*,
|
|
filter_args: Optional[Sequence[str]] = None,
|
|
parse_docstring: bool = False,
|
|
error_on_invalid_docstring: bool = False,
|
|
include_injected: bool = True,
|
|
) -> Type[BaseModel]:
|
|
"""Create a pydantic schema from a function's signature.
|
|
|
|
Args:
|
|
model_name: Name to assign to the generated pydantic schema.
|
|
func: Function to generate the schema from.
|
|
filter_args: Optional list of arguments to exclude from the schema.
|
|
Defaults to FILTERED_ARGS.
|
|
parse_docstring: Whether to parse the function's docstring for descriptions
|
|
for each argument. Defaults to False.
|
|
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
|
whether to raise ValueError on invalid Google Style docstrings.
|
|
Defaults to False.
|
|
include_injected: Whether to include injected arguments in the schema.
|
|
Defaults to True, since we want to include them in the schema
|
|
when *validating* tool inputs.
|
|
|
|
Returns:
|
|
A pydantic model with the same arguments as the function.
|
|
"""
|
|
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
|
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
|
inferred_model = validated.model # type: ignore
|
|
filter_args = filter_args if filter_args is not None else FILTERED_ARGS
|
|
for arg in filter_args:
|
|
if arg in inferred_model.__fields__:
|
|
del inferred_model.__fields__[arg]
|
|
description, arg_descriptions = _infer_arg_descriptions(
|
|
func,
|
|
parse_docstring=parse_docstring,
|
|
error_on_invalid_docstring=error_on_invalid_docstring,
|
|
)
|
|
# Pydantic adds placeholder virtual fields we need to strip
|
|
valid_properties = _get_filtered_args(
|
|
inferred_model, func, filter_args=filter_args, include_injected=include_injected
|
|
)
|
|
return _create_subset_model(
|
|
f"{model_name}Schema",
|
|
inferred_model,
|
|
list(valid_properties),
|
|
descriptions=arg_descriptions,
|
|
fn_description=description,
|
|
)
|
|
|
|
|
|
class ToolException(Exception):
|
|
"""Optional exception that tool throws when execution error occurs.
|
|
|
|
When this exception is thrown, the agent will not stop working,
|
|
but it will handle the exception according to the handle_tool_error
|
|
variable of the tool, and the processing result will be returned
|
|
to the agent as observation, and printed in red on the console.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
|
|
"""Interface LangChain tools must implement."""
|
|
|
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
"""Create the definition of the new tool class."""
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
args_schema_type = cls.__annotations__.get("args_schema", None)
|
|
|
|
if args_schema_type is not None and args_schema_type == BaseModel:
|
|
# Throw errors for common mis-annotations.
|
|
# TODO: Use get_args / get_origin and fully
|
|
# specify valid annotations.
|
|
typehint_mandate = """
|
|
class ChildTool(BaseTool):
|
|
...
|
|
args_schema: Type[BaseModel] = SchemaClass
|
|
..."""
|
|
name = cls.__name__
|
|
raise SchemaAnnotationError(
|
|
f"Tool definition for {name} must include valid type annotations"
|
|
f" for argument 'args_schema' to behave as expected.\n"
|
|
f"Expected annotation of 'Type[BaseModel]'"
|
|
f" but got '{args_schema_type}'.\n"
|
|
f"Expected class looks like:\n"
|
|
f"{typehint_mandate}"
|
|
)
|
|
|
|
name: str
|
|
"""The unique name of the tool that clearly communicates its purpose."""
|
|
description: str
|
|
"""Used to tell the model how/when/why to use the tool.
|
|
|
|
You can provide few-shot examples as a part of the description.
|
|
"""
|
|
args_schema: Optional[TypeBaseModel] = None
|
|
"""Pydantic model class to validate and parse the tool's input arguments.
|
|
|
|
Args schema should be either:
|
|
|
|
- A subclass of pydantic.BaseModel.
|
|
or
|
|
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
|
|
"""
|
|
return_direct: bool = False
|
|
"""Whether to return the tool's output directly.
|
|
|
|
Setting this to True means
|
|
that after the tool is called, the AgentExecutor will stop looping.
|
|
"""
|
|
verbose: bool = False
|
|
"""Whether to log the tool's progress."""
|
|
|
|
callbacks: Callbacks = Field(default=None, exclude=True)
|
|
"""Callbacks to be called during tool execution."""
|
|
|
|
callback_manager: Optional[BaseCallbackManager] = deprecated(
|
|
name="callback_manager", since="0.1.7", removal="0.3.0", alternative="callbacks"
|
|
)(
|
|
Field(
|
|
default=None,
|
|
exclude=True,
|
|
description="Callback manager to add to the run trace.",
|
|
)
|
|
)
|
|
tags: Optional[List[str]] = None
|
|
"""Optional list of tags associated with the tool. Defaults to None.
|
|
These tags will be associated with each call to this tool,
|
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
You can use these to eg identify a specific instance of a tool with its use case.
|
|
"""
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
"""Optional metadata associated with the tool. Defaults to None.
|
|
This metadata will be associated with each call to this tool,
|
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
You can use these to eg identify a specific instance of a tool with its use case.
|
|
"""
|
|
|
|
handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = (
|
|
False
|
|
)
|
|
"""Handle the content of the ToolException thrown."""
|
|
|
|
handle_validation_error: Optional[
|
|
Union[bool, str, Callable[[ValidationError], str]]
|
|
] = False
|
|
"""Handle the content of the ValidationError thrown."""
|
|
|
|
response_format: Literal["content", "content_and_artifact"] = "content"
|
|
"""The tool response format. Defaults to 'content'.
|
|
|
|
If "content" then the output of the tool is interpreted as the contents of a
|
|
ToolMessage. If "content_and_artifact" then the output is expected to be a
|
|
two-tuple corresponding to the (content, artifact) of a ToolMessage.
|
|
"""
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
"""Initialize the tool."""
|
|
if "args_schema" in kwargs and kwargs["args_schema"] is not None:
|
|
if not is_basemodel_subclass(kwargs["args_schema"]):
|
|
raise TypeError(
|
|
f"args_schema must be a subclass of pydantic BaseModel. "
|
|
f"Got: {kwargs['args_schema']}."
|
|
)
|
|
super().__init__(**kwargs)
|
|
|
|
class Config(Serializable.Config):
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def is_single_input(self) -> bool:
|
|
"""Whether the tool only accepts a single input."""
|
|
keys = {k for k in self.args if k != "kwargs"}
|
|
return len(keys) == 1
|
|
|
|
@property
|
|
def args(self) -> dict:
|
|
return self.get_input_schema().schema()["properties"]
|
|
|
|
@property
|
|
def tool_call_schema(self) -> Type[BaseModel]:
|
|
full_schema = self.get_input_schema()
|
|
fields = []
|
|
for name, type_ in _get_all_basemodel_annotations(full_schema).items():
|
|
if not _is_injected_arg_type(type_):
|
|
fields.append(name)
|
|
return _create_subset_model(
|
|
self.name, full_schema, fields, fn_description=self.description
|
|
)
|
|
|
|
# --- Runnable ---
|
|
|
|
def get_input_schema(
|
|
self, config: Optional[RunnableConfig] = None
|
|
) -> Type[BaseModel]:
|
|
"""The tool's input schema.
|
|
|
|
Args:
|
|
config: The configuration for the tool.
|
|
|
|
Returns:
|
|
The input schema for the tool.
|
|
"""
|
|
if self.args_schema is not None:
|
|
return self.args_schema
|
|
else:
|
|
return create_schema_from_function(self.name, self._run)
|
|
|
|
def invoke(
|
|
self,
|
|
input: Union[str, Dict, ToolCall],
|
|
config: Optional[RunnableConfig] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
|
return self.run(tool_input, **kwargs)
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: Union[str, Dict, ToolCall],
|
|
config: Optional[RunnableConfig] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
|
return await self.arun(tool_input, **kwargs)
|
|
|
|
# --- Tool ---
|
|
|
|
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
|
|
"""Convert tool input to a pydantic model.
|
|
|
|
Args:
|
|
tool_input: The input to the tool.
|
|
"""
|
|
input_args = self.args_schema
|
|
if isinstance(tool_input, str):
|
|
if input_args is not None:
|
|
key_ = next(iter(input_args.__fields__.keys()))
|
|
input_args.validate({key_: tool_input})
|
|
return tool_input
|
|
else:
|
|
if input_args is not None:
|
|
result = input_args.parse_obj(tool_input)
|
|
return {
|
|
k: getattr(result, k)
|
|
for k, v in result.dict().items()
|
|
if k in tool_input
|
|
}
|
|
return tool_input
|
|
|
|
@root_validator(pre=True)
|
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
"""Raise deprecation warning if callback_manager is used.
|
|
|
|
Args:
|
|
values: The values to validate.
|
|
|
|
Returns:
|
|
The validated values.
|
|
"""
|
|
if values.get("callback_manager") is not None:
|
|
warnings.warn(
|
|
"callback_manager is deprecated. Please use callbacks instead.",
|
|
DeprecationWarning,
|
|
)
|
|
values["callbacks"] = values.pop("callback_manager", None)
|
|
return values
|
|
|
|
@abstractmethod
|
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
|
"""Use the tool.
|
|
|
|
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
|
to child implementations to enable tracing.
|
|
"""
|
|
|
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
|
"""Use the tool asynchronously.
|
|
|
|
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
|
to child implementations to enable tracing.
|
|
"""
|
|
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
|
"run_manager"
|
|
):
|
|
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
|
return await run_in_executor(None, self._run, *args, **kwargs)
|
|
|
|
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
|
tool_input = self._parse_input(tool_input)
|
|
# For backwards compatibility, if run_input is a string,
|
|
# pass as a positional argument.
|
|
if isinstance(tool_input, str):
|
|
return (tool_input,), {}
|
|
else:
|
|
return (), tool_input
|
|
|
|
def run(
|
|
self,
|
|
tool_input: Union[str, Dict[str, Any]],
|
|
verbose: Optional[bool] = None,
|
|
start_color: Optional[str] = "green",
|
|
color: Optional[str] = "green",
|
|
callbacks: Callbacks = None,
|
|
*,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
run_name: Optional[str] = None,
|
|
run_id: Optional[uuid.UUID] = None,
|
|
config: Optional[RunnableConfig] = None,
|
|
tool_call_id: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run the tool.
|
|
|
|
Args:
|
|
tool_input: The input to the tool.
|
|
verbose: Whether to log the tool's progress. Defaults to None.
|
|
start_color: The color to use when starting the tool. Defaults to 'green'.
|
|
color: The color to use when ending the tool. Defaults to 'green'.
|
|
callbacks: Callbacks to be called during tool execution. Defaults to None.
|
|
tags: Optional list of tags associated with the tool. Defaults to None.
|
|
metadata: Optional metadata associated with the tool. Defaults to None.
|
|
run_name: The name of the run. Defaults to None.
|
|
run_id: The id of the run. Defaults to None.
|
|
config: The configuration for the tool. Defaults to None.
|
|
tool_call_id: The id of the tool call. Defaults to None.
|
|
kwargs: Additional arguments to pass to the tool
|
|
|
|
Returns:
|
|
The output of the tool.
|
|
|
|
Raises:
|
|
ToolException: If an error occurs during tool execution.
|
|
"""
|
|
callback_manager = CallbackManager.configure(
|
|
callbacks,
|
|
self.callbacks,
|
|
self.verbose or bool(verbose),
|
|
tags,
|
|
self.tags,
|
|
metadata,
|
|
self.metadata,
|
|
)
|
|
|
|
run_manager = callback_manager.on_tool_start(
|
|
{"name": self.name, "description": self.description},
|
|
tool_input if isinstance(tool_input, str) else str(tool_input),
|
|
color=start_color,
|
|
name=run_name,
|
|
run_id=run_id,
|
|
# Inputs by definition should always be dicts.
|
|
# For now, it's unclear whether this assumption is ever violated,
|
|
# but if it is we will send a `None` value to the callback instead
|
|
# TODO: will need to address issue via a patch.
|
|
inputs=tool_input if isinstance(tool_input, dict) else None,
|
|
**kwargs,
|
|
)
|
|
|
|
content = None
|
|
artifact = None
|
|
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
|
try:
|
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
|
context = copy_context()
|
|
context.run(_set_config_context, child_config)
|
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
|
if signature(self._run).parameters.get("run_manager"):
|
|
tool_kwargs["run_manager"] = run_manager
|
|
|
|
if config_param := _get_runnable_config_param(self._run):
|
|
tool_kwargs[config_param] = config
|
|
response = context.run(self._run, *tool_args, **tool_kwargs)
|
|
if self.response_format == "content_and_artifact":
|
|
if not isinstance(response, tuple) or len(response) != 2:
|
|
raise ValueError(
|
|
"Since response_format='content_and_artifact' "
|
|
"a two-tuple of the message content and raw tool output is "
|
|
f"expected. Instead generated response of type: "
|
|
f"{type(response)}."
|
|
)
|
|
content, artifact = response
|
|
else:
|
|
content = response
|
|
status = "success"
|
|
except ValidationError as e:
|
|
if not self.handle_validation_error:
|
|
error_to_raise = e
|
|
else:
|
|
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
|
status = "error"
|
|
except ToolException as e:
|
|
if not self.handle_tool_error:
|
|
error_to_raise = e
|
|
else:
|
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
|
status = "error"
|
|
except (Exception, KeyboardInterrupt) as e:
|
|
error_to_raise = e
|
|
status = "error"
|
|
|
|
if error_to_raise:
|
|
run_manager.on_tool_error(error_to_raise)
|
|
raise error_to_raise
|
|
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
|
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
|
return output
|
|
|
|
async def arun(
|
|
self,
|
|
tool_input: Union[str, Dict],
|
|
verbose: Optional[bool] = None,
|
|
start_color: Optional[str] = "green",
|
|
color: Optional[str] = "green",
|
|
callbacks: Callbacks = None,
|
|
*,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
run_name: Optional[str] = None,
|
|
run_id: Optional[uuid.UUID] = None,
|
|
config: Optional[RunnableConfig] = None,
|
|
tool_call_id: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run the tool asynchronously.
|
|
|
|
Args:
|
|
tool_input: The input to the tool.
|
|
verbose: Whether to log the tool's progress. Defaults to None.
|
|
start_color: The color to use when starting the tool. Defaults to 'green'.
|
|
color: The color to use when ending the tool. Defaults to 'green'.
|
|
callbacks: Callbacks to be called during tool execution. Defaults to None.
|
|
tags: Optional list of tags associated with the tool. Defaults to None.
|
|
metadata: Optional metadata associated with the tool. Defaults to None.
|
|
run_name: The name of the run. Defaults to None.
|
|
run_id: The id of the run. Defaults to None.
|
|
config: The configuration for the tool. Defaults to None.
|
|
tool_call_id: The id of the tool call. Defaults to None.
|
|
kwargs: Additional arguments to pass to the tool
|
|
|
|
Returns:
|
|
The output of the tool.
|
|
|
|
Raises:
|
|
ToolException: If an error occurs during tool execution.
|
|
"""
|
|
callback_manager = AsyncCallbackManager.configure(
|
|
callbacks,
|
|
self.callbacks,
|
|
self.verbose or bool(verbose),
|
|
tags,
|
|
self.tags,
|
|
metadata,
|
|
self.metadata,
|
|
)
|
|
run_manager = await callback_manager.on_tool_start(
|
|
{"name": self.name, "description": self.description},
|
|
tool_input if isinstance(tool_input, str) else str(tool_input),
|
|
color=start_color,
|
|
name=run_name,
|
|
run_id=run_id,
|
|
# Inputs by definition should always be dicts.
|
|
# For now, it's unclear whether this assumption is ever violated,
|
|
# but if it is we will send a `None` value to the callback instead
|
|
# TODO: will need to address issue via a patch.
|
|
inputs=tool_input if isinstance(tool_input, dict) else None,
|
|
**kwargs,
|
|
)
|
|
content = None
|
|
artifact = None
|
|
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
|
try:
|
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
|
context = copy_context()
|
|
context.run(_set_config_context, child_config)
|
|
func_to_check = (
|
|
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
|
)
|
|
if signature(func_to_check).parameters.get("run_manager"):
|
|
tool_kwargs["run_manager"] = run_manager
|
|
if config_param := _get_runnable_config_param(func_to_check):
|
|
tool_kwargs[config_param] = config
|
|
|
|
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
|
if asyncio_accepts_context():
|
|
response = await asyncio.create_task(coro, context=context) # type: ignore
|
|
else:
|
|
response = await coro
|
|
if self.response_format == "content_and_artifact":
|
|
if not isinstance(response, tuple) or len(response) != 2:
|
|
raise ValueError(
|
|
"Since response_format='content_and_artifact' "
|
|
"a two-tuple of the message content and raw tool output is "
|
|
f"expected. Instead generated response of type: "
|
|
f"{type(response)}."
|
|
)
|
|
content, artifact = response
|
|
else:
|
|
content = response
|
|
status = "success"
|
|
except ValidationError as e:
|
|
if not self.handle_validation_error:
|
|
error_to_raise = e
|
|
else:
|
|
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
|
status = "error"
|
|
except ToolException as e:
|
|
if not self.handle_tool_error:
|
|
error_to_raise = e
|
|
else:
|
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
|
status = "error"
|
|
except (Exception, KeyboardInterrupt) as e:
|
|
error_to_raise = e
|
|
status = "error"
|
|
|
|
if error_to_raise:
|
|
await run_manager.on_tool_error(error_to_raise)
|
|
raise error_to_raise
|
|
|
|
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
|
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
|
return output
|
|
|
|
@deprecated("0.1.47", alternative="invoke", removal="0.3.0")
|
|
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
|
"""Make tool callable."""
|
|
return self.run(tool_input, callbacks=callbacks)
|
|
|
|
|
|
class Tool(BaseTool):
|
|
"""Tool that takes in function or coroutine directly."""
|
|
|
|
description: str = ""
|
|
func: Optional[Callable[..., str]]
|
|
"""The function to run when the tool is called."""
|
|
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
|
"""The asynchronous version of the function."""
|
|
|
|
# --- Runnable ---
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: Union[str, Dict, ToolCall],
|
|
config: Optional[RunnableConfig] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
if not self.coroutine:
|
|
# If the tool does not implement async, fall back to default implementation
|
|
return await run_in_executor(config, self.invoke, input, config, **kwargs)
|
|
|
|
return await super().ainvoke(input, config, **kwargs)
|
|
|
|
# --- Tool ---
|
|
|
|
@property
|
|
def args(self) -> dict:
|
|
"""The tool's input arguments.
|
|
|
|
Returns:
|
|
The input arguments for the tool.
|
|
"""
|
|
if self.args_schema is not None:
|
|
return self.args_schema.schema()["properties"]
|
|
# For backwards compatibility, if the function signature is ambiguous,
|
|
# assume it takes a single string input.
|
|
return {"tool_input": {"type": "string"}}
|
|
|
|
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
|
"""Convert tool input to pydantic model."""
|
|
args, kwargs = super()._to_args_and_kwargs(tool_input)
|
|
# For backwards compatibility. The tool must be run with a single input
|
|
all_args = list(args) + list(kwargs.values())
|
|
if len(all_args) != 1:
|
|
raise ToolException(
|
|
f"""Too many arguments to single-input tool {self.name}.
|
|
Consider using StructuredTool instead."""
|
|
f" Args: {all_args}"
|
|
)
|
|
return tuple(all_args), {}
|
|
|
|
def _run(
|
|
self,
|
|
*args: Any,
|
|
config: RunnableConfig,
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Use the tool."""
|
|
if self.func:
|
|
if run_manager and signature(self.func).parameters.get("callbacks"):
|
|
kwargs["callbacks"] = run_manager.get_child()
|
|
if config_param := _get_runnable_config_param(self.func):
|
|
kwargs[config_param] = config
|
|
return self.func(*args, **kwargs)
|
|
raise NotImplementedError("Tool does not support sync invocation.")
|
|
|
|
async def _arun(
|
|
self,
|
|
*args: Any,
|
|
config: RunnableConfig,
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Use the tool asynchronously."""
|
|
if self.coroutine:
|
|
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
|
kwargs["callbacks"] = run_manager.get_child()
|
|
if config_param := _get_runnable_config_param(self.coroutine):
|
|
kwargs[config_param] = config
|
|
return await self.coroutine(*args, **kwargs)
|
|
|
|
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
|
# None.
|
|
return await super()._arun(
|
|
*args, config=config, run_manager=run_manager, **kwargs
|
|
)
|
|
|
|
# TODO: this is for backwards compatibility, remove in future
|
|
def __init__(
|
|
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
|
|
) -> None:
|
|
"""Initialize tool."""
|
|
super(Tool, self).__init__( # type: ignore[call-arg]
|
|
name=name, func=func, description=description, **kwargs
|
|
)
|
|
|
|
@classmethod
|
|
def from_function(
|
|
cls,
|
|
func: Optional[Callable],
|
|
name: str, # We keep these required to support backwards compatibility
|
|
description: str,
|
|
return_direct: bool = False,
|
|
args_schema: Optional[Type[BaseModel]] = None,
|
|
coroutine: Optional[
|
|
Callable[..., Awaitable[Any]]
|
|
] = None, # This is last for compatibility, but should be after func
|
|
**kwargs: Any,
|
|
) -> Tool:
|
|
"""Initialize tool from a function.
|
|
|
|
Args:
|
|
func: The function to create the tool from.
|
|
name: The name of the tool.
|
|
description: The description of the tool.
|
|
return_direct: Whether to return the output directly. Defaults to False.
|
|
args_schema: The schema of the tool's input arguments. Defaults to None.
|
|
coroutine: The asynchronous version of the function. Defaults to None.
|
|
kwargs: Additional arguments to pass to the tool.
|
|
|
|
Returns:
|
|
The tool.
|
|
|
|
Raises:
|
|
ValueError: If the function is not provided.
|
|
"""
|
|
if func is None and coroutine is None:
|
|
raise ValueError("Function and/or coroutine must be provided")
|
|
return cls(
|
|
name=name,
|
|
func=func,
|
|
coroutine=coroutine,
|
|
description=description,
|
|
return_direct=return_direct,
|
|
args_schema=args_schema,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class StructuredTool(BaseTool):
|
|
"""Tool that can operate on any number of inputs."""
|
|
|
|
description: str = ""
|
|
args_schema: TypeBaseModel = Field(..., description="The tool schema.")
|
|
"""The input arguments' schema."""
|
|
func: Optional[Callable[..., Any]]
|
|
"""The function to run when the tool is called."""
|
|
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
|
"""The asynchronous version of the function."""
|
|
|
|
# --- Runnable ---
|
|
|
|
# TODO: Is this needed?
|
|
async def ainvoke(
|
|
self,
|
|
input: Union[str, Dict, ToolCall],
|
|
config: Optional[RunnableConfig] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
if not self.coroutine:
|
|
# If the tool does not implement async, fall back to default implementation
|
|
return await run_in_executor(config, self.invoke, input, config, **kwargs)
|
|
|
|
return await super().ainvoke(input, config, **kwargs)
|
|
|
|
# --- Tool ---
|
|
|
|
@property
|
|
def args(self) -> dict:
|
|
"""The tool's input arguments."""
|
|
return self.args_schema.schema()["properties"]
|
|
|
|
def _run(
|
|
self,
|
|
*args: Any,
|
|
config: RunnableConfig,
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Use the tool."""
|
|
if self.func:
|
|
if run_manager and signature(self.func).parameters.get("callbacks"):
|
|
kwargs["callbacks"] = run_manager.get_child()
|
|
if config_param := _get_runnable_config_param(self.func):
|
|
kwargs[config_param] = config
|
|
return self.func(*args, **kwargs)
|
|
raise NotImplementedError("StructuredTool does not support sync invocation.")
|
|
|
|
async def _arun(
|
|
self,
|
|
*args: Any,
|
|
config: RunnableConfig,
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Use the tool asynchronously."""
|
|
if self.coroutine:
|
|
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
|
kwargs["callbacks"] = run_manager.get_child()
|
|
if config_param := _get_runnable_config_param(self.coroutine):
|
|
kwargs[config_param] = config
|
|
return await self.coroutine(*args, **kwargs)
|
|
|
|
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
|
# None.
|
|
return await super()._arun(
|
|
*args, config=config, run_manager=run_manager, **kwargs
|
|
)
|
|
|
|
@classmethod
|
|
def from_function(
|
|
cls,
|
|
func: Optional[Callable] = None,
|
|
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
return_direct: bool = False,
|
|
args_schema: Optional[Type[BaseModel]] = None,
|
|
infer_schema: bool = True,
|
|
*,
|
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
|
parse_docstring: bool = False,
|
|
error_on_invalid_docstring: bool = False,
|
|
**kwargs: Any,
|
|
) -> StructuredTool:
|
|
"""Create tool from a given function.
|
|
|
|
A classmethod that helps to create a tool from a function.
|
|
|
|
Args:
|
|
func: The function from which to create a tool.
|
|
coroutine: The async function from which to create a tool.
|
|
name: The name of the tool. Defaults to the function name.
|
|
description: The description of the tool.
|
|
Defaults to the function docstring.
|
|
return_direct: Whether to return the result directly or as a callback.
|
|
Defaults to False.
|
|
args_schema: The schema of the tool's input arguments. Defaults to None.
|
|
infer_schema: Whether to infer the schema from the function's signature.
|
|
Defaults to True.
|
|
response_format: The tool response format. If "content" then the output of
|
|
the tool is interpreted as the contents of a ToolMessage. If
|
|
"content_and_artifact" then the output is expected to be a two-tuple
|
|
corresponding to the (content, artifact) of a ToolMessage.
|
|
Defaults to "content".
|
|
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
|
|
to parse parameter descriptions from Google Style function docstrings.
|
|
Defaults to False.
|
|
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
|
whether to raise ValueError on invalid Google Style docstrings.
|
|
Defaults to False.
|
|
kwargs: Additional arguments to pass to the tool
|
|
|
|
Returns:
|
|
The tool.
|
|
|
|
Raises:
|
|
ValueError: If the function is not provided.
|
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
def add(a: int, b: int) -> int:
|
|
\"\"\"Add two numbers\"\"\"
|
|
return a + b
|
|
tool = StructuredTool.from_function(add)
|
|
tool.run(1, 2) # 3
|
|
"""
|
|
|
|
if func is not None:
|
|
source_function = func
|
|
elif coroutine is not None:
|
|
source_function = coroutine
|
|
else:
|
|
raise ValueError("Function and/or coroutine must be provided")
|
|
name = name or source_function.__name__
|
|
if args_schema is None and infer_schema:
|
|
# schema name is appended within function
|
|
args_schema = create_schema_from_function(
|
|
name,
|
|
source_function,
|
|
parse_docstring=parse_docstring,
|
|
error_on_invalid_docstring=error_on_invalid_docstring,
|
|
filter_args=_filter_schema_args(source_function),
|
|
)
|
|
description_ = description
|
|
if description is None and not parse_docstring:
|
|
description_ = source_function.__doc__ or None
|
|
if description_ is None and args_schema:
|
|
description_ = args_schema.__doc__ or None
|
|
if description_ is None:
|
|
raise ValueError(
|
|
"Function must have a docstring if description not provided."
|
|
)
|
|
if description is None:
|
|
# Only apply if using the function's docstring
|
|
description_ = textwrap.dedent(description_).strip()
|
|
|
|
# Description example:
|
|
# search_api(query: str) - Searches the API for the query.
|
|
description_ = f"{description_.strip()}"
|
|
return cls(
|
|
name=name,
|
|
func=func,
|
|
coroutine=coroutine,
|
|
args_schema=args_schema, # type: ignore[arg-type]
|
|
description=description_,
|
|
return_direct=return_direct,
|
|
response_format=response_format,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
# TODO: Type args_schema as TypeBaseModel if we can get mypy to correctly recognize
|
|
# pydantic v2 BaseModel classes.
|
|
def tool(
|
|
*args: Union[str, Callable, Runnable],
|
|
return_direct: bool = False,
|
|
args_schema: Optional[Type] = None,
|
|
infer_schema: bool = True,
|
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
|
parse_docstring: bool = False,
|
|
error_on_invalid_docstring: bool = True,
|
|
) -> Callable:
|
|
"""Make tools out of functions, can be used with or without arguments.
|
|
|
|
Args:
|
|
*args: The arguments to the tool.
|
|
return_direct: Whether to return directly from the tool rather
|
|
than continuing the agent loop. Defaults to False.
|
|
args_schema: optional argument schema for user to specify.
|
|
Defaults to None.
|
|
infer_schema: Whether to infer the schema of the arguments from
|
|
the function's signature. This also makes the resultant tool
|
|
accept a dictionary input to its `run()` function.
|
|
Defaults to True.
|
|
response_format: The tool response format. If "content" then the output of
|
|
the tool is interpreted as the contents of a ToolMessage. If
|
|
"content_and_artifact" then the output is expected to be a two-tuple
|
|
corresponding to the (content, artifact) of a ToolMessage.
|
|
Defaults to "content".
|
|
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
|
|
parse parameter descriptions from Google Style function docstrings.
|
|
Defaults to False.
|
|
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
|
whether to raise ValueError on invalid Google Style docstrings.
|
|
Defaults to True.
|
|
|
|
Returns:
|
|
The tool.
|
|
|
|
Requires:
|
|
- Function must be of type (str) -> str
|
|
- Function must have a docstring
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
@tool
|
|
def search_api(query: str) -> str:
|
|
# Searches the API for the query.
|
|
return
|
|
|
|
@tool("search", return_direct=True)
|
|
def search_api(query: str) -> str:
|
|
# Searches the API for the query.
|
|
return
|
|
|
|
@tool(response_format="content_and_artifact")
|
|
def search_api(query: str) -> Tuple[str, dict]:
|
|
return "partial json of results", {"full": "object of results"}
|
|
|
|
.. versionadded:: 0.2.14
|
|
Parse Google-style docstrings:
|
|
|
|
.. code-block:: python
|
|
|
|
@tool(parse_docstring=True)
|
|
def foo(bar: str, baz: int) -> str:
|
|
\"\"\"The foo.
|
|
|
|
Args:
|
|
bar: The bar.
|
|
baz: The baz.
|
|
\"\"\"
|
|
return bar
|
|
|
|
foo.args_schema.schema()
|
|
|
|
.. code-block:: python
|
|
|
|
{
|
|
"title": "fooSchema",
|
|
"description": "The foo.",
|
|
"type": "object",
|
|
"properties": {
|
|
"bar": {
|
|
"title": "Bar",
|
|
"description": "The bar.",
|
|
"type": "string"
|
|
},
|
|
"baz": {
|
|
"title": "Baz",
|
|
"description": "The baz.",
|
|
"type": "integer"
|
|
}
|
|
},
|
|
"required": [
|
|
"bar",
|
|
"baz"
|
|
]
|
|
}
|
|
|
|
Note that parsing by default will raise ``ValueError`` if the docstring
|
|
is considered invalid. A docstring is considered invalid if it contains
|
|
arguments not in the function signature, or is unable to be parsed into
|
|
a summary and "Args:" blocks. Examples below:
|
|
|
|
.. code-block:: python
|
|
|
|
# No args section
|
|
def invalid_docstring_1(bar: str, baz: int) -> str:
|
|
\"\"\"The foo.\"\"\"
|
|
return bar
|
|
|
|
# Improper whitespace between summary and args section
|
|
def invalid_docstring_2(bar: str, baz: int) -> str:
|
|
\"\"\"The foo.
|
|
Args:
|
|
bar: The bar.
|
|
baz: The baz.
|
|
\"\"\"
|
|
return bar
|
|
|
|
# Documented args absent from function signature
|
|
def invalid_docstring_3(bar: str, baz: int) -> str:
|
|
\"\"\"The foo.
|
|
|
|
Args:
|
|
banana: The bar.
|
|
monkey: The baz.
|
|
\"\"\"
|
|
return bar
|
|
"""
|
|
|
|
def _make_with_name(tool_name: str) -> Callable:
|
|
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
|
if isinstance(dec_func, Runnable):
|
|
runnable = dec_func
|
|
|
|
if runnable.input_schema.schema().get("type") != "object":
|
|
raise ValueError("Runnable must have an object schema.")
|
|
|
|
async def ainvoke_wrapper(
|
|
callbacks: Optional[Callbacks] = None, **kwargs: Any
|
|
) -> Any:
|
|
return await runnable.ainvoke(kwargs, {"callbacks": callbacks})
|
|
|
|
def invoke_wrapper(
|
|
callbacks: Optional[Callbacks] = None, **kwargs: Any
|
|
) -> Any:
|
|
return runnable.invoke(kwargs, {"callbacks": callbacks})
|
|
|
|
coroutine = ainvoke_wrapper
|
|
func = invoke_wrapper
|
|
schema: Optional[Type[BaseModel]] = runnable.input_schema
|
|
description = repr(runnable)
|
|
elif inspect.iscoroutinefunction(dec_func):
|
|
coroutine = dec_func
|
|
func = None
|
|
schema = args_schema
|
|
description = None
|
|
else:
|
|
coroutine = None
|
|
func = dec_func
|
|
schema = args_schema
|
|
description = None
|
|
|
|
if infer_schema or args_schema is not None:
|
|
return StructuredTool.from_function(
|
|
func,
|
|
coroutine,
|
|
name=tool_name,
|
|
description=description,
|
|
return_direct=return_direct,
|
|
args_schema=schema,
|
|
infer_schema=infer_schema,
|
|
response_format=response_format,
|
|
parse_docstring=parse_docstring,
|
|
error_on_invalid_docstring=error_on_invalid_docstring,
|
|
)
|
|
# If someone doesn't want a schema applied, we must treat it as
|
|
# a simple string->string function
|
|
if dec_func.__doc__ is None:
|
|
raise ValueError(
|
|
"Function must have a docstring if "
|
|
"description not provided and infer_schema is False."
|
|
)
|
|
return Tool(
|
|
name=tool_name,
|
|
func=func,
|
|
description=f"{tool_name} tool",
|
|
return_direct=return_direct,
|
|
coroutine=coroutine,
|
|
response_format=response_format,
|
|
)
|
|
|
|
return _make_tool
|
|
|
|
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable):
|
|
return _make_with_name(args[0])(args[1])
|
|
elif len(args) == 1 and isinstance(args[0], str):
|
|
# if the argument is a string, then we use the string as the tool name
|
|
# Example usage: @tool("search", return_direct=True)
|
|
return _make_with_name(args[0])
|
|
elif len(args) == 1 and callable(args[0]):
|
|
# if the argument is a function, then we use the function name as the tool name
|
|
# Example usage: @tool
|
|
return _make_with_name(args[0].__name__)(args[0])
|
|
elif len(args) == 0:
|
|
# if there are no arguments, then we use the function name as the tool name
|
|
# Example usage: @tool(return_direct=True)
|
|
def _partial(func: Callable[[str], str]) -> BaseTool:
|
|
return _make_with_name(func.__name__)(func)
|
|
|
|
return _partial
|
|
else:
|
|
raise ValueError("Too many arguments for tool decorator")
|
|
|
|
|
|
class RetrieverInput(BaseModel):
|
|
"""Input to the retriever."""
|
|
|
|
query: str = Field(description="query to look up in retriever")
|
|
|
|
|
|
def _get_relevant_documents(
|
|
query: str,
|
|
retriever: BaseRetriever,
|
|
document_prompt: BasePromptTemplate,
|
|
document_separator: str,
|
|
callbacks: Callbacks = None,
|
|
) -> str:
|
|
docs = retriever.invoke(query, config={"callbacks": callbacks})
|
|
return document_separator.join(
|
|
format_document(doc, document_prompt) for doc in docs
|
|
)
|
|
|
|
|
|
async def _aget_relevant_documents(
|
|
query: str,
|
|
retriever: BaseRetriever,
|
|
document_prompt: BasePromptTemplate,
|
|
document_separator: str,
|
|
callbacks: Callbacks = None,
|
|
) -> str:
|
|
docs = await retriever.ainvoke(query, config={"callbacks": callbacks})
|
|
return document_separator.join(
|
|
[await aformat_document(doc, document_prompt) for doc in docs]
|
|
)
|
|
|
|
|
|
def create_retriever_tool(
|
|
retriever: BaseRetriever,
|
|
name: str,
|
|
description: str,
|
|
*,
|
|
document_prompt: Optional[BasePromptTemplate] = None,
|
|
document_separator: str = "\n\n",
|
|
) -> Tool:
|
|
"""Create a tool to do retrieval of documents.
|
|
|
|
Args:
|
|
retriever: The retriever to use for the retrieval
|
|
name: The name for the tool. This will be passed to the language model,
|
|
so should be unique and somewhat descriptive.
|
|
description: The description for the tool. This will be passed to the language
|
|
model, so should be descriptive.
|
|
document_prompt: The prompt to use for the document. Defaults to None.
|
|
document_separator: The separator to use between documents. Defaults to "\n\n".
|
|
|
|
Returns:
|
|
Tool class to pass to an agent.
|
|
"""
|
|
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
|
|
func = partial(
|
|
_get_relevant_documents,
|
|
retriever=retriever,
|
|
document_prompt=document_prompt,
|
|
document_separator=document_separator,
|
|
)
|
|
afunc = partial(
|
|
_aget_relevant_documents,
|
|
retriever=retriever,
|
|
document_prompt=document_prompt,
|
|
document_separator=document_separator,
|
|
)
|
|
return Tool(
|
|
name=name,
|
|
description=description,
|
|
func=func,
|
|
coroutine=afunc,
|
|
args_schema=RetrieverInput,
|
|
)
|
|
|
|
|
|
ToolsRenderer = Callable[[List[BaseTool]], str]
|
|
|
|
|
|
def render_text_description(tools: List[BaseTool]) -> str:
|
|
"""Render the tool name and description in plain text.
|
|
|
|
Args:
|
|
tools: The tools to render.
|
|
|
|
Returns:
|
|
The rendered text.
|
|
|
|
Output will be in the format of:
|
|
|
|
.. code-block:: markdown
|
|
|
|
search: This tool is used for search
|
|
calculator: This tool is used for math
|
|
"""
|
|
descriptions = []
|
|
for tool in tools:
|
|
if hasattr(tool, "func") and tool.func:
|
|
sig = signature(tool.func)
|
|
description = f"{tool.name}{sig} - {tool.description}"
|
|
else:
|
|
description = f"{tool.name} - {tool.description}"
|
|
|
|
descriptions.append(description)
|
|
return "\n".join(descriptions)
|
|
|
|
|
|
def render_text_description_and_args(tools: List[BaseTool]) -> str:
|
|
"""Render the tool name, description, and args in plain text.
|
|
|
|
Args:
|
|
tools: The tools to render.
|
|
|
|
Returns:
|
|
The rendered text.
|
|
|
|
Output will be in the format of:
|
|
|
|
.. code-block:: markdown
|
|
|
|
search: This tool is used for search, args: {"query": {"type": "string"}}
|
|
calculator: This tool is used for math, \
|
|
args: {"expression": {"type": "string"}}
|
|
"""
|
|
tool_strings = []
|
|
for tool in tools:
|
|
args_schema = str(tool.args)
|
|
if hasattr(tool, "func") and tool.func:
|
|
sig = signature(tool.func)
|
|
description = f"{tool.name}{sig} - {tool.description}"
|
|
else:
|
|
description = f"{tool.name} - {tool.description}"
|
|
tool_strings.append(f"{description}, args: {args_schema}")
|
|
return "\n".join(tool_strings)
|
|
|
|
|
|
class BaseToolkit(BaseModel, ABC):
|
|
"""Base Toolkit representing a collection of related tools."""
|
|
|
|
@abstractmethod
|
|
def get_tools(self) -> List[BaseTool]:
|
|
"""Get the tools in the toolkit."""
|
|
|
|
|
|
def _is_tool_call(x: Any) -> bool:
|
|
return isinstance(x, dict) and x.get("type") == "tool_call"
|
|
|
|
|
|
def _handle_validation_error(
|
|
e: ValidationError,
|
|
*,
|
|
flag: Union[Literal[True], str, Callable[[ValidationError], str]],
|
|
) -> str:
|
|
if isinstance(flag, bool):
|
|
content = "Tool input validation error"
|
|
elif isinstance(flag, str):
|
|
content = flag
|
|
elif callable(flag):
|
|
content = flag(e)
|
|
else:
|
|
raise ValueError(
|
|
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
|
f"str or callable. Received: {flag}"
|
|
)
|
|
return content
|
|
|
|
|
|
def _handle_tool_error(
|
|
e: ToolException,
|
|
*,
|
|
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
|
) -> str:
|
|
if isinstance(flag, bool):
|
|
if e.args:
|
|
content = e.args[0]
|
|
else:
|
|
content = "Tool execution error"
|
|
elif isinstance(flag, str):
|
|
content = flag
|
|
elif callable(flag):
|
|
content = flag(e)
|
|
else:
|
|
raise ValueError(
|
|
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
|
f"or callable. Received: {flag}"
|
|
)
|
|
return content
|
|
|
|
|
|
def _prep_run_args(
|
|
input: Union[str, dict, ToolCall],
|
|
config: Optional[RunnableConfig],
|
|
**kwargs: Any,
|
|
) -> Tuple[Union[str, Dict], Dict]:
|
|
config = ensure_config(config)
|
|
if _is_tool_call(input):
|
|
tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
|
|
tool_input: Union[str, dict] = cast(ToolCall, input)["args"].copy()
|
|
else:
|
|
tool_call_id = None
|
|
tool_input = cast(Union[str, dict], input)
|
|
return (
|
|
tool_input,
|
|
dict(
|
|
callbacks=config.get("callbacks"),
|
|
tags=config.get("tags"),
|
|
metadata=config.get("metadata"),
|
|
run_name=config.get("run_name"),
|
|
run_id=config.pop("run_id", None),
|
|
config=config,
|
|
tool_call_id=tool_call_id,
|
|
**kwargs,
|
|
),
|
|
)
|
|
|
|
|
|
def _format_output(
|
|
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
|
|
) -> Union[ToolMessage, Any]:
|
|
if tool_call_id:
|
|
if not _is_message_content_type(content):
|
|
content = _stringify(content)
|
|
return ToolMessage(
|
|
content,
|
|
artifact=artifact,
|
|
tool_call_id=tool_call_id,
|
|
name=name,
|
|
status=status,
|
|
)
|
|
else:
|
|
return content
|
|
|
|
|
|
def _is_message_content_type(obj: Any) -> bool:
|
|
"""Check for OpenAI or Anthropic format tool message content."""
|
|
if isinstance(obj, str):
|
|
return True
|
|
elif isinstance(obj, list) and all(_is_message_content_block(e) for e in obj):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def _is_message_content_block(obj: Any) -> bool:
|
|
"""Check for OpenAI or Anthropic format tool message content blocks."""
|
|
if isinstance(obj, str):
|
|
return True
|
|
elif isinstance(obj, dict):
|
|
return obj.get("type", None) in ("text", "image_url", "image", "json")
|
|
else:
|
|
return False
|
|
|
|
|
|
def _stringify(content: Any) -> str:
|
|
try:
|
|
return json.dumps(content)
|
|
except Exception:
|
|
return str(content)
|
|
|
|
|
|
def _get_description_from_runnable(runnable: Runnable) -> str:
|
|
"""Generate a placeholder description of a runnable."""
|
|
input_schema = runnable.input_schema.schema()
|
|
return f"Takes {input_schema}."
|
|
|
|
|
|
def _get_schema_from_runnable_and_arg_types(
|
|
runnable: Runnable,
|
|
name: str,
|
|
arg_types: Optional[Dict[str, Type]] = None,
|
|
) -> Type[BaseModel]:
|
|
"""Infer args_schema for tool."""
|
|
if arg_types is None:
|
|
try:
|
|
arg_types = get_type_hints(runnable.InputType)
|
|
except TypeError as e:
|
|
raise TypeError(
|
|
"Tool input must be str or dict. If dict, dict arguments must be "
|
|
"typed. Either annotate types (e.g., with TypedDict) or pass "
|
|
f"arg_types into `.as_tool` to specify. {str(e)}"
|
|
)
|
|
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}
|
|
return create_model(name, **fields) # type: ignore
|
|
|
|
|
|
def convert_runnable_to_tool(
|
|
runnable: Runnable,
|
|
args_schema: Optional[Type[BaseModel]] = None,
|
|
*,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
arg_types: Optional[Dict[str, Type]] = None,
|
|
) -> BaseTool:
|
|
"""Convert a Runnable into a BaseTool.
|
|
|
|
Args:
|
|
runnable: The runnable to convert.
|
|
args_schema: The schema for the tool's input arguments. Defaults to None.
|
|
name: The name of the tool. Defaults to None.
|
|
description: The description of the tool. Defaults to None.
|
|
arg_types: The types of the arguments. Defaults to None.
|
|
|
|
Returns:
|
|
The tool.
|
|
"""
|
|
if args_schema:
|
|
runnable = runnable.with_types(input_type=args_schema)
|
|
description = description or _get_description_from_runnable(runnable)
|
|
name = name or runnable.get_name()
|
|
|
|
schema = runnable.input_schema.schema()
|
|
if schema.get("type") == "string":
|
|
return Tool(
|
|
name=name,
|
|
func=runnable.invoke,
|
|
coroutine=runnable.ainvoke,
|
|
description=description,
|
|
)
|
|
else:
|
|
|
|
async def ainvoke_wrapper(
|
|
callbacks: Optional[Callbacks] = None, **kwargs: Any
|
|
) -> Any:
|
|
return await runnable.ainvoke(kwargs, config={"callbacks": callbacks})
|
|
|
|
def invoke_wrapper(callbacks: Optional[Callbacks] = None, **kwargs: Any) -> Any:
|
|
return runnable.invoke(kwargs, config={"callbacks": callbacks})
|
|
|
|
if (
|
|
arg_types is None
|
|
and schema.get("type") == "object"
|
|
and schema.get("properties")
|
|
):
|
|
args_schema = runnable.input_schema
|
|
else:
|
|
args_schema = _get_schema_from_runnable_and_arg_types(
|
|
runnable, name, arg_types=arg_types
|
|
)
|
|
|
|
return StructuredTool.from_function(
|
|
name=name,
|
|
func=invoke_wrapper,
|
|
coroutine=ainvoke_wrapper,
|
|
description=description,
|
|
args_schema=args_schema,
|
|
)
|
|
|
|
|
|
def _get_type_hints(func: Callable) -> Optional[Dict[str, Type]]:
|
|
if isinstance(func, functools.partial):
|
|
func = func.func
|
|
try:
|
|
return get_type_hints(func)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _get_runnable_config_param(func: Callable) -> Optional[str]:
|
|
type_hints = _get_type_hints(func)
|
|
if not type_hints:
|
|
return None
|
|
for name, type_ in type_hints.items():
|
|
if type_ is RunnableConfig:
|
|
return name
|
|
return None
|
|
|
|
|
|
class InjectedToolArg:
|
|
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
|
|
|
|
|
|
def _is_injected_arg_type(type_: Type) -> bool:
|
|
return any(
|
|
isinstance(arg, InjectedToolArg)
|
|
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
|
|
for arg in get_args(type_)[1:]
|
|
)
|
|
|
|
|
|
def _filter_schema_args(func: Callable) -> List[str]:
|
|
filter_args = list(FILTERED_ARGS)
|
|
if config_param := _get_runnable_config_param(func):
|
|
filter_args.append(config_param)
|
|
# filter_args.extend(_get_non_model_params(type_hints))
|
|
return filter_args
|
|
|
|
|
|
def _get_all_basemodel_annotations(
|
|
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
|
) -> Dict[str, Type]:
|
|
# cls has no subscript: cls = FooBar
|
|
if isinstance(cls, type):
|
|
annotations: Dict[str, Type] = {}
|
|
for name, param in inspect.signature(cls).parameters.items():
|
|
# Exclude hidden init args added by pydantic Config. For example if
|
|
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
|
if (
|
|
fields := getattr(cls, "model_fields", {}) # pydantic v2+
|
|
or getattr(cls, "__fields__", {}) # pydantic v1
|
|
) and name not in fields:
|
|
continue
|
|
annotations[name] = param.annotation
|
|
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
|
|
# cls has subscript: cls = FooBar[int]
|
|
else:
|
|
annotations = _get_all_basemodel_annotations(
|
|
get_origin(cls), default_to_bound=False
|
|
)
|
|
orig_bases = (cls,)
|
|
|
|
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
|
|
if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)):
|
|
# if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str]
|
|
# if cls = FooBar inherits from Baz, orig_bases will contain Baz
|
|
# if cls = FooBar[int], orig_bases will contain FooBar[int]
|
|
for parent in orig_bases:
|
|
# if class = FooBar inherits from Baz, parent = Baz
|
|
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
|
|
annotations.update(
|
|
_get_all_basemodel_annotations(parent, default_to_bound=False)
|
|
)
|
|
continue
|
|
|
|
parent_origin = get_origin(parent)
|
|
|
|
# if class = FooBar inherits from non-pydantic class
|
|
if not parent_origin:
|
|
continue
|
|
|
|
# if class = FooBar inherits from Baz[str]:
|
|
# parent = Baz[str],
|
|
# parent_origin = Baz,
|
|
# generic_type_vars = (type vars in Baz)
|
|
# generic_map = {type var in Baz: str}
|
|
generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple())
|
|
generic_map = {
|
|
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
|
|
}
|
|
for field in getattr(parent_origin, "__annotations__", dict()):
|
|
annotations[field] = _replace_type_vars(
|
|
annotations[field], generic_map, default_to_bound
|
|
)
|
|
|
|
return {
|
|
k: _replace_type_vars(v, default_to_bound=default_to_bound)
|
|
for k, v in annotations.items()
|
|
}
|
|
|
|
|
|
def _replace_type_vars(
|
|
type_: Type,
|
|
generic_map: Optional[Dict[TypeVar, Type]] = None,
|
|
default_to_bound: bool = True,
|
|
) -> Type:
|
|
generic_map = generic_map or {}
|
|
if isinstance(type_, TypeVar):
|
|
if type_ in generic_map:
|
|
return generic_map[type_]
|
|
elif default_to_bound:
|
|
return type_.__bound__ or Any
|
|
else:
|
|
return type_
|
|
elif (origin := get_origin(type_)) and (args := get_args(type_)):
|
|
new_args = tuple(
|
|
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
|
|
)
|
|
return _py_38_safe_origin(origin)[new_args]
|
|
else:
|
|
return type_
|