forked from Archives/langchain
Expand requests tool into individual methods for load_tools (#2254)
### Motivation / Context When exploring `load_tools(["requests"] )`, I would have expected all request method tools to be imported instead of just `RequestsGetTool`. ### Changes Break `_get_requests` into multiple functions by request method. Each function returns the `BaseTool` for that particular request method. In `load_tools`, if the tool name "requests_all" is encountered, we replace with all `_BASE_TOOLS` that starts with `requests_`. This way, `load_tools(["requests"])` returns: - RequestsGetTool - RequestsPostTool - RequestsPatchTool - RequestsPutTool - RequestsDeleteTool
This commit is contained in:
parent
28cedab1a4
commit
c841b2cc51
@ -1,6 +1,7 @@
|
|||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
"""Load tools."""
|
"""Load tools."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
import warnings
|
||||||
|
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
@ -16,7 +17,13 @@ from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearch
|
|||||||
from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun
|
from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun
|
||||||
from langchain.tools.human.tool import HumanInputRun
|
from langchain.tools.human.tool import HumanInputRun
|
||||||
from langchain.tools.python.tool import PythonREPLTool
|
from langchain.tools.python.tool import PythonREPLTool
|
||||||
from langchain.tools.requests.tool import RequestsGetTool
|
from langchain.tools.requests.tool import (
|
||||||
|
RequestsGetTool,
|
||||||
|
RequestsPostTool,
|
||||||
|
RequestsPatchTool,
|
||||||
|
RequestsPutTool,
|
||||||
|
RequestsDeleteTool,
|
||||||
|
)
|
||||||
from langchain.tools.wikipedia.tool import WikipediaQueryRun
|
from langchain.tools.wikipedia.tool import WikipediaQueryRun
|
||||||
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
|
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
|
||||||
from langchain.utilities.apify import ApifyWrapper
|
from langchain.utilities.apify import ApifyWrapper
|
||||||
@ -34,10 +41,26 @@ def _get_python_repl() -> BaseTool:
|
|||||||
return PythonREPLTool()
|
return PythonREPLTool()
|
||||||
|
|
||||||
|
|
||||||
def _get_requests() -> BaseTool:
|
def _get_tools_requests_get() -> BaseTool:
|
||||||
return RequestsGetTool(requests_wrapper=RequestsWrapper())
|
return RequestsGetTool(requests_wrapper=RequestsWrapper())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tools_requests_post() -> BaseTool:
|
||||||
|
return RequestsPostTool(requests_wrapper=RequestsWrapper())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tools_requests_patch() -> BaseTool:
|
||||||
|
return RequestsPatchTool(requests_wrapper=RequestsWrapper())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tools_requests_put() -> BaseTool:
|
||||||
|
return RequestsPutTool(requests_wrapper=RequestsWrapper())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tools_requests_delete() -> BaseTool:
|
||||||
|
return RequestsDeleteTool(requests_wrapper=RequestsWrapper())
|
||||||
|
|
||||||
|
|
||||||
def _get_terminal() -> BaseTool:
|
def _get_terminal() -> BaseTool:
|
||||||
return Tool(
|
return Tool(
|
||||||
name="Terminal",
|
name="Terminal",
|
||||||
@ -48,7 +71,12 @@ def _get_terminal() -> BaseTool:
|
|||||||
|
|
||||||
_BASE_TOOLS = {
|
_BASE_TOOLS = {
|
||||||
"python_repl": _get_python_repl,
|
"python_repl": _get_python_repl,
|
||||||
"requests": _get_requests,
|
"requests": _get_tools_requests_get, # preserved for backwards compatability
|
||||||
|
"requests_get": _get_tools_requests_get,
|
||||||
|
"requests_post": _get_tools_requests_post,
|
||||||
|
"requests_patch": _get_tools_requests_patch,
|
||||||
|
"requests_put": _get_tools_requests_put,
|
||||||
|
"requests_delete": _get_tools_requests_delete,
|
||||||
"terminal": _get_terminal,
|
"terminal": _get_terminal,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,8 +256,21 @@ def load_tools(
|
|||||||
List of tools.
|
List of tools.
|
||||||
"""
|
"""
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
for name in tool_names:
|
for name in tool_names:
|
||||||
if name in _BASE_TOOLS:
|
if name == "requests":
|
||||||
|
warnings.warn(
|
||||||
|
"tool name `requests` is deprecated - "
|
||||||
|
"please use `requests_all` or specify the requests method"
|
||||||
|
)
|
||||||
|
|
||||||
|
if name == "requests_all":
|
||||||
|
# expand requests into various methods
|
||||||
|
requests_method_tools = [
|
||||||
|
_tool for _tool in _BASE_TOOLS if _tool.startswith("requests_")
|
||||||
|
]
|
||||||
|
tool_names.extend(requests_method_tools)
|
||||||
|
elif name in _BASE_TOOLS:
|
||||||
tools.append(_BASE_TOOLS[name]())
|
tools.append(_BASE_TOOLS[name]())
|
||||||
elif name in _LLM_TOOLS:
|
elif name in _LLM_TOOLS:
|
||||||
if llm is None:
|
if llm is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user