Add dangerous parameter to requests tool (#18697)

The tools are already documented as dangerous. Not clear whether adding
an opt-in parameter is necessary or not
This commit is contained in:
Eugene Yurtsev 2024-03-07 15:10:56 -05:00 committed by GitHub
parent dad949eb99
commit e188d4ecb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 162 additions and 28 deletions

View File

@ -41,15 +41,32 @@ class RequestsToolkit(BaseToolkit):
"""
requests_wrapper: TextRequestsWrapper
allow_dangerous_requests: bool = False
"""Allow dangerous requests. See documentation for details."""
def get_tools(self) -> List[BaseTool]:
"""Return a list of tools."""
return [
RequestsGetTool(requests_wrapper=self.requests_wrapper),
RequestsPostTool(requests_wrapper=self.requests_wrapper),
RequestsPatchTool(requests_wrapper=self.requests_wrapper),
RequestsPutTool(requests_wrapper=self.requests_wrapper),
RequestsDeleteTool(requests_wrapper=self.requests_wrapper),
RequestsGetTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
RequestsPostTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
RequestsPatchTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
RequestsPutTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
RequestsDeleteTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
]
@ -66,6 +83,8 @@ class OpenAPIToolkit(BaseToolkit):
json_agent: Any
requests_wrapper: TextRequestsWrapper
allow_dangerous_requests: bool = False
"""Allow dangerous requests. See documentation for details."""
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
@ -74,7 +93,10 @@ class OpenAPIToolkit(BaseToolkit):
func=self.json_agent.run,
description=DESCRIPTION,
)
request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper)
request_toolkit = RequestsToolkit(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
)
return [*request_toolkit.get_tools(), json_agent_tool]
@classmethod
@ -83,8 +105,13 @@ class OpenAPIToolkit(BaseToolkit):
llm: BaseLanguageModel,
json_spec: JsonSpec,
requests_wrapper: TextRequestsWrapper,
allow_dangerous_requests: bool = False,
**kwargs: Any,
) -> OpenAPIToolkit:
"""Create json agent from llm, then initialize."""
json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs)
return cls(json_agent=json_agent, requests_wrapper=requests_wrapper)
return cls(
json_agent=json_agent,
requests_wrapper=requests_wrapper,
allow_dangerous_requests=allow_dangerous_requests,
)

View File

@ -28,6 +28,23 @@ class BaseRequestsTool(BaseModel):
requests_wrapper: GenericRequestsWrapper
allow_dangerous_requests: bool = False
def __init__(self, **kwargs: Any):
"""Initialize the tool."""
if not kwargs.get("allow_dangerous_requests", False):
raise ValueError(
"You must set allow_dangerous_requests to True to use this tool. "
"Request scan be dangerous and can lead to security vulnerabilities. "
"For example, users can ask a server to make a request to an internal"
"server. It's recommended to use requests through a proxy server "
"and avoid accepting inputs from untrusted sources without proper "
"sandboxing."
"Please see: https://python.langchain.com/docs/security for "
"further security information."
)
super().__init__(**kwargs)
class RequestsGetTool(BaseRequestsTool, BaseTool):
"""Tool for making a GET request to an API endpoint."""

View File

@ -72,34 +72,44 @@ def test_parse_input() -> None:
def test_requests_get_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsGetTool(requests_wrapper=mock_requests_wrapper)
tool = RequestsGetTool(
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
)
assert tool.run("https://example.com") == "get_response"
assert asyncio.run(tool.arun("https://example.com")) == "aget_response"
def test_requests_post_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsPostTool(requests_wrapper=mock_requests_wrapper)
tool = RequestsPostTool(
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == "post {'key': 'value'}"
assert asyncio.run(tool.arun(input_text)) == "apost {'key': 'value'}"
def test_requests_patch_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsPatchTool(requests_wrapper=mock_requests_wrapper)
tool = RequestsPatchTool(
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == "patch {'key': 'value'}"
assert asyncio.run(tool.arun(input_text)) == "apatch {'key': 'value'}"
def test_requests_put_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsPutTool(requests_wrapper=mock_requests_wrapper)
tool = RequestsPutTool(
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == "put {'key': 'value'}"
assert asyncio.run(tool.arun(input_text)) == "aput {'key': 'value'}"
def test_requests_delete_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsDeleteTool(requests_wrapper=mock_requests_wrapper)
tool = RequestsDeleteTool(
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
)
assert tool.run("https://example.com") == "delete_response"
assert asyncio.run(tool.arun("https://example.com")) == "adelete_response"
@ -154,7 +164,9 @@ def mock_json_requests_wrapper() -> JsonRequestsWrapper:
def test_requests_get_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsGetTool(requests_wrapper=mock_json_requests_wrapper)
tool = RequestsGetTool(
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
)
assert tool.run("https://example.com") == {"response": "get_response"}
assert asyncio.run(tool.arun("https://example.com")) == {
"response": "aget_response"
@ -164,7 +176,9 @@ def test_requests_get_tool_json(
def test_requests_post_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsPostTool(requests_wrapper=mock_json_requests_wrapper)
tool = RequestsPostTool(
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == {"response": 'post {"key": "value"}'}
assert asyncio.run(tool.arun(input_text)) == {"response": 'apost {"key": "value"}'}
@ -173,7 +187,9 @@ def test_requests_post_tool_json(
def test_requests_patch_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsPatchTool(requests_wrapper=mock_json_requests_wrapper)
tool = RequestsPatchTool(
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == {"response": 'patch {"key": "value"}'}
assert asyncio.run(tool.arun(input_text)) == {"response": 'apatch {"key": "value"}'}
@ -182,7 +198,9 @@ def test_requests_patch_tool_json(
def test_requests_put_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsPutTool(requests_wrapper=mock_json_requests_wrapper)
tool = RequestsPutTool(
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == {"response": 'put {"key": "value"}'}
assert asyncio.run(tool.arun(input_text)) == {"response": 'aput {"key": "value"}'}
@ -191,7 +209,9 @@ def test_requests_put_tool_json(
def test_requests_delete_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsDeleteTool(requests_wrapper=mock_json_requests_wrapper)
tool = RequestsDeleteTool(
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
)
assert tool.run("https://example.com") == {"response": "delete_response"}
assert asyncio.run(tool.arun("https://example.com")) == {
"response": "adelete_response"

View File

@ -106,23 +106,48 @@ from langchain_community.utilities.reddit_search import RedditSearchAPIWrapper
def _get_tools_requests_get() -> BaseTool:
return RequestsGetTool(requests_wrapper=TextRequestsWrapper())
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsGetTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_tools_requests_post() -> BaseTool:
return RequestsPostTool(requests_wrapper=TextRequestsWrapper())
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsPostTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_tools_requests_patch() -> BaseTool:
return RequestsPatchTool(requests_wrapper=TextRequestsWrapper())
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsPatchTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_tools_requests_put() -> BaseTool:
return RequestsPutTool(requests_wrapper=TextRequestsWrapper())
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsPutTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_tools_requests_delete() -> BaseTool:
return RequestsDeleteTool(requests_wrapper=TextRequestsWrapper())
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsDeleteTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_terminal() -> BaseTool:
@ -134,6 +159,15 @@ def _get_sleep() -> BaseTool:
_BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
"sleep": _get_sleep,
}
DANGEROUS_TOOLS = {
# Tools that contain some level of risk.
# Please use with caution and read the documentation of these tools
# to understand the risks and how to mitigate them.
# Refer to https://python.langchain.com/docs/security
# for more information.
"requests": _get_tools_requests_get, # preserved for backwards compatibility
"requests_get": _get_tools_requests_get,
"requests_post": _get_tools_requests_post,
@ -141,7 +175,6 @@ _BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
"requests_put": _get_tools_requests_put,
"requests_delete": _get_tools_requests_delete,
"terminal": _get_terminal,
"sleep": _get_sleep,
}
@ -541,6 +574,7 @@ def load_tools(
tool_names: List[str],
llm: Optional[BaseLanguageModel] = None,
callbacks: Callbacks = None,
allow_dangerous_tools: bool = False,
**kwargs: Any,
) -> List[BaseTool]:
"""Load tools based on their name.
@ -566,6 +600,15 @@ def load_tools(
llm: An optional language model, may be needed to initialize certain tools.
callbacks: Optional callback manager or list of callback handlers.
If not provided, default global callback manager will be used.
allow_dangerous_tools: Optional flag to allow dangerous tools.
Tools that contain some level of risk.
Please use with caution and read the documentation of these tools
to understand the risks and how to mitigate them.
Refer to https://python.langchain.com/docs/security
for more information.
Please note that this list may not be fully exhaustive.
It is your responsibility to understand which tools
you're using and the risks associated with them.
Returns:
List of tools.
@ -574,10 +617,26 @@ def load_tools(
callbacks = _handle_callbacks(
callback_manager=kwargs.get("callback_manager"), callbacks=callbacks
)
# print(_BASE_TOOLS)
# print(1)
for name in tool_names:
if name == "requests":
if name in DANGEROUS_TOOLS and not allow_dangerous_tools:
raise ValueError(
f"{name} is a dangerous tool. You cannot use it without opting in "
"by setting allow_dangerous_tools to True. "
"Most tools have some inherit risk to them merely because they are "
'allowed to interact with the "real world".'
"Please refer to LangChain security guidelines "
"to https://python.langchain.com/docs/security."
"Some tools have been designated as dangerous because they pose "
"risk that is not intuitively obvious. For example, a tool that "
"allows an agent to make requests to the web, can also be used "
"to make requests to a server that is only accessible from the "
"server hosting the code."
"Again, all tools carry some risk, and it's your responsibility to "
"understand which tools you're using and the risks associated with "
"them."
)
if name in {"requests"}:
warnings.warn(
"tool name `requests` is deprecated - "
"please use `requests_all` or specify the requests method"
@ -590,6 +649,8 @@ def load_tools(
tool_names.extend(requests_method_tools)
elif name in _BASE_TOOLS:
tools.append(_BASE_TOOLS[name]())
elif name in DANGEROUS_TOOLS:
tools.append(DANGEROUS_TOOLS[name]())
elif name in _LLM_TOOLS:
if llm is None:
raise ValueError(f"Tool {name} requires an LLM to be provided")
@ -628,4 +689,5 @@ def get_all_tool_names() -> List[str]:
+ list(_EXTRA_OPTIONAL_TOOLS)
+ list(_EXTRA_LLM_TOOLS)
+ list(_LLM_TOOLS)
+ list(DANGEROUS_TOOLS)
)

View File

@ -71,7 +71,11 @@ def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
"""Test load_tools raises a deprecation for old callback manager kwarg."""
callback_manager = MagicMock()
with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"):
tools = load_tools(["requests_get"], callback_manager=callback_manager)
tools = load_tools(
["requests_get"],
callback_manager=callback_manager,
allow_dangerous_tools=True,
)
assert len(tools) == 1
assert tools[0].callbacks == callback_manager
@ -79,7 +83,11 @@ def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
def test_load_tools_with_callbacks_is_called() -> None:
"""Test callbacks are called when provided to load_tools fn."""
callbacks = [FakeCallbackHandler()]
tools = load_tools(["requests_get"], callbacks=callbacks) # type: ignore
tools = load_tools(
["requests_get"], # type: ignore
callbacks=callbacks, # type: ignore
allow_dangerous_tools=True,
)
assert len(tools) == 1
# Patch the requests.get() method to return a mock response
with unittest.mock.patch(