mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Wfh/async tool (#9878)
Co-authored-by: Daniel Brenot <dbrenot@pelmorex.com> Co-authored-by: Daniel <daniel.alexander.brenot@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
7bba1d911b
commit
d799963870
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -437,7 +438,7 @@ class Tool(BaseTool):
|
|||||||
"""Tool that takes in function or coroutine directly."""
|
"""Tool that takes in function or coroutine directly."""
|
||||||
|
|
||||||
description: str = ""
|
description: str = ""
|
||||||
func: Callable[..., str]
|
func: Optional[Callable[..., str]]
|
||||||
"""The function to run when the tool is called."""
|
"""The function to run when the tool is called."""
|
||||||
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
||||||
"""The asynchronous version of the function."""
|
"""The asynchronous version of the function."""
|
||||||
@ -488,16 +489,18 @@ class Tool(BaseTool):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
if self.func:
|
||||||
return (
|
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||||
self.func(
|
return (
|
||||||
*args,
|
self.func(
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
*args,
|
||||||
**kwargs,
|
callbacks=run_manager.get_child() if run_manager else None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if new_argument_supported
|
||||||
|
else self.func(*args, **kwargs)
|
||||||
)
|
)
|
||||||
if new_argument_supported
|
raise NotImplementedError("Tool does not support sync")
|
||||||
else self.func(*args, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
@ -523,7 +526,7 @@ class Tool(BaseTool):
|
|||||||
|
|
||||||
# TODO: this is for backwards compatibility, remove in future
|
# TODO: this is for backwards compatibility, remove in future
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, func: Callable, description: str, **kwargs: Any
|
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize tool."""
|
"""Initialize tool."""
|
||||||
super(Tool, self).__init__(
|
super(Tool, self).__init__(
|
||||||
@ -533,17 +536,23 @@ class Tool(BaseTool):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
cls,
|
cls,
|
||||||
func: Callable,
|
func: Optional[Callable],
|
||||||
name: str, # We keep these required to support backwards compatibility
|
name: str, # We keep these required to support backwards compatibility
|
||||||
description: str,
|
description: str,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[Type[BaseModel]] = None,
|
args_schema: Optional[Type[BaseModel]] = None,
|
||||||
|
coroutine: Optional[
|
||||||
|
Callable[..., Awaitable[Any]]
|
||||||
|
] = None, # This is last for compatibility, but should be after func
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""Initialize tool from a function."""
|
"""Initialize tool from a function."""
|
||||||
|
if func is None and coroutine is None:
|
||||||
|
raise ValueError("Function and/or coroutine must be provided")
|
||||||
return cls(
|
return cls(
|
||||||
name=name,
|
name=name,
|
||||||
func=func,
|
func=func,
|
||||||
|
coroutine=coroutine,
|
||||||
description=description,
|
description=description,
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
args_schema=args_schema,
|
args_schema=args_schema,
|
||||||
@ -557,7 +566,7 @@ class StructuredTool(BaseTool):
|
|||||||
description: str = ""
|
description: str = ""
|
||||||
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
|
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
|
||||||
"""The input arguments' schema."""
|
"""The input arguments' schema."""
|
||||||
func: Callable[..., Any]
|
func: Optional[Callable[..., Any]]
|
||||||
"""The function to run when the tool is called."""
|
"""The function to run when the tool is called."""
|
||||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
||||||
"""The asynchronous version of the function."""
|
"""The asynchronous version of the function."""
|
||||||
@ -592,16 +601,18 @@ class StructuredTool(BaseTool):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
if self.func:
|
||||||
return (
|
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||||
self.func(
|
return (
|
||||||
*args,
|
self.func(
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
*args,
|
||||||
**kwargs,
|
callbacks=run_manager.get_child() if run_manager else None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if new_argument_supported
|
||||||
|
else self.func(*args, **kwargs)
|
||||||
)
|
)
|
||||||
if new_argument_supported
|
raise NotImplementedError("Tool does not support sync")
|
||||||
else self.func(*args, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
@ -628,7 +639,8 @@ class StructuredTool(BaseTool):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
cls,
|
cls,
|
||||||
func: Callable,
|
func: Optional[Callable] = None,
|
||||||
|
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
@ -642,6 +654,7 @@ class StructuredTool(BaseTool):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: The function from which to create a tool
|
func: The function from which to create a tool
|
||||||
|
coroutine: The async function from which to create a tool
|
||||||
name: The name of the tool. Defaults to the function name
|
name: The name of the tool. Defaults to the function name
|
||||||
description: The description of the tool. Defaults to the function docstring
|
description: The description of the tool. Defaults to the function docstring
|
||||||
return_direct: Whether to return the result directly or as a callback
|
return_direct: Whether to return the result directly or as a callback
|
||||||
@ -662,21 +675,31 @@ class StructuredTool(BaseTool):
|
|||||||
tool = StructuredTool.from_function(add)
|
tool = StructuredTool.from_function(add)
|
||||||
tool.run(1, 2) # 3
|
tool.run(1, 2) # 3
|
||||||
"""
|
"""
|
||||||
name = name or func.__name__
|
|
||||||
description = description or func.__doc__
|
if func is not None:
|
||||||
assert (
|
source_function = func
|
||||||
description is not None
|
elif coroutine is not None:
|
||||||
), "Function must have a docstring if description not provided."
|
source_function = coroutine
|
||||||
|
else:
|
||||||
|
raise ValueError("Function and/or coroutine must be provided")
|
||||||
|
name = name or source_function.__name__
|
||||||
|
description = description or source_function.__doc__
|
||||||
|
if description is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Function must have a docstring if description not provided."
|
||||||
|
)
|
||||||
|
|
||||||
# Description example:
|
# Description example:
|
||||||
# search_api(query: str) - Searches the API for the query.
|
# search_api(query: str) - Searches the API for the query.
|
||||||
description = f"{name}{signature(func)} - {description.strip()}"
|
sig = signature(source_function)
|
||||||
|
description = f"{name}{sig} - {description.strip()}"
|
||||||
_args_schema = args_schema
|
_args_schema = args_schema
|
||||||
if _args_schema is None and infer_schema:
|
if _args_schema is None and infer_schema:
|
||||||
_args_schema = create_schema_from_function(f"{name}Schema", func)
|
_args_schema = create_schema_from_function(f"{name}Schema", source_function)
|
||||||
return cls(
|
return cls(
|
||||||
name=name,
|
name=name,
|
||||||
func=func,
|
func=func,
|
||||||
|
coroutine=coroutine,
|
||||||
args_schema=_args_schema,
|
args_schema=_args_schema,
|
||||||
description=description,
|
description=description,
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
@ -720,10 +743,18 @@ def tool(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _make_with_name(tool_name: str) -> Callable:
|
def _make_with_name(tool_name: str) -> Callable:
|
||||||
def _make_tool(func: Callable) -> BaseTool:
|
def _make_tool(dec_func: Callable) -> BaseTool:
|
||||||
|
if inspect.iscoroutinefunction(dec_func):
|
||||||
|
coroutine = dec_func
|
||||||
|
func = None
|
||||||
|
else:
|
||||||
|
coroutine = None
|
||||||
|
func = dec_func
|
||||||
|
|
||||||
if infer_schema or args_schema is not None:
|
if infer_schema or args_schema is not None:
|
||||||
return StructuredTool.from_function(
|
return StructuredTool.from_function(
|
||||||
func,
|
func,
|
||||||
|
coroutine,
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
args_schema=args_schema,
|
args_schema=args_schema,
|
||||||
@ -731,12 +762,17 @@ def tool(
|
|||||||
)
|
)
|
||||||
# If someone doesn't want a schema applied, we must treat it as
|
# If someone doesn't want a schema applied, we must treat it as
|
||||||
# a simple string->string function
|
# a simple string->string function
|
||||||
assert func.__doc__ is not None, "Function must have a docstring"
|
if func.__doc__ is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Function must have a docstring if "
|
||||||
|
"description not provided and infer_schema is False."
|
||||||
|
)
|
||||||
return Tool(
|
return Tool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
func=func,
|
func=func,
|
||||||
description=f"{tool_name} tool",
|
description=f"{tool_name} tool",
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
|
coroutine=coroutine,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _make_tool
|
return _make_tool
|
||||||
|
@ -546,7 +546,7 @@ def test_tool_with_kwargs() -> None:
|
|||||||
def test_missing_docstring() -> None:
|
def test_missing_docstring() -> None:
|
||||||
"""Test error is raised when docstring is missing."""
|
"""Test error is raised when docstring is missing."""
|
||||||
# expect to throw a value error if there's no docstring
|
# expect to throw a value error if there's no docstring
|
||||||
with pytest.raises(AssertionError, match="Function must have a docstring"):
|
with pytest.raises(ValueError, match="Function must have a docstring"):
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def search_api(query: str) -> str:
|
def search_api(query: str) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user