mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
fix: remove callbacks arg from Tool and StructuredTool inferred schema (#6483)
Fixes #5456 This PR removes the `callbacks` argument from a tool's schema when creating a `Tool` or `StructuredTool` with the `from_function` method and `infer_schema` is set to `True`. The `callbacks` argument is now removed in the `create_schema_from_function` and `_get_filtered_args` methods. As suggested by @vowelparrot, this fix provides a straightforward solution that minimally affects the existing implementation. A test was added to verify that this change enables the expected use of `Tool` and `StructuredTool` when using a `CallbackManager` and inferring the tool's schema. - @hwchase17
This commit is contained in:
parent
b4fe7f3a09
commit
980c865174
@ -82,7 +82,7 @@ def _get_filtered_args(
|
|||||||
"""Get the arguments from a function's signature."""
|
"""Get the arguments from a function's signature."""
|
||||||
schema = inferred_model.schema()["properties"]
|
schema = inferred_model.schema()["properties"]
|
||||||
valid_keys = signature(func).parameters
|
valid_keys = signature(func).parameters
|
||||||
return {k: schema[k] for k in valid_keys if k != "run_manager"}
|
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
|
||||||
|
|
||||||
|
|
||||||
class _SchemaConfig:
|
class _SchemaConfig:
|
||||||
@ -108,6 +108,8 @@ def create_schema_from_function(
|
|||||||
inferred_model = validated.model # type: ignore
|
inferred_model = validated.model # type: ignore
|
||||||
if "run_manager" in inferred_model.__fields__:
|
if "run_manager" in inferred_model.__fields__:
|
||||||
del inferred_model.__fields__["run_manager"]
|
del inferred_model.__fields__["run_manager"]
|
||||||
|
if "callbacks" in inferred_model.__fields__:
|
||||||
|
del inferred_model.__fields__["callbacks"]
|
||||||
# Pydantic adds placeholder virtual fields we need to strip
|
# Pydantic adds placeholder virtual fields we need to strip
|
||||||
valid_properties = _get_filtered_args(inferred_model, func)
|
valid_properties = _get_filtered_args(inferred_model, func)
|
||||||
return _create_subset_model(
|
return _create_subset_model(
|
||||||
|
@ -19,6 +19,7 @@ from langchain.tools.base import (
|
|||||||
StructuredTool,
|
StructuredTool,
|
||||||
ToolException,
|
ToolException,
|
||||||
)
|
)
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
def test_unnamed_decorator() -> None:
|
def test_unnamed_decorator() -> None:
|
||||||
@ -393,6 +394,64 @@ def test_empty_args_decorator() -> None:
|
|||||||
assert empty_tool_input.run({}) == "the empty result"
|
assert empty_tool_input.run({}) == "the empty result"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_from_function_with_run_manager() -> None:
|
||||||
|
"""Test run of tool when using run_manager."""
|
||||||
|
|
||||||
|
def foo(bar: str, callbacks: Optional[CallbackManagerForToolRun] = None) -> str:
|
||||||
|
"""Docstring
|
||||||
|
Args:
|
||||||
|
bar: str
|
||||||
|
"""
|
||||||
|
assert callbacks is not None
|
||||||
|
return "foo" + bar
|
||||||
|
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
tool = Tool.from_function(foo, name="foo", description="Docstring")
|
||||||
|
|
||||||
|
assert tool.run(tool_input={"bar": "bar"}, run_manager=[handler]) == "foobar"
|
||||||
|
assert tool.run("baz", run_manager=[handler]) == "foobaz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_from_function_with_run_manager() -> None:
|
||||||
|
"""Test args and schema of structured tool when using callbacks."""
|
||||||
|
|
||||||
|
def foo(
|
||||||
|
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
|
||||||
|
) -> str:
|
||||||
|
"""Docstring
|
||||||
|
Args:
|
||||||
|
bar: int
|
||||||
|
baz: str
|
||||||
|
"""
|
||||||
|
assert callbacks is not None
|
||||||
|
return str(bar) + baz
|
||||||
|
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
structured_tool = StructuredTool.from_function(foo)
|
||||||
|
|
||||||
|
assert structured_tool.args == {
|
||||||
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
|
"baz": {"title": "Baz", "type": "string"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert structured_tool.args_schema.schema() == {
|
||||||
|
"properties": {
|
||||||
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
|
"baz": {"title": "Baz", "type": "string"},
|
||||||
|
},
|
||||||
|
"title": "fooSchemaSchema",
|
||||||
|
"type": "object",
|
||||||
|
"required": ["bar", "baz"],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert (
|
||||||
|
structured_tool.run(
|
||||||
|
tool_input={"bar": "10", "baz": "baz"}, run_manger=[handler]
|
||||||
|
)
|
||||||
|
== "10baz"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_named_tool_decorator() -> None:
|
def test_named_tool_decorator() -> None:
|
||||||
"""Test functionality when arguments are provided as input to decorator."""
|
"""Test functionality when arguments are provided as input to decorator."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user