Expand requests tool into individual methods for load_tools (#2254)

### Motivation / Context

When exploring `load_tools(["requests"] )`, I would have expected all
request method tools to be imported instead of just `RequestsGetTool`.

### Changes

Break `_get_requests` into multiple functions by request method. Each
function returns the `BaseTool` for that particular request method.

In `load_tools`, if the tool name "requests_all" is encountered, we
replace with all `_BASE_TOOLS` that starts with `requests_`.

This way, `load_tools(["requests"])` returns:
- RequestsGetTool
- RequestsPostTool
- RequestsPatchTool
- RequestsPutTool
- RequestsDeleteTool
This commit is contained in:
Mandy Gu 2023-04-03 18:59:52 -04:00 committed by GitHub
parent 28cedab1a4
commit c841b2cc51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
# flake8: noqa # flake8: noqa
"""Load tools.""" """Load tools."""
from typing import Any, List, Optional from typing import Any, List, Optional
import warnings
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
@ -16,7 +17,13 @@ from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearch
from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun
from langchain.tools.human.tool import HumanInputRun from langchain.tools.human.tool import HumanInputRun
from langchain.tools.python.tool import PythonREPLTool from langchain.tools.python.tool import PythonREPLTool
from langchain.tools.requests.tool import RequestsGetTool from langchain.tools.requests.tool import (
RequestsGetTool,
RequestsPostTool,
RequestsPatchTool,
RequestsPutTool,
RequestsDeleteTool,
)
from langchain.tools.wikipedia.tool import WikipediaQueryRun from langchain.tools.wikipedia.tool import WikipediaQueryRun
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
from langchain.utilities.apify import ApifyWrapper from langchain.utilities.apify import ApifyWrapper
@ -34,10 +41,26 @@ def _get_python_repl() -> BaseTool:
return PythonREPLTool() return PythonREPLTool()
def _get_requests() -> BaseTool: def _get_tools_requests_get() -> BaseTool:
return RequestsGetTool(requests_wrapper=RequestsWrapper()) return RequestsGetTool(requests_wrapper=RequestsWrapper())
def _get_tools_requests_post() -> BaseTool:
return RequestsPostTool(requests_wrapper=RequestsWrapper())
def _get_tools_requests_patch() -> BaseTool:
return RequestsPatchTool(requests_wrapper=RequestsWrapper())
def _get_tools_requests_put() -> BaseTool:
return RequestsPutTool(requests_wrapper=RequestsWrapper())
def _get_tools_requests_delete() -> BaseTool:
return RequestsDeleteTool(requests_wrapper=RequestsWrapper())
def _get_terminal() -> BaseTool: def _get_terminal() -> BaseTool:
return Tool( return Tool(
name="Terminal", name="Terminal",
@ -48,7 +71,12 @@ def _get_terminal() -> BaseTool:
_BASE_TOOLS = { _BASE_TOOLS = {
"python_repl": _get_python_repl, "python_repl": _get_python_repl,
"requests": _get_requests, "requests": _get_tools_requests_get, # preserved for backwards compatability
"requests_get": _get_tools_requests_get,
"requests_post": _get_tools_requests_post,
"requests_patch": _get_tools_requests_patch,
"requests_put": _get_tools_requests_put,
"requests_delete": _get_tools_requests_delete,
"terminal": _get_terminal, "terminal": _get_terminal,
} }
@ -228,8 +256,21 @@ def load_tools(
List of tools. List of tools.
""" """
tools = [] tools = []
for name in tool_names: for name in tool_names:
if name in _BASE_TOOLS: if name == "requests":
warnings.warn(
"tool name `requests` is deprecated - "
"please use `requests_all` or specify the requests method"
)
if name == "requests_all":
# expand requests into various methods
requests_method_tools = [
_tool for _tool in _BASE_TOOLS if _tool.startswith("requests_")
]
tool_names.extend(requests_method_tools)
elif name in _BASE_TOOLS:
tools.append(_BASE_TOOLS[name]()) tools.append(_BASE_TOOLS[name]())
elif name in _LLM_TOOLS: elif name in _LLM_TOOLS:
if llm is None: if llm is None: