Re-Permit Partials in Tool (#4058)

Resolved issue #4053

Now that StructuredTool is a separate class, this constraint is no
longer needed.

Added/updated a unit test
This commit is contained in:
Zander Chase 2023-05-03 13:16:41 -07:00 committed by GitHub
parent 7e967aa4d5
commit afa9d1292b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 18 deletions

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
@ -14,7 +13,6 @@ from pydantic import (
create_model,
root_validator,
validate_arguments,
validator,
)
from pydantic.main import ModelMetaclass
@ -309,13 +307,6 @@ class Tool(BaseTool):
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
@validator("func", pre=True, always=True)
def validate_func_not_partial(cls, func: Callable) -> Callable:
"""Check that the function is not a partial."""
if isinstance(func, partial):
raise ValueError("Partial functions not yet supported in tools.")
return func
@property
def args(self) -> dict:
"""The tool's input arguments."""

View File

@ -4,7 +4,6 @@ from functools import partial
from typing import Any, Optional, Type, Union
from unittest.mock import MagicMock
import pydantic
import pytest
from pydantic import BaseModel
@ -252,14 +251,12 @@ def test_tool_partial_function_args_schema() -> None:
def func(tool_input: str, other_arg: str) -> str:
return tool_input + other_arg
with pytest.raises(pydantic.error_wrappers.ValidationError):
# We don't yet support args_schema inference for partial functions
# so want to make sure we proactively raise an error
Tool(
name="tool",
description="A tool",
func=partial(func, other_arg="foo"),
)
tool = Tool(
name="tool",
description="A tool",
func=partial(func, other_arg="foo"),
)
assert tool.run("bar") == "barfoo"
def test_empty_args_decorator() -> None: