Fix: Recognize `List` at `from_function` (#7178)

- Description: pydantic's `ModelField.type_` only exposes the native
data type but not complex type hints like `List`. Thus, generating a
Tool with `from_function` through function signature produces incorrect
argument schemas (e.g., `str` instead of `List[str]`)
  - Issue: N/A
  - Dependencies: N/A
  - Tag maintainer: @hinthornw
  - Twitter handle: `mapped`

All the unittest (with an additional one in this PR) passed, though I
didn't try integration tests...
pull/7306/head
Jason B. Koh 1 year ago committed by GitHub
parent ec10787bc7
commit d642609a23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -71,7 +71,7 @@ def _create_subset_model(
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
fields[field_name] = (field.type_, field.field_info)
fields[field_name] = (field.outer_type_, field.field_info)
return create_model(name, **fields) # type: ignore

@ -3,7 +3,7 @@ import json
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any, Optional, Type, Union
from typing import Any, List, Optional, Type, Union
import pytest
from pydantic import BaseModel
@ -349,6 +349,39 @@ def test_structured_tool_from_function_docstring() -> None:
assert structured_tool.description == prefix + foo.__doc__.strip()
def test_structured_tool_from_function_docstring_complex_args() -> None:
"""Test that structured tools can be created from functions."""
def foo(bar: int, baz: List[str]) -> str:
"""Docstring
Args:
bar: int
baz: List[str]
"""
raise NotImplementedError()
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
}
assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
},
"title": "fooSchemaSchema",
"type": "object",
"required": ["bar", "baz"],
}
prefix = "foo(bar: int, baz: List[str]) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = StructuredTool.from_function(

Loading…
Cancel
Save