Add tests for structured tools created from partial functions with docs

This commit is contained in:
zhanglei 2024-10-21 23:32:10 +08:00
parent c06ad4a6f6
commit 4ce578cbe2

View File

@ -483,6 +483,82 @@ def test_structured_tool_lambda_multi_args_schema() -> None:
assert tool.args == expected_args
def test_structured_tool_from_function_partial_docstring() -> None:
"""Test that structured tools can be created from partial functions."""
def foo(bar: int, baz: str) -> str:
"""Docstring
Args:
bar: the bar value
baz: the baz value
"""
raise NotImplementedError
structured_tool = StructuredTool.from_function(func=partial(foo, baz="foo"))
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
}
assert _schema(structured_tool.args_schema) == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"description": inspect.getdoc(foo),
"title": "foo",
"type": "object",
"required": ["bar", "baz"],
}
assert foo.__doc__ is not None
assert structured_tool.description == textwrap.dedent(foo.__doc__.strip())
def test_structured_tool_from_function_partial_docstring_complex_args() -> None:
"""Test that structured tools can be created from partial functions."""
def foo(bar: int, baz: list[str]) -> str:
"""Docstring
Args:
bar: int
baz: List[str]
"""
raise NotImplementedError
structured_tool = StructuredTool.from_function(func=partial(foo, baz="foo"))
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
},
}
assert _schema(structured_tool.args_schema) == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {
"title": "Baz",
"type": "array",
"items": {"type": "string"},
},
},
"description": inspect.getdoc(foo),
"title": "foo",
"type": "object",
"required": ["bar", "baz"],
}
assert foo.__doc__ is not None
assert structured_tool.description == textwrap.dedent(foo.__doc__).strip()
def test_structured_tool_from_function_partial() -> None:
"""Test that structured tools can be created from a partial function."""