mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Dynamic tool -> single purpose (#3697)
I think the logic of https://github.com/hwchase17/langchain/pull/3684#pullrequestreview-1405358565 is too confusing. I prefer this alternative because: - All `Tool()` implementations by default will be treated the same as before. No breaking changes. - Less reliance on pydantic magic - The decorator (which only is typed as returning a callable) can infer schema and generate a structured tool - Either way, the recommended way to create a custom tool is through inheriting from the base tool
This commit is contained in:
parent
1bf1c37c0c
commit
da7b51455c
@ -1,15 +1,10 @@
|
|||||||
"""Interface for tools."""
|
"""Interface for tools."""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import signature
|
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
from typing import Any, Awaitable, Callable, Optional, Type, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel, validate_arguments, validator
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
from langchain.tools.base import (
|
from langchain.tools.base import BaseTool, StructuredTool
|
||||||
BaseTool,
|
|
||||||
create_schema_from_function,
|
|
||||||
get_filtered_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseTool):
|
class Tool(BaseTool):
|
||||||
@ -30,17 +25,30 @@ class Tool(BaseTool):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> dict:
|
def args(self) -> dict:
|
||||||
|
"""The tool's input arguments."""
|
||||||
if self.args_schema is not None:
|
if self.args_schema is not None:
|
||||||
return self.args_schema.schema()["properties"]
|
return self.args_schema.schema()["properties"]
|
||||||
else:
|
# For backwards compatibility, if the function signature is ambiguous,
|
||||||
inferred_model = validate_arguments(self.func).model # type: ignore
|
# assume it takes a single string input.
|
||||||
return get_filtered_args(inferred_model, self.func)
|
return {"tool_input": {"type": "string"}}
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||||
|
"""Convert tool input to pydantic model."""
|
||||||
|
args, kwargs = super()._to_args_and_kwargs(tool_input)
|
||||||
|
# For backwards compatibility. The tool must be run with a single input
|
||||||
|
all_args = list(args) + list(kwargs.values())
|
||||||
|
if len(all_args) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Too many arguments to single-input tool {self.name}."
|
||||||
|
f" Args: {all_args}"
|
||||||
|
)
|
||||||
|
return tuple(all_args), {}
|
||||||
|
|
||||||
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
if self.coroutine:
|
if self.coroutine:
|
||||||
return await self.coroutine(*args, **kwargs)
|
return await self.coroutine(*args, **kwargs)
|
||||||
@ -48,7 +56,7 @@ class Tool(BaseTool):
|
|||||||
|
|
||||||
# TODO: this is for backwards compatibility, remove in future
|
# TODO: this is for backwards compatibility, remove in future
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, func: Callable[[str], str], description: str, **kwargs: Any
|
self, name: str, func: Callable, description: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize tool."""
|
"""Initialize tool."""
|
||||||
super(Tool, self).__init__(
|
super(Tool, self).__init__(
|
||||||
@ -107,22 +115,24 @@ def tool(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _make_with_name(tool_name: str) -> Callable:
|
def _make_with_name(tool_name: str) -> Callable:
|
||||||
def _make_tool(func: Callable) -> Tool:
|
def _make_tool(func: Callable) -> BaseTool:
|
||||||
assert func.__doc__, "Function must have a docstring"
|
if infer_schema or args_schema is not None:
|
||||||
# Description example:
|
return StructuredTool.from_function(
|
||||||
# search_api(query: str) - Searches the API for the query.
|
func,
|
||||||
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
name=tool_name,
|
||||||
_args_schema = args_schema
|
return_direct=return_direct,
|
||||||
if _args_schema is None and infer_schema:
|
args_schema=args_schema,
|
||||||
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
|
infer_schema=infer_schema,
|
||||||
tool_ = Tool(
|
)
|
||||||
|
# If someone doesn't want a schema applied, we must treat it as
|
||||||
|
# a simple string->string function
|
||||||
|
assert func.__doc__ is not None, "Function must have a docstring"
|
||||||
|
return Tool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
func=func,
|
func=func,
|
||||||
args_schema=_args_schema,
|
description=f"{tool_name} tool",
|
||||||
description=description,
|
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
)
|
)
|
||||||
return tool_
|
|
||||||
|
|
||||||
return _make_tool
|
return _make_tool
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
|
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@ -19,15 +19,6 @@ from langchain.callbacks import get_callback_manager
|
|||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
|
||||||
|
|
||||||
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
|
|
||||||
# For backwards compatability, if run_input is a string,
|
|
||||||
# pass as a positional argument.
|
|
||||||
if isinstance(run_input, str):
|
|
||||||
return (run_input,), {}
|
|
||||||
else:
|
|
||||||
return [], run_input
|
|
||||||
|
|
||||||
|
|
||||||
class SchemaAnnotationError(TypeError):
|
class SchemaAnnotationError(TypeError):
|
||||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||||
|
|
||||||
@ -81,14 +72,20 @@ def _create_subset_model(
|
|||||||
return create_model(name, **fields) # type: ignore
|
return create_model(name, **fields) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
def get_filtered_args(
|
||||||
|
inferred_model: Type[BaseModel],
|
||||||
|
func: Callable,
|
||||||
|
) -> dict:
|
||||||
"""Get the arguments from a function's signature."""
|
"""Get the arguments from a function's signature."""
|
||||||
schema = inferred_model.schema()["properties"]
|
schema = inferred_model.schema()["properties"]
|
||||||
valid_keys = signature(func).parameters
|
valid_keys = signature(func).parameters
|
||||||
return {k: schema[k] for k in valid_keys}
|
return {k: schema[k] for k in valid_keys}
|
||||||
|
|
||||||
|
|
||||||
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
def create_schema_from_function(
|
||||||
|
model_name: str,
|
||||||
|
func: Callable,
|
||||||
|
) -> Type[BaseModel]:
|
||||||
"""Create a pydantic schema from a function's signature."""
|
"""Create a pydantic schema from a function's signature."""
|
||||||
inferred_model = validate_arguments(func).model # type: ignore
|
inferred_model = validate_arguments(func).model # type: ignore
|
||||||
# Pydantic adds placeholder virtual fields we need to strip
|
# Pydantic adds placeholder virtual fields we need to strip
|
||||||
@ -102,12 +99,23 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
"""Interface LangChain tools must implement."""
|
"""Interface LangChain tools must implement."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
"""The unique name of the tool that clearly communicates its purpose."""
|
||||||
description: str
|
description: str
|
||||||
|
"""Used to tell the model how/when/why to use the tool.
|
||||||
|
|
||||||
|
You can provide few-shot examples as a part of the description.
|
||||||
|
"""
|
||||||
args_schema: Optional[Type[BaseModel]] = None
|
args_schema: Optional[Type[BaseModel]] = None
|
||||||
"""Pydantic model class to validate and parse the tool's input arguments."""
|
"""Pydantic model class to validate and parse the tool's input arguments."""
|
||||||
return_direct: bool = False
|
return_direct: bool = False
|
||||||
|
"""Whether to return the tool's output directly. Setting this to True means
|
||||||
|
|
||||||
|
that after the tool is called, the AgentExecutor will stop looping.
|
||||||
|
"""
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
|
"""Whether to log the tool's progress."""
|
||||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||||
|
"""Callback manager for this tool."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -160,6 +168,14 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
|
|
||||||
|
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||||
|
# For backwards compatibility, if run_input is a string,
|
||||||
|
# pass as a positional argument.
|
||||||
|
if isinstance(tool_input, str):
|
||||||
|
return (tool_input,), {}
|
||||||
|
else:
|
||||||
|
return (), tool_input
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
tool_input: Union[str, Dict],
|
tool_input: Union[str, Dict],
|
||||||
@ -182,7 +198,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||||
observation = self._run(*tool_args, **tool_kwargs)
|
observation = self._run(*tool_args, **tool_kwargs)
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||||
@ -224,8 +240,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
args, kwargs = _to_args_and_kwargs(tool_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||||
observation = await self._arun(*args, **kwargs)
|
observation = await self._arun(*tool_args, **tool_kwargs)
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
if self.callback_manager.is_async:
|
if self.callback_manager.is_async:
|
||||||
await self.callback_manager.on_tool_error(e, verbose=verbose_)
|
await self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||||
@ -249,3 +265,62 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
def __call__(self, tool_input: Union[str, dict]) -> Any:
|
def __call__(self, tool_input: Union[str, dict]) -> Any:
|
||||||
"""Make tool callable."""
|
"""Make tool callable."""
|
||||||
return self.run(tool_input)
|
return self.run(tool_input)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredTool(BaseTool):
|
||||||
|
"""Tool that can operate on any number of inputs."""
|
||||||
|
|
||||||
|
description: str = ""
|
||||||
|
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
|
||||||
|
"""The input arguments' schema."""
|
||||||
|
func: Callable[..., str]
|
||||||
|
"""The function to run when the tool is called."""
|
||||||
|
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
||||||
|
"""The asynchronous version of the function."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args(self) -> dict:
|
||||||
|
"""The tool's input arguments."""
|
||||||
|
return self.args_schema.schema()["properties"]
|
||||||
|
|
||||||
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Use the tool."""
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
if self.coroutine:
|
||||||
|
return await self.coroutine(*args, **kwargs)
|
||||||
|
raise NotImplementedError("Tool does not support async")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_function(
|
||||||
|
cls,
|
||||||
|
func: Callable,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
return_direct: bool = False,
|
||||||
|
args_schema: Optional[Type[BaseModel]] = None,
|
||||||
|
infer_schema: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> StructuredTool:
|
||||||
|
name = name or func.__name__
|
||||||
|
description = description or func.__doc__
|
||||||
|
assert (
|
||||||
|
description is not None
|
||||||
|
), "Function must have a docstring if description not provided."
|
||||||
|
|
||||||
|
# Description example:
|
||||||
|
# search_api(query: str) - Searches the API for the query.
|
||||||
|
description = f"{name}{signature(func)} - {description.strip()}"
|
||||||
|
_args_schema = args_schema
|
||||||
|
if _args_schema is None and infer_schema:
|
||||||
|
_args_schema = create_schema_from_function(f"{name}Schema", func)
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
func=func,
|
||||||
|
args_schema=_args_schema,
|
||||||
|
description=description,
|
||||||
|
return_direct=return_direct,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test tool utils."""
|
"""Test tool utils."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Type, Union
|
from typing import Any, Optional, Type, Union
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -16,7 +16,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
|
|||||||
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
||||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||||
from langchain.agents.tools import Tool, tool
|
from langchain.agents.tools import Tool, tool
|
||||||
from langchain.tools.base import BaseTool, SchemaAnnotationError
|
from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool
|
||||||
|
|
||||||
|
|
||||||
def test_unnamed_decorator() -> None:
|
def test_unnamed_decorator() -> None:
|
||||||
@ -27,7 +27,7 @@ def test_unnamed_decorator() -> None:
|
|||||||
"""Search the API for the query."""
|
"""Search the API for the query."""
|
||||||
return "API result"
|
return "API result"
|
||||||
|
|
||||||
assert isinstance(search_api, Tool)
|
assert isinstance(search_api, BaseTool)
|
||||||
assert search_api.name == "search_api"
|
assert search_api.name == "search_api"
|
||||||
assert not search_api.return_direct
|
assert not search_api.return_direct
|
||||||
assert search_api("test") == "API result"
|
assert search_api("test") == "API result"
|
||||||
@ -145,7 +145,7 @@ def test_decorator_with_specified_schema() -> None:
|
|||||||
"""Return the arguments directly."""
|
"""Return the arguments directly."""
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
assert isinstance(tool_func, Tool)
|
assert isinstance(tool_func, BaseTool)
|
||||||
assert tool_func.args_schema == _MockSchema
|
assert tool_func.args_schema == _MockSchema
|
||||||
|
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ def test_decorated_function_schema_equivalent() -> None:
|
|||||||
"""Return the arguments directly."""
|
"""Return the arguments directly."""
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
assert isinstance(structured_tool_input, Tool)
|
assert isinstance(structured_tool_input, BaseTool)
|
||||||
assert structured_tool_input.args_schema is not None
|
assert structured_tool_input.args_schema is not None
|
||||||
assert (
|
assert (
|
||||||
structured_tool_input.args_schema.schema()["properties"]
|
structured_tool_input.args_schema.schema()["properties"]
|
||||||
@ -171,14 +171,14 @@ def test_decorated_function_schema_equivalent() -> None:
|
|||||||
def test_structured_args_decorator_no_infer_schema() -> None:
|
def test_structured_args_decorator_no_infer_schema() -> None:
|
||||||
"""Test functionality with structured arguments parsed as a decorator."""
|
"""Test functionality with structured arguments parsed as a decorator."""
|
||||||
|
|
||||||
@tool(infer_schema=False)
|
@tool
|
||||||
def structured_tool_input(
|
def structured_tool_input(
|
||||||
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return the arguments directly."""
|
"""Return the arguments directly."""
|
||||||
return f"{arg1}, {arg2}, {opt_arg}"
|
return f"{arg1}, {arg2}, {opt_arg}"
|
||||||
|
|
||||||
assert isinstance(structured_tool_input, Tool)
|
assert isinstance(structured_tool_input, BaseTool)
|
||||||
assert structured_tool_input.name == "structured_tool_input"
|
assert structured_tool_input.name == "structured_tool_input"
|
||||||
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
|
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
|
||||||
expected_result = "1, 0.001, {'foo': 'bar'}"
|
expected_result = "1, 0.001, {'foo': 'bar'}"
|
||||||
@ -193,8 +193,9 @@ def test_structured_single_str_decorator_no_infer_schema() -> None:
|
|||||||
"""Return the arguments directly."""
|
"""Return the arguments directly."""
|
||||||
return f"{tool_input}"
|
return f"{tool_input}"
|
||||||
|
|
||||||
assert isinstance(unstructured_tool_input, Tool)
|
assert isinstance(unstructured_tool_input, BaseTool)
|
||||||
assert unstructured_tool_input.args_schema is None
|
assert unstructured_tool_input.args_schema is None
|
||||||
|
assert unstructured_tool_input.run("foo") == "foo"
|
||||||
|
|
||||||
|
|
||||||
def test_base_tool_inheritance_base_schema() -> None:
|
def test_base_tool_inheritance_base_schema() -> None:
|
||||||
@ -225,18 +226,18 @@ def test_tool_lambda_args_schema() -> None:
|
|||||||
func=lambda tool_input: tool_input,
|
func=lambda tool_input: tool_input,
|
||||||
)
|
)
|
||||||
assert tool.args_schema is None
|
assert tool.args_schema is None
|
||||||
expected_args = {"tool_input": {"title": "Tool Input"}}
|
expected_args = {"tool_input": {"type": "string"}}
|
||||||
assert tool.args == expected_args
|
assert tool.args == expected_args
|
||||||
|
|
||||||
|
|
||||||
def test_tool_lambda_multi_args_schema() -> None:
|
def test_structured_tool_lambda_multi_args_schema() -> None:
|
||||||
"""Test args schema inference when the tool argument is a lambda function."""
|
"""Test args schema inference when the tool argument is a lambda function."""
|
||||||
tool = Tool(
|
tool = StructuredTool.from_function(
|
||||||
name="tool",
|
name="tool",
|
||||||
description="A tool",
|
description="A tool",
|
||||||
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
||||||
)
|
)
|
||||||
assert tool.args_schema is None
|
assert tool.args_schema is not None
|
||||||
expected_args = {
|
expected_args = {
|
||||||
"tool_input": {"title": "Tool Input"},
|
"tool_input": {"title": "Tool Input"},
|
||||||
"other_arg": {"title": "Other Arg"},
|
"other_arg": {"title": "Other Arg"},
|
||||||
@ -268,7 +269,7 @@ def test_empty_args_decorator() -> None:
|
|||||||
"""Return a constant."""
|
"""Return a constant."""
|
||||||
return "the empty result"
|
return "the empty result"
|
||||||
|
|
||||||
assert isinstance(empty_tool_input, Tool)
|
assert isinstance(empty_tool_input, BaseTool)
|
||||||
assert empty_tool_input.name == "empty_tool_input"
|
assert empty_tool_input.name == "empty_tool_input"
|
||||||
assert empty_tool_input.args == {}
|
assert empty_tool_input.args == {}
|
||||||
assert empty_tool_input.run({}) == "the empty result"
|
assert empty_tool_input.run({}) == "the empty result"
|
||||||
@ -282,7 +283,7 @@ def test_named_tool_decorator() -> None:
|
|||||||
"""Search the API for the query."""
|
"""Search the API for the query."""
|
||||||
return "API result"
|
return "API result"
|
||||||
|
|
||||||
assert isinstance(search_api, Tool)
|
assert isinstance(search_api, BaseTool)
|
||||||
assert search_api.name == "search"
|
assert search_api.name == "search"
|
||||||
assert not search_api.return_direct
|
assert not search_api.return_direct
|
||||||
|
|
||||||
@ -295,7 +296,7 @@ def test_named_tool_decorator_return_direct() -> None:
|
|||||||
"""Search the API for the query."""
|
"""Search the API for the query."""
|
||||||
return "API result"
|
return "API result"
|
||||||
|
|
||||||
assert isinstance(search_api, Tool)
|
assert isinstance(search_api, BaseTool)
|
||||||
assert search_api.name == "search"
|
assert search_api.name == "search"
|
||||||
assert search_api.return_direct
|
assert search_api.return_direct
|
||||||
|
|
||||||
@ -308,7 +309,7 @@ def test_unnamed_tool_decorator_return_direct() -> None:
|
|||||||
"""Search the API for the query."""
|
"""Search the API for the query."""
|
||||||
return "API result"
|
return "API result"
|
||||||
|
|
||||||
assert isinstance(search_api, Tool)
|
assert isinstance(search_api, BaseTool)
|
||||||
assert search_api.name == "search_api"
|
assert search_api.name == "search_api"
|
||||||
assert search_api.return_direct
|
assert search_api.return_direct
|
||||||
|
|
||||||
@ -325,7 +326,7 @@ def test_tool_with_kwargs() -> None:
|
|||||||
"""Search the API for the query."""
|
"""Search the API for the query."""
|
||||||
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
|
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
|
||||||
|
|
||||||
assert isinstance(search_api, Tool)
|
assert isinstance(search_api, BaseTool)
|
||||||
result = search_api.run(
|
result = search_api.run(
|
||||||
tool_input={
|
tool_input={
|
||||||
"arg_0": "foo",
|
"arg_0": "foo",
|
||||||
@ -423,3 +424,23 @@ def test_single_input_agent_raises_error_on_structured_tool(
|
|||||||
f" multi-input tool {the_tool.name}.",
|
f" multi-input tool {the_tool.name}.",
|
||||||
):
|
):
|
||||||
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
|
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_no_args_specified_assumes_str() -> None:
|
||||||
|
"""Older tools could assume *args and **kwargs were passed in."""
|
||||||
|
|
||||||
|
def ambiguous_function(*args: Any, **kwargs: Any) -> str:
|
||||||
|
"""An ambiguously defined function."""
|
||||||
|
return args[0]
|
||||||
|
|
||||||
|
some_tool = Tool(
|
||||||
|
name="chain_run",
|
||||||
|
description="Run the chain",
|
||||||
|
func=ambiguous_function,
|
||||||
|
)
|
||||||
|
expected_args = {"tool_input": {"type": "string"}}
|
||||||
|
assert some_tool.args == expected_args
|
||||||
|
assert some_tool.run("foobar") == "foobar"
|
||||||
|
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
|
||||||
|
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
|
||||||
|
some_tool.run({"tool_input": "foobar", "other_input": "bar"})
|
||||||
|
Loading…
Reference in New Issue
Block a user