forked from Archives/langchain
Re-Permit Partials in Tool
(#4058)
Resolved issue #4053 Now that StructuredTool is a separate class, this constraint is no longer needed. Added/updated a unit test
This commit is contained in:
parent
7e967aa4d5
commit
afa9d1292b
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
@ -14,7 +13,6 @@ from pydantic import (
|
||||
create_model,
|
||||
root_validator,
|
||||
validate_arguments,
|
||||
validator,
|
||||
)
|
||||
from pydantic.main import ModelMetaclass
|
||||
|
||||
@ -309,13 +307,6 @@ class Tool(BaseTool):
|
||||
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
||||
"""The asynchronous version of the function."""
|
||||
|
||||
@validator("func", pre=True, always=True)
|
||||
def validate_func_not_partial(cls, func: Callable) -> Callable:
|
||||
"""Check that the function is not a partial."""
|
||||
if isinstance(func, partial):
|
||||
raise ValueError("Partial functions not yet supported in tools.")
|
||||
return func
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""The tool's input arguments."""
|
||||
|
@ -4,7 +4,6 @@ from functools import partial
|
||||
from typing import Any, Optional, Type, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -252,14 +251,12 @@ def test_tool_partial_function_args_schema() -> None:
|
||||
def func(tool_input: str, other_arg: str) -> str:
|
||||
return tool_input + other_arg
|
||||
|
||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||
# We don't yet support args_schema inference for partial functions
|
||||
# so want to make sure we proactively raise an error
|
||||
Tool(
|
||||
tool = Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=partial(func, other_arg="foo"),
|
||||
)
|
||||
assert tool.run("bar") == "barfoo"
|
||||
|
||||
|
||||
def test_empty_args_decorator() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user