mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
improve tools (#6062)
This commit is contained in:
parent
5b6bbf4ab2
commit
ec1a2adf9c
@ -68,18 +68,14 @@ def _create_subset_model(
|
||||
name: str, model: BaseModel, field_names: list
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {
|
||||
field_name: (
|
||||
model.__fields__[field_name].type_,
|
||||
model.__fields__[field_name].default,
|
||||
)
|
||||
for field_name in field_names
|
||||
if field_name in model.__fields__
|
||||
}
|
||||
fields = {}
|
||||
for field_name in field_names:
|
||||
field = model.__fields__[field_name]
|
||||
fields[field_name] = (field.type_, field.field_info)
|
||||
return create_model(name, **fields) # type: ignore
|
||||
|
||||
|
||||
def get_filtered_args(
|
||||
def _get_filtered_args(
|
||||
inferred_model: Type[BaseModel],
|
||||
func: Callable,
|
||||
) -> dict:
|
||||
@ -100,15 +96,22 @@ def create_schema_from_function(
|
||||
model_name: str,
|
||||
func: Callable,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature."""
|
||||
"""Create a pydantic schema from a function's signature.
|
||||
Args:
|
||||
model_name: Name to assign to the generated pydandic schema
|
||||
func: Function to generate the schema from
|
||||
Returns:
|
||||
A pydantic model with the same arguments as the function
|
||||
"""
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||
inferred_model = validated.model # type: ignore
|
||||
if "run_manager" in inferred_model.__fields__:
|
||||
del inferred_model.__fields__["run_manager"]
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
filtered_args = get_filtered_args(inferred_model, func)
|
||||
valid_properties = _get_filtered_args(inferred_model, func)
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema", inferred_model, list(filtered_args)
|
||||
f"{model_name}Schema", inferred_model, list(valid_properties)
|
||||
)
|
||||
|
||||
|
||||
@ -534,6 +537,30 @@ class StructuredTool(BaseTool):
|
||||
infer_schema: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> StructuredTool:
|
||||
"""Create tool from a given function.
|
||||
|
||||
A classmethod that helps to create a tool from a function.
|
||||
|
||||
Args:
|
||||
func: The function from which to create a tool
|
||||
name: The name of the tool. Defaults to the function name
|
||||
description: The description of the tool. Defaults to the function docstring
|
||||
return_direct: Whether to return the result directly or as a callback
|
||||
args_schema: The schema of the tool's input arguments
|
||||
infer_schema: Whether to infer the schema from the function's signature
|
||||
**kwargs: Additional arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
The tool
|
||||
|
||||
Examples:
|
||||
... code-block:: python
|
||||
def add(a: int, b: int) -> int:
|
||||
\"\"\"Add two numbers\"\"\"
|
||||
return a + b
|
||||
tool = StructuredTool.from_function(add)
|
||||
tool.run(1, 2) # 3
|
||||
"""
|
||||
name = name or func.__name__
|
||||
description = description or func.__doc__
|
||||
assert (
|
||||
|
@ -315,6 +315,39 @@ def test_tool_lambda_args_schema() -> None:
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_structured_tool_from_function_docstring() -> None:
|
||||
"""Test that structured tools can be created from functions."""
|
||||
|
||||
def foo(bar: int, baz: str) -> str:
|
||||
"""Docstring
|
||||
Args:
|
||||
bar: int
|
||||
baz: str
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
structured_tool = StructuredTool.from_function(foo)
|
||||
assert structured_tool.name == "foo"
|
||||
assert structured_tool.args == {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
},
|
||||
"title": "fooSchemaSchema",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
|
||||
prefix = "foo(bar: int, baz: str) -> str - "
|
||||
assert foo.__doc__ is not None
|
||||
assert structured_tool.description == prefix + foo.__doc__.strip()
|
||||
|
||||
|
||||
def test_structured_tool_lambda_multi_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
tool = StructuredTool.from_function(
|
||||
@ -577,12 +610,13 @@ def test_structured_tool_from_function() -> None:
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"title": "fooSchemaSchema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
},
|
||||
"title": "fooSchemaSchema",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
|
||||
prefix = "foo(bar: int, baz: str) -> str - "
|
||||
|
Loading…
Reference in New Issue
Block a user