diff --git a/langchain/tools/base.py b/langchain/tools/base.py index c3fab242c9..a8d4c88863 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -82,7 +82,7 @@ def _get_filtered_args( """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys if k != "run_manager"} + return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} class _SchemaConfig: @@ -108,6 +108,8 @@ def create_schema_from_function( inferred_model = validated.model # type: ignore if "run_manager" in inferred_model.__fields__: del inferred_model.__fields__["run_manager"] + if "callbacks" in inferred_model.__fields__: + del inferred_model.__fields__["callbacks"] # Pydantic adds placeholder virtual fields we need to strip valid_properties = _get_filtered_args(inferred_model, func) return _create_subset_model( diff --git a/tests/unit_tests/tools/test_base.py b/tests/unit_tests/tools/test_base.py index e486fcc9d0..eadfbcf97c 100644 --- a/tests/unit_tests/tools/test_base.py +++ b/tests/unit_tests/tools/test_base.py @@ -19,6 +19,7 @@ from langchain.tools.base import ( StructuredTool, ToolException, ) +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler def test_unnamed_decorator() -> None: @@ -393,6 +394,64 @@ def test_empty_args_decorator() -> None: assert empty_tool_input.run({}) == "the empty result" +def test_tool_from_function_with_run_manager() -> None: + """Test run of tool when using run_manager.""" + + def foo(bar: str, callbacks: Optional[CallbackManagerForToolRun] = None) -> str: + """Docstring + Args: + bar: str + """ + assert callbacks is not None + return "foo" + bar + + handler = FakeCallbackHandler() + tool = Tool.from_function(foo, name="foo", description="Docstring") + + assert tool.run(tool_input={"bar": "bar"}, run_manager=[handler]) == "foobar" + assert tool.run("baz", run_manager=[handler]) == "foobaz" + + +def test_structured_tool_from_function_with_run_manager() -> None: + """Test args and schema of structured tool when using callbacks.""" + + def foo( + bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None + ) -> str: + """Docstring + Args: + bar: int + baz: str + """ + assert callbacks is not None + return str(bar) + baz + + handler = FakeCallbackHandler() + structured_tool = StructuredTool.from_function(foo) + + assert structured_tool.args == { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "string"}, + } + + assert structured_tool.args_schema.schema() == { + "properties": { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "string"}, + }, + "title": "fooSchemaSchema", + "type": "object", + "required": ["bar", "baz"], + } + + assert ( + structured_tool.run( + tool_input={"bar": "10", "baz": "baz"}, run_manger=[handler] + ) + == "10baz" + ) + + def test_named_tool_decorator() -> None: """Test functionality when arguments are provided as input to decorator."""