forked from Archives/langchain
Re-Permit Partials in Tool
(#4058)
Resolved issue #4053 Now that StructuredTool is a separate class, this constraint is no longer needed. Added/updated a unit test
This commit is contained in:
parent
7e967aa4d5
commit
afa9d1292b
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
@ -14,7 +13,6 @@ from pydantic import (
|
|||||||
create_model,
|
create_model,
|
||||||
root_validator,
|
root_validator,
|
||||||
validate_arguments,
|
validate_arguments,
|
||||||
validator,
|
|
||||||
)
|
)
|
||||||
from pydantic.main import ModelMetaclass
|
from pydantic.main import ModelMetaclass
|
||||||
|
|
||||||
@ -309,13 +307,6 @@ class Tool(BaseTool):
|
|||||||
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
||||||
"""The asynchronous version of the function."""
|
"""The asynchronous version of the function."""
|
||||||
|
|
||||||
@validator("func", pre=True, always=True)
|
|
||||||
def validate_func_not_partial(cls, func: Callable) -> Callable:
|
|
||||||
"""Check that the function is not a partial."""
|
|
||||||
if isinstance(func, partial):
|
|
||||||
raise ValueError("Partial functions not yet supported in tools.")
|
|
||||||
return func
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> dict:
|
def args(self) -> dict:
|
||||||
"""The tool's input arguments."""
|
"""The tool's input arguments."""
|
||||||
|
@ -4,7 +4,6 @@ from functools import partial
|
|||||||
from typing import Any, Optional, Type, Union
|
from typing import Any, Optional, Type, Union
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pydantic
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -252,14 +251,12 @@ def test_tool_partial_function_args_schema() -> None:
|
|||||||
def func(tool_input: str, other_arg: str) -> str:
|
def func(tool_input: str, other_arg: str) -> str:
|
||||||
return tool_input + other_arg
|
return tool_input + other_arg
|
||||||
|
|
||||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
tool = Tool(
|
||||||
# We don't yet support args_schema inference for partial functions
|
name="tool",
|
||||||
# so want to make sure we proactively raise an error
|
description="A tool",
|
||||||
Tool(
|
func=partial(func, other_arg="foo"),
|
||||||
name="tool",
|
)
|
||||||
description="A tool",
|
assert tool.run("bar") == "barfoo"
|
||||||
func=partial(func, other_arg="foo"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_args_decorator() -> None:
|
def test_empty_args_decorator() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user