diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index c5aa094e..913110ba 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -1,10 +1,15 @@ """Interface for tools.""" +from functools import partial from inspect import signature from typing import Any, Awaitable, Callable, Optional, Type, Union -from pydantic import BaseModel, validate_arguments +from pydantic import BaseModel, validate_arguments, validator -from langchain.tools.base import BaseTool +from langchain.tools.base import ( + BaseTool, + create_schema_from_function, + get_filtered_args, +) class Tool(BaseTool): @@ -16,15 +21,20 @@ class Tool(BaseTool): coroutine: Optional[Callable[..., Awaitable[str]]] = None """The asynchronous version of the function.""" + @validator("func", pre=True, always=True) + def validate_func_not_partial(cls, func: Callable) -> Callable: + """Check that the function is not a partial.""" + if isinstance(func, partial): + raise ValueError("Partial functions not yet supported in tools.") + return func + @property def args(self) -> dict: if self.args_schema is not None: return self.args_schema.schema()["properties"] else: inferred_model = validate_arguments(self.func).model # type: ignore - schema = inferred_model.schema()["properties"] - valid_keys = signature(self.func).parameters - return {k: schema[k] for k in valid_keys} + return get_filtered_args(inferred_model, self.func) def _run(self, *args: Any, **kwargs: Any) -> str: """Use the tool.""" @@ -104,7 +114,7 @@ def tool( description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: - _args_schema = validate_arguments(func).model # type: ignore + _args_schema = create_schema_from_function(f"{tool_name}Schema", func) tool_ = Tool( name=tool_name, func=func, diff --git a/langchain/tools/base.py b/langchain/tools/base.py index bc173866..54d8db5d 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -1,10 +1,19 @@ """Base implementation for tools or skills.""" +from __future__ import annotations from abc import ABC, abstractmethod from inspect import signature -from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union - -from pydantic import BaseModel, Extra, Field, validate_arguments, validator +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union + +from pydantic import ( + BaseModel, + Extra, + Field, + create_model, + validate_arguments, + validator, +) +from pydantic.main import ModelMetaclass from langchain.callbacks import get_callback_manager from langchain.callbacks.base import BaseCallbackManager @@ -19,7 +28,77 @@ def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]: return [], run_input -class BaseTool(ABC, BaseModel): +class SchemaAnnotationError(TypeError): + """Raised when 'args_schema' is missing or has an incorrect type annotation.""" + + +class ToolMetaclass(ModelMetaclass): + """Metaclass for BaseTool to ensure the provided args_schema + + doesn't silently ignored.""" + + def __new__( + cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict + ) -> ToolMetaclass: + """Create the definition of the new tool class.""" + schema_type: Optional[Type[BaseModel]] = dct.get("args_schema") + if schema_type is not None: + schema_annotations = dct.get("__annotations__", {}) + args_schema_type = schema_annotations.get("args_schema", None) + if args_schema_type is None or args_schema_type == BaseModel: + # Throw errors for common mis-annotations. + # TODO: Use get_args / get_origin and fully + # specify valid annotations. + typehint_mandate = """ +class ChildTool(BaseTool): + ... + args_schema: Type[BaseModel] = SchemaClass + ...""" + raise SchemaAnnotationError( + f"Tool definition for {name} must include valid type annotations" + f" for argument 'args_schema' to behave as expected.\n" + f"Expected annotation of 'Type[BaseModel]'" + f" but got '{args_schema_type}'.\n" + f"Expected class looks like:\n" + f"{typehint_mandate}" + ) + # Pass through to Pydantic's metaclass + return super().__new__(cls, name, bases, dct) + + +def _create_subset_model( + name: str, model: BaseModel, field_names: list +) -> Type[BaseModel]: + """Create a pydantic model with only a subset of model's fields.""" + fields = { + field_name: ( + model.__fields__[field_name].type_, + model.__fields__[field_name].default, + ) + for field_name in field_names + if field_name in model.__fields__ + } + return create_model(name, **fields) # type: ignore + + +def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict: + """Get the arguments from a function's signature.""" + schema = inferred_model.schema()["properties"] + valid_keys = signature(func).parameters + return {k: schema[k] for k in valid_keys} + + +def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]: + """Create a pydantic schema from a function's signature.""" + inferred_model = validate_arguments(func).model # type: ignore + # Pydantic adds placeholder virtual fields we need to strip + filtered_args = get_filtered_args(inferred_model, func) + return _create_subset_model( + f"{model_name}Schema", inferred_model, list(filtered_args) + ) + + +class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): """Interface LangChain tools must implement.""" name: str @@ -42,9 +121,7 @@ class BaseTool(ABC, BaseModel): return self.args_schema.schema()["properties"] else: inferred_model = validate_arguments(self._run).model # type: ignore - schema = inferred_model.schema()["properties"] - valid_keys = signature(self._run).parameters - return {k: schema[k] for k in valid_keys} + return get_filtered_args(inferred_model, self._run) def _parse_input( self, diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index cdb7929d..d76011dc 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -1,12 +1,14 @@ """Test tool utils.""" from datetime import datetime +from functools import partial from typing import Optional, Type, Union +import pydantic import pytest from pydantic import BaseModel from langchain.agents.tools import Tool, tool -from langchain.tools.base import BaseTool +from langchain.tools.base import BaseTool, SchemaAnnotationError def test_unnamed_decorator() -> None: @@ -51,10 +53,116 @@ def test_structured_args() -> None: assert structured_api.run(args) == expected_result -def test_structured_args_decorator() -> None: - """Test functionality with structured arguments parsed as a decorator.""" +def test_unannotated_base_tool_raises_error() -> None: + """Test that a BaseTool without type hints raises an exception.""" "" + with pytest.raises(SchemaAnnotationError): + + class _UnAnnotatedTool(BaseTool): + name = "structured_api" + # This would silently be ignored without the custom metaclass + args_schema = _MockSchema + description = "A Structured Tool" + + def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + return f"{arg1} {arg2} {arg3}" + + async def _arun( + self, arg1: int, arg2: bool, arg3: Optional[dict] = None + ) -> str: + raise NotImplementedError + + +def test_misannotated_base_tool_raises_error() -> None: + """Test that a BaseTool with the incorrrect typehint raises an exception.""" "" + with pytest.raises(SchemaAnnotationError): + + class _MisAnnotatedTool(BaseTool): + name = "structured_api" + # This would silently be ignored without the custom metaclass + args_schema: BaseModel = _MockSchema # type: ignore + description = "A Structured Tool" + + def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + return f"{arg1} {arg2} {arg3}" + + async def _arun( + self, arg1: int, arg2: bool, arg3: Optional[dict] = None + ) -> str: + raise NotImplementedError + + +def test_forward_ref_annotated_base_tool_accepted() -> None: + """Test that a using forward ref annotation syntax is accepted.""" "" + + class _ForwardRefAnnotatedTool(BaseTool): + name = "structured_api" + args_schema: "Type[BaseModel]" = _MockSchema + description = "A Structured Tool" + + def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + return f"{arg1} {arg2} {arg3}" + + async def _arun( + self, arg1: int, arg2: bool, arg3: Optional[dict] = None + ) -> str: + raise NotImplementedError + + +def test_subclass_annotated_base_tool_accepted() -> None: + """Test BaseTool child w/ custom schema isn't overwritten.""" + + class _ForwardRefAnnotatedTool(BaseTool): + name = "structured_api" + args_schema: Type[_MockSchema] = _MockSchema + description = "A Structured Tool" + + def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + return f"{arg1} {arg2} {arg3}" + + async def _arun( + self, arg1: int, arg2: bool, arg3: Optional[dict] = None + ) -> str: + raise NotImplementedError + + assert issubclass(_ForwardRefAnnotatedTool, BaseTool) + tool = _ForwardRefAnnotatedTool() + assert tool.args_schema == _MockSchema + + +def test_decorator_with_specified_schema() -> None: + """Test that manually specified schemata are passed through to the tool.""" + + @tool(args_schema=_MockSchema) + def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + """Return the arguments directly.""" + return f"{arg1} {arg2} {arg3}" + + assert isinstance(tool_func, Tool) + assert tool_func.args_schema == _MockSchema + + +def test_decorated_function_schema_equivalent() -> None: + """Test that a BaseTool without a schema meets expectations.""" @tool + def structured_tool_input( + arg1: int, arg2: bool, arg3: Optional[dict] = None + ) -> str: + """Return the arguments directly.""" + return f"{arg1} {arg2} {arg3}" + + assert isinstance(structured_tool_input, Tool) + assert ( + structured_tool_input.args_schema.schema()["properties"] + == _MockSchema.schema()["properties"] + == structured_tool_input.args + ) + + +def test_structured_args_decorator_no_infer_schema() -> None: + """Test functionality with structured arguments parsed as a decorator.""" + + @tool(infer_schema=False) def structured_tool_input( arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None ) -> str: @@ -68,8 +176,83 @@ def test_structured_args_decorator() -> None: assert structured_tool_input.run(args) == expected_result +def test_structured_single_str_decorator_no_infer_schema() -> None: + """Test functionality with structured arguments parsed as a decorator.""" + + @tool(infer_schema=False) + def unstructured_tool_input(tool_input: str) -> str: + """Return the arguments directly.""" + return f"{tool_input}" + + assert isinstance(unstructured_tool_input, Tool) + assert unstructured_tool_input.args_schema is None + + +def test_base_tool_inheritance_base_schema() -> None: + """Test schema is correctly inferred when inheriting from BaseTool.""" + + class _MockSimpleTool(BaseTool): + name = "simple_tool" + description = "A Simple Tool" + + def _run(self, tool_input: str) -> str: + return f"{tool_input}" + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError + + simple_tool = _MockSimpleTool() + assert simple_tool.args_schema is None + expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}} + assert simple_tool.args == expected_args + + +def test_tool_lambda_args_schema() -> None: + """Test args schema inference when the tool argument is a lambda function.""" + + tool = Tool( + name="tool", + description="A tool", + func=lambda tool_input: tool_input, + ) + assert tool.args_schema is None + expected_args = {"tool_input": {"title": "Tool Input"}} + assert tool.args == expected_args + + +def test_tool_lambda_multi_args_schema() -> None: + """Test args schema inference when the tool argument is a lambda function.""" + tool = Tool( + name="tool", + description="A tool", + func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore + ) + assert tool.args_schema is None + expected_args = { + "tool_input": {"title": "Tool Input"}, + "other_arg": {"title": "Other Arg"}, + } + assert tool.args == expected_args + + +def test_tool_partial_function_args_schema() -> None: + """Test args schema inference when the tool argument is a partial function.""" + + def func(tool_input: str, other_arg: str) -> str: + return tool_input + other_arg + + with pytest.raises(pydantic.error_wrappers.ValidationError): + # We don't yet support args_schema inference for partial functions + # so want to make sure we proactively raise an error + Tool( + name="tool", + description="A tool", + func=partial(func, other_arg="foo"), + ) + + def test_empty_args_decorator() -> None: - """Test functionality with no args parsed as a decorator.""" + """Test inferred schema of decorated fn with no args.""" @tool def empty_tool_input() -> str: @@ -78,6 +261,7 @@ def test_empty_args_decorator() -> None: assert isinstance(empty_tool_input, Tool) assert empty_tool_input.name == "empty_tool_input" + assert empty_tool_input.args == {} assert empty_tool_input.run({}) == "the empty result"