Wfh/async tool (#9878)

Co-authored-by: Daniel Brenot <dbrenot@pelmorex.com>
Co-authored-by: Daniel <daniel.alexander.brenot@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/9965/head
William FH 1 year ago committed by GitHub
parent 7bba1d911b
commit d799963870
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import inspect
import warnings
from abc import abstractmethod
from functools import partial
@ -437,7 +438,7 @@ class Tool(BaseTool):
"""Tool that takes in function or coroutine directly."""
description: str = ""
func: Callable[..., str]
func: Optional[Callable[..., str]]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
@ -488,16 +489,18 @@ class Tool(BaseTool):
**kwargs: Any,
) -> Any:
"""Use the tool."""
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
async def _arun(
self,
@ -523,7 +526,7 @@ class Tool(BaseTool):
# TODO: this is for backwards compatibility, remove in future
def __init__(
self, name: str, func: Callable, description: str, **kwargs: Any
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
@ -533,17 +536,23 @@ class Tool(BaseTool):
@classmethod
def from_function(
cls,
func: Callable,
func: Optional[Callable],
name: str, # We keep these required to support backwards compatibility
description: str,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
coroutine: Optional[
Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func
**kwargs: Any,
) -> Tool:
"""Initialize tool from a function."""
if func is None and coroutine is None:
raise ValueError("Function and/or coroutine must be provided")
return cls(
name=name,
func=func,
coroutine=coroutine,
description=description,
return_direct=return_direct,
args_schema=args_schema,
@ -557,7 +566,7 @@ class StructuredTool(BaseTool):
description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Callable[..., Any]
func: Optional[Callable[..., Any]]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
@ -592,16 +601,18 @@ class StructuredTool(BaseTool):
**kwargs: Any,
) -> Any:
"""Use the tool."""
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
async def _arun(
self,
@ -628,7 +639,8 @@ class StructuredTool(BaseTool):
@classmethod
def from_function(
cls,
func: Callable,
func: Optional[Callable] = None,
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
@ -642,6 +654,7 @@ class StructuredTool(BaseTool):
Args:
func: The function from which to create a tool
coroutine: The async function from which to create a tool
name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback
@ -662,21 +675,31 @@ class StructuredTool(BaseTool):
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
name = name or func.__name__
description = description or func.__doc__
assert (
description is not None
), "Function must have a docstring if description not provided."
if func is not None:
source_function = func
elif coroutine is not None:
source_function = coroutine
else:
raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__
description = description or source_function.__doc__
if description is None:
raise ValueError(
"Function must have a docstring if description not provided."
)
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{name}{signature(func)} - {description.strip()}"
sig = signature(source_function)
description = f"{name}{sig} - {description.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
_args_schema = create_schema_from_function(f"{name}Schema", source_function)
return cls(
name=name,
func=func,
coroutine=coroutine,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
@ -720,10 +743,18 @@ def tool(
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> BaseTool:
def _make_tool(dec_func: Callable) -> BaseTool:
if inspect.iscoroutinefunction(dec_func):
coroutine = dec_func
func = None
else:
coroutine = None
func = dec_func
if infer_schema or args_schema is not None:
return StructuredTool.from_function(
func,
coroutine,
name=tool_name,
return_direct=return_direct,
args_schema=args_schema,
@ -731,12 +762,17 @@ def tool(
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
assert func.__doc__ is not None, "Function must have a docstring"
if func.__doc__ is None:
raise ValueError(
"Function must have a docstring if "
"description not provided and infer_schema is False."
)
return Tool(
name=tool_name,
func=func,
description=f"{tool_name} tool",
return_direct=return_direct,
coroutine=coroutine,
)
return _make_tool

@ -546,7 +546,7 @@ def test_tool_with_kwargs() -> None:
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if there's no docstring
with pytest.raises(AssertionError, match="Function must have a docstring"):
with pytest.raises(ValueError, match="Function must have a docstring"):
@tool
def search_api(query: str) -> str:

Loading…
Cancel
Save