diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index 269c5e48..c20322fb 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -1,6 +1,7 @@ # flake8: noqa """Load tools.""" from typing import Any, List, Optional +import warnings from langchain.agents.tools import Tool from langchain.callbacks.base import BaseCallbackManager @@ -16,7 +17,13 @@ from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearch from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun from langchain.tools.human.tool import HumanInputRun from langchain.tools.python.tool import PythonREPLTool -from langchain.tools.requests.tool import RequestsGetTool +from langchain.tools.requests.tool import ( + RequestsGetTool, + RequestsPostTool, + RequestsPatchTool, + RequestsPutTool, + RequestsDeleteTool, +) from langchain.tools.wikipedia.tool import WikipediaQueryRun from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun from langchain.utilities.apify import ApifyWrapper @@ -34,10 +41,26 @@ def _get_python_repl() -> BaseTool: return PythonREPLTool() -def _get_requests() -> BaseTool: +def _get_tools_requests_get() -> BaseTool: return RequestsGetTool(requests_wrapper=RequestsWrapper()) +def _get_tools_requests_post() -> BaseTool: + return RequestsPostTool(requests_wrapper=RequestsWrapper()) + + +def _get_tools_requests_patch() -> BaseTool: + return RequestsPatchTool(requests_wrapper=RequestsWrapper()) + + +def _get_tools_requests_put() -> BaseTool: + return RequestsPutTool(requests_wrapper=RequestsWrapper()) + + +def _get_tools_requests_delete() -> BaseTool: + return RequestsDeleteTool(requests_wrapper=RequestsWrapper()) + + def _get_terminal() -> BaseTool: return Tool( name="Terminal", @@ -48,7 +71,12 @@ def _get_terminal() -> BaseTool: _BASE_TOOLS = { "python_repl": _get_python_repl, - "requests": _get_requests, + "requests": _get_tools_requests_get, # preserved for backwards compatability + "requests_get": _get_tools_requests_get, + "requests_post": _get_tools_requests_post, + "requests_patch": _get_tools_requests_patch, + "requests_put": _get_tools_requests_put, + "requests_delete": _get_tools_requests_delete, "terminal": _get_terminal, } @@ -228,8 +256,21 @@ def load_tools( List of tools. """ tools = [] + for name in tool_names: - if name in _BASE_TOOLS: + if name == "requests": + warnings.warn( + "tool name `requests` is deprecated - " + "please use `requests_all` or specify the requests method" + ) + + if name == "requests_all": + # expand requests into various methods + requests_method_tools = [ + _tool for _tool in _BASE_TOOLS if _tool.startswith("requests_") + ] + tool_names.extend(requests_method_tools) + elif name in _BASE_TOOLS: tools.append(_BASE_TOOLS[name]()) elif name in _LLM_TOOLS: if llm is None: