diff --git a/langchain/tools/base.py b/langchain/tools/base.py index c41462ae2b..f39132efbd 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -71,7 +71,7 @@ def _create_subset_model( fields = {} for field_name in field_names: field = model.__fields__[field_name] - fields[field_name] = (field.type_, field.field_info) + fields[field_name] = (field.outer_type_, field.field_info) return create_model(name, **fields) # type: ignore diff --git a/tests/unit_tests/tools/test_base.py b/tests/unit_tests/tools/test_base.py index eadfbcf97c..0d6a62f416 100644 --- a/tests/unit_tests/tools/test_base.py +++ b/tests/unit_tests/tools/test_base.py @@ -3,7 +3,7 @@ import json from datetime import datetime from enum import Enum from functools import partial -from typing import Any, Optional, Type, Union +from typing import Any, List, Optional, Type, Union import pytest from pydantic import BaseModel @@ -349,6 +349,39 @@ def test_structured_tool_from_function_docstring() -> None: assert structured_tool.description == prefix + foo.__doc__.strip() +def test_structured_tool_from_function_docstring_complex_args() -> None: + """Test that structured tools can be created from functions.""" + + def foo(bar: int, baz: List[str]) -> str: + """Docstring + Args: + bar: int + baz: List[str] + """ + raise NotImplementedError() + + structured_tool = StructuredTool.from_function(foo) + assert structured_tool.name == "foo" + assert structured_tool.args == { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, + } + + assert structured_tool.args_schema.schema() == { + "properties": { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, + }, + "title": "fooSchemaSchema", + "type": "object", + "required": ["bar", "baz"], + } + + prefix = "foo(bar: int, baz: List[str]) -> str - " + assert foo.__doc__ is not None + assert structured_tool.description == prefix + foo.__doc__.strip() + + def test_structured_tool_lambda_multi_args_schema() -> None: """Test args schema inference when the tool argument is a lambda function.""" tool = StructuredTool.from_function(