From ec1a2adf9ce4f310b53f90602ead72c1f2483f60 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 12 Jun 2023 22:19:03 -0700 Subject: [PATCH] improve tools (#6062) --- langchain/tools/base.py | 51 ++++++++++++++++++++++------- tests/unit_tests/tools/test_base.py | 38 +++++++++++++++++++-- 2 files changed, 75 insertions(+), 14 deletions(-) diff --git a/langchain/tools/base.py b/langchain/tools/base.py index ea69731b67..ca3bebb041 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -68,18 +68,14 @@ def _create_subset_model( name: str, model: BaseModel, field_names: list ) -> Type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" - fields = { - field_name: ( - model.__fields__[field_name].type_, - model.__fields__[field_name].default, - ) - for field_name in field_names - if field_name in model.__fields__ - } + fields = {} + for field_name in field_names: + field = model.__fields__[field_name] + fields[field_name] = (field.type_, field.field_info) return create_model(name, **fields) # type: ignore -def get_filtered_args( +def _get_filtered_args( inferred_model: Type[BaseModel], func: Callable, ) -> dict: @@ -100,15 +96,22 @@ def create_schema_from_function( model_name: str, func: Callable, ) -> Type[BaseModel]: - """Create a pydantic schema from a function's signature.""" + """Create a pydantic schema from a function's signature. + Args: + model_name: Name to assign to the generated pydandic schema + func: Function to generate the schema from + Returns: + A pydantic model with the same arguments as the function + """ + # https://docs.pydantic.dev/latest/usage/validation_decorator/ validated = validate_arguments(func, config=_SchemaConfig) # type: ignore inferred_model = validated.model # type: ignore if "run_manager" in inferred_model.__fields__: del inferred_model.__fields__["run_manager"] # Pydantic adds placeholder virtual fields we need to strip - filtered_args = get_filtered_args(inferred_model, func) + valid_properties = _get_filtered_args(inferred_model, func) return _create_subset_model( - f"{model_name}Schema", inferred_model, list(filtered_args) + f"{model_name}Schema", inferred_model, list(valid_properties) ) @@ -534,6 +537,30 @@ class StructuredTool(BaseTool): infer_schema: bool = True, **kwargs: Any, ) -> StructuredTool: + """Create tool from a given function. + + A classmethod that helps to create a tool from a function. + + Args: + func: The function from which to create a tool + name: The name of the tool. Defaults to the function name + description: The description of the tool. Defaults to the function docstring + return_direct: Whether to return the result directly or as a callback + args_schema: The schema of the tool's input arguments + infer_schema: Whether to infer the schema from the function's signature + **kwargs: Additional arguments to pass to the tool + + Returns: + The tool + + Examples: + ... code-block:: python + def add(a: int, b: int) -> int: + \"\"\"Add two numbers\"\"\" + return a + b + tool = StructuredTool.from_function(add) + tool.run(1, 2) # 3 + """ name = name or func.__name__ description = description or func.__doc__ assert ( diff --git a/tests/unit_tests/tools/test_base.py b/tests/unit_tests/tools/test_base.py index cea017d79d..9f05faa0c3 100644 --- a/tests/unit_tests/tools/test_base.py +++ b/tests/unit_tests/tools/test_base.py @@ -315,6 +315,39 @@ def test_tool_lambda_args_schema() -> None: assert tool.args == expected_args +def test_structured_tool_from_function_docstring() -> None: + """Test that structured tools can be created from functions.""" + + def foo(bar: int, baz: str) -> str: + """Docstring + Args: + bar: int + baz: str + """ + raise NotImplementedError() + + structured_tool = StructuredTool.from_function(foo) + assert structured_tool.name == "foo" + assert structured_tool.args == { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "string"}, + } + + assert structured_tool.args_schema.schema() == { + "properties": { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "string"}, + }, + "title": "fooSchemaSchema", + "type": "object", + "required": ["bar", "baz"], + } + + prefix = "foo(bar: int, baz: str) -> str - " + assert foo.__doc__ is not None + assert structured_tool.description == prefix + foo.__doc__.strip() + + def test_structured_tool_lambda_multi_args_schema() -> None: """Test args schema inference when the tool argument is a lambda function.""" tool = StructuredTool.from_function( @@ -577,12 +610,13 @@ def test_structured_tool_from_function() -> None: } assert structured_tool.args_schema.schema() == { + "title": "fooSchemaSchema", + "type": "object", "properties": { "bar": {"title": "Bar", "type": "integer"}, "baz": {"title": "Baz", "type": "string"}, }, - "title": "fooSchemaSchema", - "type": "object", + "required": ["bar", "baz"], } prefix = "foo(bar: int, baz: str) -> str - "