fix: remove callbacks arg from Tool and StructuredTool inferred schema (#6483)

Fixes #5456 

This PR removes the `callbacks` argument from a tool's schema when
creating a `Tool` or `StructuredTool` with the `from_function` method
and `infer_schema` is set to `True`. The `callbacks` argument is now
removed in the `create_schema_from_function` and `_get_filtered_args`
methods. As suggested by @vowelparrot, this fix provides a
straightforward solution that minimally affects the existing
implementation.

A test was added to verify that this change enables the expected use of
`Tool` and `StructuredTool` when using a `CallbackManager` and inferring
the tool's schema.

  - @hwchase17
This commit is contained in:
Alejandra De Luna 2023-06-23 04:48:27 -04:00 committed by GitHub
parent b4fe7f3a09
commit 980c865174
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 1 deletions

View File

@ -82,7 +82,7 @@ def _get_filtered_args(
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys if k != "run_manager"}
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
class _SchemaConfig:
@ -108,6 +108,8 @@ def create_schema_from_function(
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
if "callbacks" in inferred_model.__fields__:
del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
return _create_subset_model(

View File

@ -19,6 +19,7 @@ from langchain.tools.base import (
StructuredTool,
ToolException,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_unnamed_decorator() -> None:
@ -393,6 +394,64 @@ def test_empty_args_decorator() -> None:
assert empty_tool_input.run({}) == "the empty result"
def test_tool_from_function_with_run_manager() -> None:
"""Test run of tool when using run_manager."""
def foo(bar: str, callbacks: Optional[CallbackManagerForToolRun] = None) -> str:
"""Docstring
Args:
bar: str
"""
assert callbacks is not None
return "foo" + bar
handler = FakeCallbackHandler()
tool = Tool.from_function(foo, name="foo", description="Docstring")
assert tool.run(tool_input={"bar": "bar"}, run_manager=[handler]) == "foobar"
assert tool.run("baz", run_manager=[handler]) == "foobaz"
def test_structured_tool_from_function_with_run_manager() -> None:
"""Test args and schema of structured tool when using callbacks."""
def foo(
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
) -> str:
"""Docstring
Args:
bar: int
baz: str
"""
assert callbacks is not None
return str(bar) + baz
handler = FakeCallbackHandler()
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
}
assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"title": "fooSchemaSchema",
"type": "object",
"required": ["bar", "baz"],
}
assert (
structured_tool.run(
tool_input={"bar": "10", "baz": "baz"}, run_manger=[handler]
)
== "10baz"
)
def test_named_tool_decorator() -> None:
"""Test functionality when arguments are provided as input to decorator."""