From afa9d1292b0a152e36d338dde7b02f0b93bd37d9 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Wed, 3 May 2023 13:16:41 -0700 Subject: [PATCH] Re-Permit Partials in `Tool` (#4058) Resolved issue #4053 Now that StructuredTool is a separate class, this constraint is no longer needed. Added/updated a unit test --- langchain/tools/base.py | 9 --------- tests/unit_tests/agents/test_tools.py | 15 ++++++--------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/langchain/tools/base.py b/langchain/tools/base.py index f231b5ef..5bbbf405 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -3,7 +3,6 @@ from __future__ import annotations import warnings from abc import ABC, abstractmethod -from functools import partial from inspect import signature from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union @@ -14,7 +13,6 @@ from pydantic import ( create_model, root_validator, validate_arguments, - validator, ) from pydantic.main import ModelMetaclass @@ -309,13 +307,6 @@ class Tool(BaseTool): coroutine: Optional[Callable[..., Awaitable[str]]] = None """The asynchronous version of the function.""" - @validator("func", pre=True, always=True) - def validate_func_not_partial(cls, func: Callable) -> Callable: - """Check that the function is not a partial.""" - if isinstance(func, partial): - raise ValueError("Partial functions not yet supported in tools.") - return func - @property def args(self) -> dict: """The tool's input arguments.""" diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 094774fb..8001e5c8 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -4,7 +4,6 @@ from functools import partial from typing import Any, Optional, Type, Union from unittest.mock import MagicMock -import pydantic import pytest from pydantic import BaseModel @@ -252,14 +251,12 @@ def test_tool_partial_function_args_schema() -> None: def func(tool_input: str, other_arg: str) -> str: return tool_input + other_arg - with pytest.raises(pydantic.error_wrappers.ValidationError): - # We don't yet support args_schema inference for partial functions - # so want to make sure we proactively raise an error - Tool( - name="tool", - description="A tool", - func=partial(func, other_arg="foo"), - ) + tool = Tool( + name="tool", + description="A tool", + func=partial(func, other_arg="foo"), + ) + assert tool.run("bar") == "barfoo" def test_empty_args_decorator() -> None: