Structured Tool Bugfixes (#3324)

- Proactively raise error if a tool subclasses BaseTool, defines its
own schema, but fails to add the type-hints
- fix the auto-inferred schema of the decorator to strip the
unneeded virtual kwargs from the schema dict

Helps avoid silent instances of #3297
This commit is contained in:
Zander Chase 2023-04-24 09:58:29 -07:00 committed by GitHub
parent f22b9d0e57
commit 49122a96e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 287 additions and 16 deletions

View File

@ -1,10 +1,15 @@
"""Interface for tools.""" """Interface for tools."""
from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union from typing import Any, Awaitable, Callable, Optional, Type, Union
from pydantic import BaseModel, validate_arguments from pydantic import BaseModel, validate_arguments, validator
from langchain.tools.base import BaseTool from langchain.tools.base import (
BaseTool,
create_schema_from_function,
get_filtered_args,
)
class Tool(BaseTool): class Tool(BaseTool):
@ -16,15 +21,20 @@ class Tool(BaseTool):
coroutine: Optional[Callable[..., Awaitable[str]]] = None coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function.""" """The asynchronous version of the function."""
@validator("func", pre=True, always=True)
def validate_func_not_partial(cls, func: Callable) -> Callable:
"""Check that the function is not a partial."""
if isinstance(func, partial):
raise ValueError("Partial functions not yet supported in tools.")
return func
@property @property
def args(self) -> dict: def args(self) -> dict:
if self.args_schema is not None: if self.args_schema is not None:
return self.args_schema.schema()["properties"] return self.args_schema.schema()["properties"]
else: else:
inferred_model = validate_arguments(self.func).model # type: ignore inferred_model = validate_arguments(self.func).model # type: ignore
schema = inferred_model.schema()["properties"] return get_filtered_args(inferred_model, self.func)
valid_keys = signature(self.func).parameters
return {k: schema[k] for k in valid_keys}
def _run(self, *args: Any, **kwargs: Any) -> str: def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool.""" """Use the tool."""
@ -104,7 +114,7 @@ def tool(
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
_args_schema = args_schema _args_schema = args_schema
if _args_schema is None and infer_schema: if _args_schema is None and infer_schema:
_args_schema = validate_arguments(func).model # type: ignore _args_schema = create_schema_from_function(f"{tool_name}Schema", func)
tool_ = Tool( tool_ = Tool(
name=tool_name, name=tool_name,
func=func, func=func,

View File

@ -1,10 +1,19 @@
"""Base implementation for tools or skills.""" """Base implementation for tools or skills."""
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from pydantic import BaseModel, Extra, Field, validate_arguments, validator from pydantic import (
BaseModel,
Extra,
Field,
create_model,
validate_arguments,
validator,
)
from pydantic.main import ModelMetaclass
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
@ -19,7 +28,77 @@ def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
return [], run_input return [], run_input
class BaseTool(ABC, BaseModel): class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
class ToolMetaclass(ModelMetaclass):
"""Metaclass for BaseTool to ensure the provided args_schema
doesn't silently ignored."""
def __new__(
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
) -> ToolMetaclass:
"""Create the definition of the new tool class."""
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
if schema_type is not None:
schema_annotations = dct.get("__annotations__", {})
args_schema_type = schema_annotations.get("args_schema", None)
if args_schema_type is None or args_schema_type == BaseModel:
# Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully
# specify valid annotations.
typehint_mandate = """
class ChildTool(BaseTool):
...
args_schema: Type[BaseModel] = SchemaClass
..."""
raise SchemaAnnotationError(
f"Tool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n"
f"{typehint_mandate}"
)
# Pass through to Pydantic's metaclass
return super().__new__(cls, name, bases, dct)
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
)
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
"""Interface LangChain tools must implement.""" """Interface LangChain tools must implement."""
name: str name: str
@ -42,9 +121,7 @@ class BaseTool(ABC, BaseModel):
return self.args_schema.schema()["properties"] return self.args_schema.schema()["properties"]
else: else:
inferred_model = validate_arguments(self._run).model # type: ignore inferred_model = validate_arguments(self._run).model # type: ignore
schema = inferred_model.schema()["properties"] return get_filtered_args(inferred_model, self._run)
valid_keys = signature(self._run).parameters
return {k: schema[k] for k in valid_keys}
def _parse_input( def _parse_input(
self, self,

View File

@ -1,12 +1,14 @@
"""Test tool utils.""" """Test tool utils."""
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Optional, Type, Union from typing import Optional, Type, Union
import pydantic
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from langchain.agents.tools import Tool, tool from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool, SchemaAnnotationError
def test_unnamed_decorator() -> None: def test_unnamed_decorator() -> None:
@ -51,10 +53,116 @@ def test_structured_args() -> None:
assert structured_api.run(args) == expected_result assert structured_api.run(args) == expected_result
def test_structured_args_decorator() -> None: def test_unannotated_base_tool_raises_error() -> None:
"""Test functionality with structured arguments parsed as a decorator.""" """Test that a BaseTool without type hints raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _UnAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_misannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _MisAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema: BaseModel = _MockSchema # type: ignore
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_forward_ref_annotated_base_tool_accepted() -> None:
"""Test that a using forward ref annotation syntax is accepted.""" ""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: "Type[BaseModel]" = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_subclass_annotated_base_tool_accepted() -> None:
"""Test BaseTool child w/ custom schema isn't overwritten."""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: Type[_MockSchema] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
tool = _ForwardRefAnnotatedTool()
assert tool.args_schema == _MockSchema
def test_decorator_with_specified_schema() -> None:
"""Test that manually specified schemata are passed through to the tool."""
@tool(args_schema=_MockSchema)
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, Tool)
assert tool_func.args_schema == _MockSchema
def test_decorated_function_schema_equivalent() -> None:
"""Test that a BaseTool without a schema meets expectations."""
@tool @tool
def structured_tool_input(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, Tool)
assert (
structured_tool_input.args_schema.schema()["properties"]
== _MockSchema.schema()["properties"]
== structured_tool_input.args
)
def test_structured_args_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def structured_tool_input( def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str: ) -> str:
@ -68,8 +176,83 @@ def test_structured_args_decorator() -> None:
assert structured_tool_input.run(args) == expected_result assert structured_tool_input.run(args) == expected_result
def test_structured_single_str_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def unstructured_tool_input(tool_input: str) -> str:
"""Return the arguments directly."""
return f"{tool_input}"
assert isinstance(unstructured_tool_input, Tool)
assert unstructured_tool_input.args_schema is None
def test_base_tool_inheritance_base_schema() -> None:
"""Test schema is correctly inferred when inheriting from BaseTool."""
class _MockSimpleTool(BaseTool):
name = "simple_tool"
description = "A Simple Tool"
def _run(self, tool_input: str) -> str:
return f"{tool_input}"
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError
simple_tool = _MockSimpleTool()
assert simple_tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
assert simple_tool.args == expected_args
def test_tool_lambda_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
name="tool",
description="A tool",
func=lambda tool_input: tool_input,
)
assert tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input"}}
assert tool.args == expected_args
def test_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
name="tool",
description="A tool",
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
)
assert tool.args_schema is None
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
}
assert tool.args == expected_args
def test_tool_partial_function_args_schema() -> None:
"""Test args schema inference when the tool argument is a partial function."""
def func(tool_input: str, other_arg: str) -> str:
return tool_input + other_arg
with pytest.raises(pydantic.error_wrappers.ValidationError):
# We don't yet support args_schema inference for partial functions
# so want to make sure we proactively raise an error
Tool(
name="tool",
description="A tool",
func=partial(func, other_arg="foo"),
)
def test_empty_args_decorator() -> None: def test_empty_args_decorator() -> None:
"""Test functionality with no args parsed as a decorator.""" """Test inferred schema of decorated fn with no args."""
@tool @tool
def empty_tool_input() -> str: def empty_tool_input() -> str:
@ -78,6 +261,7 @@ def test_empty_args_decorator() -> None:
assert isinstance(empty_tool_input, Tool) assert isinstance(empty_tool_input, Tool)
assert empty_tool_input.name == "empty_tool_input" assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result" assert empty_tool_input.run({}) == "the empty result"