forked from Archives/langchain
Structured Tool Bugfixes (#3324)
- Proactively raise error if a tool subclasses BaseTool, defines its own schema, but fails to add the type-hints - fix the auto-inferred schema of the decorator to strip the unneeded virtual kwargs from the schema dict Helps avoid silent instances of #3297
This commit is contained in:
parent
f22b9d0e57
commit
49122a96e7
@ -1,10 +1,15 @@
|
|||||||
"""Interface for tools."""
|
"""Interface for tools."""
|
||||||
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Awaitable, Callable, Optional, Type, Union
|
from typing import Any, Awaitable, Callable, Optional, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, validate_arguments
|
from pydantic import BaseModel, validate_arguments, validator
|
||||||
|
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import (
|
||||||
|
BaseTool,
|
||||||
|
create_schema_from_function,
|
||||||
|
get_filtered_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseTool):
|
class Tool(BaseTool):
|
||||||
@ -16,15 +21,20 @@ class Tool(BaseTool):
|
|||||||
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
||||||
"""The asynchronous version of the function."""
|
"""The asynchronous version of the function."""
|
||||||
|
|
||||||
|
@validator("func", pre=True, always=True)
|
||||||
|
def validate_func_not_partial(cls, func: Callable) -> Callable:
|
||||||
|
"""Check that the function is not a partial."""
|
||||||
|
if isinstance(func, partial):
|
||||||
|
raise ValueError("Partial functions not yet supported in tools.")
|
||||||
|
return func
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> dict:
|
def args(self) -> dict:
|
||||||
if self.args_schema is not None:
|
if self.args_schema is not None:
|
||||||
return self.args_schema.schema()["properties"]
|
return self.args_schema.schema()["properties"]
|
||||||
else:
|
else:
|
||||||
inferred_model = validate_arguments(self.func).model # type: ignore
|
inferred_model = validate_arguments(self.func).model # type: ignore
|
||||||
schema = inferred_model.schema()["properties"]
|
return get_filtered_args(inferred_model, self.func)
|
||||||
valid_keys = signature(self.func).parameters
|
|
||||||
return {k: schema[k] for k in valid_keys}
|
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
@ -104,7 +114,7 @@ def tool(
|
|||||||
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
||||||
_args_schema = args_schema
|
_args_schema = args_schema
|
||||||
if _args_schema is None and infer_schema:
|
if _args_schema is None and infer_schema:
|
||||||
_args_schema = validate_arguments(func).model # type: ignore
|
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
|
||||||
tool_ = Tool(
|
tool_ = Tool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
func=func,
|
func=func,
|
||||||
|
@ -1,10 +1,19 @@
|
|||||||
"""Base implementation for tools or skills."""
|
"""Base implementation for tools or skills."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, validate_arguments, validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
Extra,
|
||||||
|
Field,
|
||||||
|
create_model,
|
||||||
|
validate_arguments,
|
||||||
|
validator,
|
||||||
|
)
|
||||||
|
from pydantic.main import ModelMetaclass
|
||||||
|
|
||||||
from langchain.callbacks import get_callback_manager
|
from langchain.callbacks import get_callback_manager
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
@ -19,7 +28,77 @@ def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
|
|||||||
return [], run_input
|
return [], run_input
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(ABC, BaseModel):
|
class SchemaAnnotationError(TypeError):
|
||||||
|
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMetaclass(ModelMetaclass):
|
||||||
|
"""Metaclass for BaseTool to ensure the provided args_schema
|
||||||
|
|
||||||
|
doesn't silently ignored."""
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
|
||||||
|
) -> ToolMetaclass:
|
||||||
|
"""Create the definition of the new tool class."""
|
||||||
|
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
|
||||||
|
if schema_type is not None:
|
||||||
|
schema_annotations = dct.get("__annotations__", {})
|
||||||
|
args_schema_type = schema_annotations.get("args_schema", None)
|
||||||
|
if args_schema_type is None or args_schema_type == BaseModel:
|
||||||
|
# Throw errors for common mis-annotations.
|
||||||
|
# TODO: Use get_args / get_origin and fully
|
||||||
|
# specify valid annotations.
|
||||||
|
typehint_mandate = """
|
||||||
|
class ChildTool(BaseTool):
|
||||||
|
...
|
||||||
|
args_schema: Type[BaseModel] = SchemaClass
|
||||||
|
..."""
|
||||||
|
raise SchemaAnnotationError(
|
||||||
|
f"Tool definition for {name} must include valid type annotations"
|
||||||
|
f" for argument 'args_schema' to behave as expected.\n"
|
||||||
|
f"Expected annotation of 'Type[BaseModel]'"
|
||||||
|
f" but got '{args_schema_type}'.\n"
|
||||||
|
f"Expected class looks like:\n"
|
||||||
|
f"{typehint_mandate}"
|
||||||
|
)
|
||||||
|
# Pass through to Pydantic's metaclass
|
||||||
|
return super().__new__(cls, name, bases, dct)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_subset_model(
|
||||||
|
name: str, model: BaseModel, field_names: list
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
"""Create a pydantic model with only a subset of model's fields."""
|
||||||
|
fields = {
|
||||||
|
field_name: (
|
||||||
|
model.__fields__[field_name].type_,
|
||||||
|
model.__fields__[field_name].default,
|
||||||
|
)
|
||||||
|
for field_name in field_names
|
||||||
|
if field_name in model.__fields__
|
||||||
|
}
|
||||||
|
return create_model(name, **fields) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
||||||
|
"""Get the arguments from a function's signature."""
|
||||||
|
schema = inferred_model.schema()["properties"]
|
||||||
|
valid_keys = signature(func).parameters
|
||||||
|
return {k: schema[k] for k in valid_keys}
|
||||||
|
|
||||||
|
|
||||||
|
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
||||||
|
"""Create a pydantic schema from a function's signature."""
|
||||||
|
inferred_model = validate_arguments(func).model # type: ignore
|
||||||
|
# Pydantic adds placeholder virtual fields we need to strip
|
||||||
|
filtered_args = get_filtered_args(inferred_model, func)
|
||||||
|
return _create_subset_model(
|
||||||
|
f"{model_name}Schema", inferred_model, list(filtered_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||||
"""Interface LangChain tools must implement."""
|
"""Interface LangChain tools must implement."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@ -42,9 +121,7 @@ class BaseTool(ABC, BaseModel):
|
|||||||
return self.args_schema.schema()["properties"]
|
return self.args_schema.schema()["properties"]
|
||||||
else:
|
else:
|
||||||
inferred_model = validate_arguments(self._run).model # type: ignore
|
inferred_model = validate_arguments(self._run).model # type: ignore
|
||||||
schema = inferred_model.schema()["properties"]
|
return get_filtered_args(inferred_model, self._run)
|
||||||
valid_keys = signature(self._run).parameters
|
|
||||||
return {k: schema[k] for k in valid_keys}
|
|
||||||
|
|
||||||
def _parse_input(
|
def _parse_input(
|
||||||
self,
|
self,
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
"""Test tool utils."""
|
"""Test tool utils."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from functools import partial
|
||||||
from typing import Optional, Type, Union
|
from typing import Optional, Type, Union
|
||||||
|
|
||||||
|
import pydantic
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.agents.tools import Tool, tool
|
from langchain.agents.tools import Tool, tool
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool, SchemaAnnotationError
|
||||||
|
|
||||||
|
|
||||||
def test_unnamed_decorator() -> None:
|
def test_unnamed_decorator() -> None:
|
||||||
@ -51,10 +53,116 @@ def test_structured_args() -> None:
|
|||||||
assert structured_api.run(args) == expected_result
|
assert structured_api.run(args) == expected_result
|
||||||
|
|
||||||
|
|
||||||
def test_structured_args_decorator() -> None:
|
def test_unannotated_base_tool_raises_error() -> None:
|
||||||
"""Test functionality with structured arguments parsed as a decorator."""
|
"""Test that a BaseTool without type hints raises an exception.""" ""
|
||||||
|
with pytest.raises(SchemaAnnotationError):
|
||||||
|
|
||||||
|
class _UnAnnotatedTool(BaseTool):
|
||||||
|
name = "structured_api"
|
||||||
|
# This would silently be ignored without the custom metaclass
|
||||||
|
args_schema = _MockSchema
|
||||||
|
description = "A Structured Tool"
|
||||||
|
|
||||||
|
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
|
) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def test_misannotated_base_tool_raises_error() -> None:
|
||||||
|
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
|
||||||
|
with pytest.raises(SchemaAnnotationError):
|
||||||
|
|
||||||
|
class _MisAnnotatedTool(BaseTool):
|
||||||
|
name = "structured_api"
|
||||||
|
# This would silently be ignored without the custom metaclass
|
||||||
|
args_schema: BaseModel = _MockSchema # type: ignore
|
||||||
|
description = "A Structured Tool"
|
||||||
|
|
||||||
|
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
|
) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_ref_annotated_base_tool_accepted() -> None:
|
||||||
|
"""Test that a using forward ref annotation syntax is accepted.""" ""
|
||||||
|
|
||||||
|
class _ForwardRefAnnotatedTool(BaseTool):
|
||||||
|
name = "structured_api"
|
||||||
|
args_schema: "Type[BaseModel]" = _MockSchema
|
||||||
|
description = "A Structured Tool"
|
||||||
|
|
||||||
|
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
|
) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def test_subclass_annotated_base_tool_accepted() -> None:
|
||||||
|
"""Test BaseTool child w/ custom schema isn't overwritten."""
|
||||||
|
|
||||||
|
class _ForwardRefAnnotatedTool(BaseTool):
|
||||||
|
name = "structured_api"
|
||||||
|
args_schema: Type[_MockSchema] = _MockSchema
|
||||||
|
description = "A Structured Tool"
|
||||||
|
|
||||||
|
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
|
) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
|
||||||
|
tool = _ForwardRefAnnotatedTool()
|
||||||
|
assert tool.args_schema == _MockSchema
|
||||||
|
|
||||||
|
|
||||||
|
def test_decorator_with_specified_schema() -> None:
|
||||||
|
"""Test that manually specified schemata are passed through to the tool."""
|
||||||
|
|
||||||
|
@tool(args_schema=_MockSchema)
|
||||||
|
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
|
"""Return the arguments directly."""
|
||||||
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
|
assert isinstance(tool_func, Tool)
|
||||||
|
assert tool_func.args_schema == _MockSchema
|
||||||
|
|
||||||
|
|
||||||
|
def test_decorated_function_schema_equivalent() -> None:
|
||||||
|
"""Test that a BaseTool without a schema meets expectations."""
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
def structured_tool_input(
|
||||||
|
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
|
) -> str:
|
||||||
|
"""Return the arguments directly."""
|
||||||
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
|
assert isinstance(structured_tool_input, Tool)
|
||||||
|
assert (
|
||||||
|
structured_tool_input.args_schema.schema()["properties"]
|
||||||
|
== _MockSchema.schema()["properties"]
|
||||||
|
== structured_tool_input.args
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_args_decorator_no_infer_schema() -> None:
|
||||||
|
"""Test functionality with structured arguments parsed as a decorator."""
|
||||||
|
|
||||||
|
@tool(infer_schema=False)
|
||||||
def structured_tool_input(
|
def structured_tool_input(
|
||||||
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -68,8 +176,83 @@ def test_structured_args_decorator() -> None:
|
|||||||
assert structured_tool_input.run(args) == expected_result
|
assert structured_tool_input.run(args) == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_single_str_decorator_no_infer_schema() -> None:
|
||||||
|
"""Test functionality with structured arguments parsed as a decorator."""
|
||||||
|
|
||||||
|
@tool(infer_schema=False)
|
||||||
|
def unstructured_tool_input(tool_input: str) -> str:
|
||||||
|
"""Return the arguments directly."""
|
||||||
|
return f"{tool_input}"
|
||||||
|
|
||||||
|
assert isinstance(unstructured_tool_input, Tool)
|
||||||
|
assert unstructured_tool_input.args_schema is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_tool_inheritance_base_schema() -> None:
|
||||||
|
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
||||||
|
|
||||||
|
class _MockSimpleTool(BaseTool):
|
||||||
|
name = "simple_tool"
|
||||||
|
description = "A Simple Tool"
|
||||||
|
|
||||||
|
def _run(self, tool_input: str) -> str:
|
||||||
|
return f"{tool_input}"
|
||||||
|
|
||||||
|
async def _arun(self, tool_input: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
simple_tool = _MockSimpleTool()
|
||||||
|
assert simple_tool.args_schema is None
|
||||||
|
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
|
||||||
|
assert simple_tool.args == expected_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_lambda_args_schema() -> None:
|
||||||
|
"""Test args schema inference when the tool argument is a lambda function."""
|
||||||
|
|
||||||
|
tool = Tool(
|
||||||
|
name="tool",
|
||||||
|
description="A tool",
|
||||||
|
func=lambda tool_input: tool_input,
|
||||||
|
)
|
||||||
|
assert tool.args_schema is None
|
||||||
|
expected_args = {"tool_input": {"title": "Tool Input"}}
|
||||||
|
assert tool.args == expected_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_lambda_multi_args_schema() -> None:
|
||||||
|
"""Test args schema inference when the tool argument is a lambda function."""
|
||||||
|
tool = Tool(
|
||||||
|
name="tool",
|
||||||
|
description="A tool",
|
||||||
|
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
||||||
|
)
|
||||||
|
assert tool.args_schema is None
|
||||||
|
expected_args = {
|
||||||
|
"tool_input": {"title": "Tool Input"},
|
||||||
|
"other_arg": {"title": "Other Arg"},
|
||||||
|
}
|
||||||
|
assert tool.args == expected_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_partial_function_args_schema() -> None:
|
||||||
|
"""Test args schema inference when the tool argument is a partial function."""
|
||||||
|
|
||||||
|
def func(tool_input: str, other_arg: str) -> str:
|
||||||
|
return tool_input + other_arg
|
||||||
|
|
||||||
|
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||||
|
# We don't yet support args_schema inference for partial functions
|
||||||
|
# so want to make sure we proactively raise an error
|
||||||
|
Tool(
|
||||||
|
name="tool",
|
||||||
|
description="A tool",
|
||||||
|
func=partial(func, other_arg="foo"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_empty_args_decorator() -> None:
|
def test_empty_args_decorator() -> None:
|
||||||
"""Test functionality with no args parsed as a decorator."""
|
"""Test inferred schema of decorated fn with no args."""
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def empty_tool_input() -> str:
|
def empty_tool_input() -> str:
|
||||||
@ -78,6 +261,7 @@ def test_empty_args_decorator() -> None:
|
|||||||
|
|
||||||
assert isinstance(empty_tool_input, Tool)
|
assert isinstance(empty_tool_input, Tool)
|
||||||
assert empty_tool_input.name == "empty_tool_input"
|
assert empty_tool_input.name == "empty_tool_input"
|
||||||
|
assert empty_tool_input.args == {}
|
||||||
assert empty_tool_input.run({}) == "the empty result"
|
assert empty_tool_input.run({}) == "the empty result"
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user