Wfh/async tool (#9878)

Co-authored-by: Daniel Brenot <dbrenot@pelmorex.com>
Co-authored-by: Daniel <daniel.alexander.brenot@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
William FH 2023-08-29 15:37:41 -07:00 committed by GitHub
parent 7bba1d911b
commit d799963870
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 33 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import inspect
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from functools import partial from functools import partial
@ -437,7 +438,7 @@ class Tool(BaseTool):
"""Tool that takes in function or coroutine directly.""" """Tool that takes in function or coroutine directly."""
description: str = "" description: str = ""
func: Callable[..., str] func: Optional[Callable[..., str]]
"""The function to run when the tool is called.""" """The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function.""" """The asynchronous version of the function."""
@ -488,16 +489,18 @@ class Tool(BaseTool):
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Use the tool.""" """Use the tool."""
new_argument_supported = signature(self.func).parameters.get("callbacks") if self.func:
return ( new_argument_supported = signature(self.func).parameters.get("callbacks")
self.func( return (
*args, self.func(
callbacks=run_manager.get_child() if run_manager else None, *args,
**kwargs, callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
) )
if new_argument_supported raise NotImplementedError("Tool does not support sync")
else self.func(*args, **kwargs)
)
async def _arun( async def _arun(
self, self,
@ -523,7 +526,7 @@ class Tool(BaseTool):
# TODO: this is for backwards compatibility, remove in future # TODO: this is for backwards compatibility, remove in future
def __init__( def __init__(
self, name: str, func: Callable, description: str, **kwargs: Any self, name: str, func: Optional[Callable], description: str, **kwargs: Any
) -> None: ) -> None:
"""Initialize tool.""" """Initialize tool."""
super(Tool, self).__init__( super(Tool, self).__init__(
@ -533,17 +536,23 @@ class Tool(BaseTool):
@classmethod @classmethod
def from_function( def from_function(
cls, cls,
func: Callable, func: Optional[Callable],
name: str, # We keep these required to support backwards compatibility name: str, # We keep these required to support backwards compatibility
description: str, description: str,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[Type[BaseModel]] = None,
coroutine: Optional[
Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func
**kwargs: Any, **kwargs: Any,
) -> Tool: ) -> Tool:
"""Initialize tool from a function.""" """Initialize tool from a function."""
if func is None and coroutine is None:
raise ValueError("Function and/or coroutine must be provided")
return cls( return cls(
name=name, name=name,
func=func, func=func,
coroutine=coroutine,
description=description, description=description,
return_direct=return_direct, return_direct=return_direct,
args_schema=args_schema, args_schema=args_schema,
@ -557,7 +566,7 @@ class StructuredTool(BaseTool):
description: str = "" description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.") args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema.""" """The input arguments' schema."""
func: Callable[..., Any] func: Optional[Callable[..., Any]]
"""The function to run when the tool is called.""" """The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function.""" """The asynchronous version of the function."""
@ -592,16 +601,18 @@ class StructuredTool(BaseTool):
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Use the tool.""" """Use the tool."""
new_argument_supported = signature(self.func).parameters.get("callbacks") if self.func:
return ( new_argument_supported = signature(self.func).parameters.get("callbacks")
self.func( return (
*args, self.func(
callbacks=run_manager.get_child() if run_manager else None, *args,
**kwargs, callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
) )
if new_argument_supported raise NotImplementedError("Tool does not support sync")
else self.func(*args, **kwargs)
)
async def _arun( async def _arun(
self, self,
@ -628,7 +639,8 @@ class StructuredTool(BaseTool):
@classmethod @classmethod
def from_function( def from_function(
cls, cls,
func: Callable, func: Optional[Callable] = None,
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
return_direct: bool = False, return_direct: bool = False,
@ -642,6 +654,7 @@ class StructuredTool(BaseTool):
Args: Args:
func: The function from which to create a tool func: The function from which to create a tool
coroutine: The async function from which to create a tool
name: The name of the tool. Defaults to the function name name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback return_direct: Whether to return the result directly or as a callback
@ -662,21 +675,31 @@ class StructuredTool(BaseTool):
tool = StructuredTool.from_function(add) tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3 tool.run(1, 2) # 3
""" """
name = name or func.__name__
description = description or func.__doc__ if func is not None:
assert ( source_function = func
description is not None elif coroutine is not None:
), "Function must have a docstring if description not provided." source_function = coroutine
else:
raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__
description = description or source_function.__doc__
if description is None:
raise ValueError(
"Function must have a docstring if description not provided."
)
# Description example: # Description example:
# search_api(query: str) - Searches the API for the query. # search_api(query: str) - Searches the API for the query.
description = f"{name}{signature(func)} - {description.strip()}" sig = signature(source_function)
description = f"{name}{sig} - {description.strip()}"
_args_schema = args_schema _args_schema = args_schema
if _args_schema is None and infer_schema: if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func) _args_schema = create_schema_from_function(f"{name}Schema", source_function)
return cls( return cls(
name=name, name=name,
func=func, func=func,
coroutine=coroutine,
args_schema=_args_schema, args_schema=_args_schema,
description=description, description=description,
return_direct=return_direct, return_direct=return_direct,
@ -720,10 +743,18 @@ def tool(
""" """
def _make_with_name(tool_name: str) -> Callable: def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> BaseTool: def _make_tool(dec_func: Callable) -> BaseTool:
if inspect.iscoroutinefunction(dec_func):
coroutine = dec_func
func = None
else:
coroutine = None
func = dec_func
if infer_schema or args_schema is not None: if infer_schema or args_schema is not None:
return StructuredTool.from_function( return StructuredTool.from_function(
func, func,
coroutine,
name=tool_name, name=tool_name,
return_direct=return_direct, return_direct=return_direct,
args_schema=args_schema, args_schema=args_schema,
@ -731,12 +762,17 @@ def tool(
) )
# If someone doesn't want a schema applied, we must treat it as # If someone doesn't want a schema applied, we must treat it as
# a simple string->string function # a simple string->string function
assert func.__doc__ is not None, "Function must have a docstring" if func.__doc__ is None:
raise ValueError(
"Function must have a docstring if "
"description not provided and infer_schema is False."
)
return Tool( return Tool(
name=tool_name, name=tool_name,
func=func, func=func,
description=f"{tool_name} tool", description=f"{tool_name} tool",
return_direct=return_direct, return_direct=return_direct,
coroutine=coroutine,
) )
return _make_tool return _make_tool

View File

@ -546,7 +546,7 @@ def test_tool_with_kwargs() -> None:
def test_missing_docstring() -> None: def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing.""" """Test error is raised when docstring is missing."""
# expect to throw a value error if there's no docstring # expect to throw a value error if there's no docstring
with pytest.raises(AssertionError, match="Function must have a docstring"): with pytest.raises(ValueError, match="Function must have a docstring"):
@tool @tool
def search_api(query: str) -> str: def search_api(query: str) -> str: