improve tools (#6062)

This commit is contained in:
Harrison Chase 2023-06-12 22:19:03 -07:00 committed by GitHub
parent 5b6bbf4ab2
commit ec1a2adf9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 14 deletions

View File

@ -68,18 +68,14 @@ def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
fields[field_name] = (field.type_, field.field_info)
return create_model(name, **fields) # type: ignore
def get_filtered_args(
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
@ -100,15 +96,22 @@ def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from
Returns:
A pydantic model with the same arguments as the function
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
valid_properties = _get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
f"{model_name}Schema", inferred_model, list(valid_properties)
)
@ -534,6 +537,30 @@ class StructuredTool(BaseTool):
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
A classmethod that helps to create a tool from a function.
Args:
func: The function from which to create a tool
name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback
args_schema: The schema of the tool's input arguments
infer_schema: Whether to infer the schema from the function's signature
**kwargs: Additional arguments to pass to the tool
Returns:
The tool
Examples:
... code-block:: python
def add(a: int, b: int) -> int:
\"\"\"Add two numbers\"\"\"
return a + b
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
name = name or func.__name__
description = description or func.__doc__
assert (

View File

@ -315,6 +315,39 @@ def test_tool_lambda_args_schema() -> None:
assert tool.args == expected_args
def test_structured_tool_from_function_docstring() -> None:
"""Test that structured tools can be created from functions."""
def foo(bar: int, baz: str) -> str:
"""Docstring
Args:
bar: int
baz: str
"""
raise NotImplementedError()
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
}
assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"title": "fooSchemaSchema",
"type": "object",
"required": ["bar", "baz"],
}
prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = StructuredTool.from_function(
@ -577,12 +610,13 @@ def test_structured_tool_from_function() -> None:
}
assert structured_tool.args_schema.schema() == {
"title": "fooSchemaSchema",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"},
},
"title": "fooSchemaSchema",
"type": "object",
"required": ["bar", "baz"],
}
prefix = "foo(bar: int, baz: str) -> str - "