Re-Permit Partials in `Tool` (#4058)

Resolved issue #4053

Now that StructuredTool is a separate class, this constraint is no
longer needed.

Added/updated a unit test
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent 7e967aa4d5
commit afa9d1292b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,6 @@ from __future__ import annotations
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
@ -14,7 +13,6 @@ from pydantic import (
create_model, create_model,
root_validator, root_validator,
validate_arguments, validate_arguments,
validator,
) )
from pydantic.main import ModelMetaclass from pydantic.main import ModelMetaclass
@ -309,13 +307,6 @@ class Tool(BaseTool):
coroutine: Optional[Callable[..., Awaitable[str]]] = None coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function.""" """The asynchronous version of the function."""
@validator("func", pre=True, always=True)
def validate_func_not_partial(cls, func: Callable) -> Callable:
"""Check that the function is not a partial."""
if isinstance(func, partial):
raise ValueError("Partial functions not yet supported in tools.")
return func
@property @property
def args(self) -> dict: def args(self) -> dict:
"""The tool's input arguments.""" """The tool's input arguments."""

@ -4,7 +4,6 @@ from functools import partial
from typing import Any, Optional, Type, Union from typing import Any, Optional, Type, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pydantic
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
@ -252,14 +251,12 @@ def test_tool_partial_function_args_schema() -> None:
def func(tool_input: str, other_arg: str) -> str: def func(tool_input: str, other_arg: str) -> str:
return tool_input + other_arg return tool_input + other_arg
with pytest.raises(pydantic.error_wrappers.ValidationError): tool = Tool(
# We don't yet support args_schema inference for partial functions name="tool",
# so want to make sure we proactively raise an error description="A tool",
Tool( func=partial(func, other_arg="foo"),
name="tool", )
description="A tool", assert tool.run("bar") == "barfoo"
func=partial(func, other_arg="foo"),
)
def test_empty_args_decorator() -> None: def test_empty_args_decorator() -> None:

Loading…
Cancel
Save