From d642609a23219b1037f84492c2bc56777e90397a Mon Sep 17 00:00:00 2001 From: "Jason B. Koh" Date: Thu, 6 Jul 2023 14:22:09 -0700 Subject: [PATCH] Fix: Recognize `List` at `from_function` (#7178) - Description: pydantic's `ModelField.type_` only exposes the native data type but not complex type hints like `List`. Thus, generating a Tool with `from_function` through function signature produces incorrect argument schemas (e.g., `str` instead of `List[str]`) - Issue: N/A - Dependencies: N/A - Tag maintainer: @hinthornw - Twitter handle: `mapped` All the unittest (with an additional one in this PR) passed, though I didn't try integration tests... --- langchain/tools/base.py | 2 +- tests/unit_tests/tools/test_base.py | 35 ++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/langchain/tools/base.py b/langchain/tools/base.py index c41462ae2b..f39132efbd 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -71,7 +71,7 @@ def _create_subset_model( fields = {} for field_name in field_names: field = model.__fields__[field_name] - fields[field_name] = (field.type_, field.field_info) + fields[field_name] = (field.outer_type_, field.field_info) return create_model(name, **fields) # type: ignore diff --git a/tests/unit_tests/tools/test_base.py b/tests/unit_tests/tools/test_base.py index eadfbcf97c..0d6a62f416 100644 --- a/tests/unit_tests/tools/test_base.py +++ b/tests/unit_tests/tools/test_base.py @@ -3,7 +3,7 @@ import json from datetime import datetime from enum import Enum from functools import partial -from typing import Any, Optional, Type, Union +from typing import Any, List, Optional, Type, Union import pytest from pydantic import BaseModel @@ -349,6 +349,39 @@ def test_structured_tool_from_function_docstring() -> None: assert structured_tool.description == prefix + foo.__doc__.strip() +def test_structured_tool_from_function_docstring_complex_args() -> None: + """Test that structured tools can be created from functions.""" + + def foo(bar: int, baz: List[str]) -> str: + """Docstring + Args: + bar: int + baz: List[str] + """ + raise NotImplementedError() + + structured_tool = StructuredTool.from_function(foo) + assert structured_tool.name == "foo" + assert structured_tool.args == { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, + } + + assert structured_tool.args_schema.schema() == { + "properties": { + "bar": {"title": "Bar", "type": "integer"}, + "baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, + }, + "title": "fooSchemaSchema", + "type": "object", + "required": ["bar", "baz"], + } + + prefix = "foo(bar: int, baz: List[str]) -> str - " + assert foo.__doc__ is not None + assert structured_tool.description == prefix + foo.__doc__.strip() + + def test_structured_tool_lambda_multi_args_schema() -> None: """Test args schema inference when the tool argument is a lambda function.""" tool = StructuredTool.from_function(