mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
Add dangerous parameter to requests tool (#18697)
The tools are already documented as dangerous. Not clear whether adding an opt-in parameter is necessary or not
This commit is contained in:
parent
dad949eb99
commit
e188d4ecb0
@ -41,15 +41,32 @@ class RequestsToolkit(BaseToolkit):
|
||||
"""
|
||||
|
||||
requests_wrapper: TextRequestsWrapper
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Allow dangerous requests. See documentation for details."""
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Return a list of tools."""
|
||||
return [
|
||||
RequestsGetTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsPostTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsPatchTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsPutTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsDeleteTool(requests_wrapper=self.requests_wrapper),
|
||||
RequestsGetTool(
|
||||
requests_wrapper=self.requests_wrapper,
|
||||
allow_dangerous_requests=self.allow_dangerous_requests,
|
||||
),
|
||||
RequestsPostTool(
|
||||
requests_wrapper=self.requests_wrapper,
|
||||
allow_dangerous_requests=self.allow_dangerous_requests,
|
||||
),
|
||||
RequestsPatchTool(
|
||||
requests_wrapper=self.requests_wrapper,
|
||||
allow_dangerous_requests=self.allow_dangerous_requests,
|
||||
),
|
||||
RequestsPutTool(
|
||||
requests_wrapper=self.requests_wrapper,
|
||||
allow_dangerous_requests=self.allow_dangerous_requests,
|
||||
),
|
||||
RequestsDeleteTool(
|
||||
requests_wrapper=self.requests_wrapper,
|
||||
allow_dangerous_requests=self.allow_dangerous_requests,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -66,6 +83,8 @@ class OpenAPIToolkit(BaseToolkit):
|
||||
|
||||
json_agent: Any
|
||||
requests_wrapper: TextRequestsWrapper
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Allow dangerous requests. See documentation for details."""
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
@ -74,7 +93,10 @@ class OpenAPIToolkit(BaseToolkit):
|
||||
func=self.json_agent.run,
|
||||
description=DESCRIPTION,
|
||||
)
|
||||
request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper)
|
||||
request_toolkit = RequestsToolkit(
|
||||
requests_wrapper=self.requests_wrapper,
|
||||
allow_dangerous_requests=self.allow_dangerous_requests,
|
||||
)
|
||||
return [*request_toolkit.get_tools(), json_agent_tool]
|
||||
|
||||
@classmethod
|
||||
@ -83,8 +105,13 @@ class OpenAPIToolkit(BaseToolkit):
|
||||
llm: BaseLanguageModel,
|
||||
json_spec: JsonSpec,
|
||||
requests_wrapper: TextRequestsWrapper,
|
||||
allow_dangerous_requests: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> OpenAPIToolkit:
|
||||
"""Create json agent from llm, then initialize."""
|
||||
json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs)
|
||||
return cls(json_agent=json_agent, requests_wrapper=requests_wrapper)
|
||||
return cls(
|
||||
json_agent=json_agent,
|
||||
requests_wrapper=requests_wrapper,
|
||||
allow_dangerous_requests=allow_dangerous_requests,
|
||||
)
|
||||
|
@ -28,6 +28,23 @@ class BaseRequestsTool(BaseModel):
|
||||
|
||||
requests_wrapper: GenericRequestsWrapper
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the tool."""
|
||||
if not kwargs.get("allow_dangerous_requests", False):
|
||||
raise ValueError(
|
||||
"You must set allow_dangerous_requests to True to use this tool. "
|
||||
"Request scan be dangerous and can lead to security vulnerabilities. "
|
||||
"For example, users can ask a server to make a request to an internal"
|
||||
"server. It's recommended to use requests through a proxy server "
|
||||
"and avoid accepting inputs from untrusted sources without proper "
|
||||
"sandboxing."
|
||||
"Please see: https://python.langchain.com/docs/security for "
|
||||
"further security information."
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class RequestsGetTool(BaseRequestsTool, BaseTool):
|
||||
"""Tool for making a GET request to an API endpoint."""
|
||||
|
@ -72,34 +72,44 @@ def test_parse_input() -> None:
|
||||
|
||||
|
||||
def test_requests_get_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||
tool = RequestsGetTool(requests_wrapper=mock_requests_wrapper)
|
||||
tool = RequestsGetTool(
|
||||
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
assert tool.run("https://example.com") == "get_response"
|
||||
assert asyncio.run(tool.arun("https://example.com")) == "aget_response"
|
||||
|
||||
|
||||
def test_requests_post_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||
tool = RequestsPostTool(requests_wrapper=mock_requests_wrapper)
|
||||
tool = RequestsPostTool(
|
||||
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||
assert tool.run(input_text) == "post {'key': 'value'}"
|
||||
assert asyncio.run(tool.arun(input_text)) == "apost {'key': 'value'}"
|
||||
|
||||
|
||||
def test_requests_patch_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||
tool = RequestsPatchTool(requests_wrapper=mock_requests_wrapper)
|
||||
tool = RequestsPatchTool(
|
||||
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||
assert tool.run(input_text) == "patch {'key': 'value'}"
|
||||
assert asyncio.run(tool.arun(input_text)) == "apatch {'key': 'value'}"
|
||||
|
||||
|
||||
def test_requests_put_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||
tool = RequestsPutTool(requests_wrapper=mock_requests_wrapper)
|
||||
tool = RequestsPutTool(
|
||||
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||
assert tool.run(input_text) == "put {'key': 'value'}"
|
||||
assert asyncio.run(tool.arun(input_text)) == "aput {'key': 'value'}"
|
||||
|
||||
|
||||
def test_requests_delete_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
|
||||
tool = RequestsDeleteTool(requests_wrapper=mock_requests_wrapper)
|
||||
tool = RequestsDeleteTool(
|
||||
requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
assert tool.run("https://example.com") == "delete_response"
|
||||
assert asyncio.run(tool.arun("https://example.com")) == "adelete_response"
|
||||
|
||||
@ -154,7 +164,9 @@ def mock_json_requests_wrapper() -> JsonRequestsWrapper:
|
||||
def test_requests_get_tool_json(
|
||||
mock_json_requests_wrapper: JsonRequestsWrapper,
|
||||
) -> None:
|
||||
tool = RequestsGetTool(requests_wrapper=mock_json_requests_wrapper)
|
||||
tool = RequestsGetTool(
|
||||
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
assert tool.run("https://example.com") == {"response": "get_response"}
|
||||
assert asyncio.run(tool.arun("https://example.com")) == {
|
||||
"response": "aget_response"
|
||||
@ -164,7 +176,9 @@ def test_requests_get_tool_json(
|
||||
def test_requests_post_tool_json(
|
||||
mock_json_requests_wrapper: JsonRequestsWrapper,
|
||||
) -> None:
|
||||
tool = RequestsPostTool(requests_wrapper=mock_json_requests_wrapper)
|
||||
tool = RequestsPostTool(
|
||||
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||
assert tool.run(input_text) == {"response": 'post {"key": "value"}'}
|
||||
assert asyncio.run(tool.arun(input_text)) == {"response": 'apost {"key": "value"}'}
|
||||
@ -173,7 +187,9 @@ def test_requests_post_tool_json(
|
||||
def test_requests_patch_tool_json(
|
||||
mock_json_requests_wrapper: JsonRequestsWrapper,
|
||||
) -> None:
|
||||
tool = RequestsPatchTool(requests_wrapper=mock_json_requests_wrapper)
|
||||
tool = RequestsPatchTool(
|
||||
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||
assert tool.run(input_text) == {"response": 'patch {"key": "value"}'}
|
||||
assert asyncio.run(tool.arun(input_text)) == {"response": 'apatch {"key": "value"}'}
|
||||
@ -182,7 +198,9 @@ def test_requests_patch_tool_json(
|
||||
def test_requests_put_tool_json(
|
||||
mock_json_requests_wrapper: JsonRequestsWrapper,
|
||||
) -> None:
|
||||
tool = RequestsPutTool(requests_wrapper=mock_json_requests_wrapper)
|
||||
tool = RequestsPutTool(
|
||||
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
|
||||
assert tool.run(input_text) == {"response": 'put {"key": "value"}'}
|
||||
assert asyncio.run(tool.arun(input_text)) == {"response": 'aput {"key": "value"}'}
|
||||
@ -191,7 +209,9 @@ def test_requests_put_tool_json(
|
||||
def test_requests_delete_tool_json(
|
||||
mock_json_requests_wrapper: JsonRequestsWrapper,
|
||||
) -> None:
|
||||
tool = RequestsDeleteTool(requests_wrapper=mock_json_requests_wrapper)
|
||||
tool = RequestsDeleteTool(
|
||||
requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True
|
||||
)
|
||||
assert tool.run("https://example.com") == {"response": "delete_response"}
|
||||
assert asyncio.run(tool.arun("https://example.com")) == {
|
||||
"response": "adelete_response"
|
||||
|
@ -106,23 +106,48 @@ from langchain_community.utilities.reddit_search import RedditSearchAPIWrapper
|
||||
|
||||
|
||||
def _get_tools_requests_get() -> BaseTool:
|
||||
return RequestsGetTool(requests_wrapper=TextRequestsWrapper())
|
||||
# Dangerous requests are allowed here, because there's another flag that the user
|
||||
# has to provide in order to actually opt in.
|
||||
# This is a private function and should not be used directly.
|
||||
return RequestsGetTool(
|
||||
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
|
||||
)
|
||||
|
||||
|
||||
def _get_tools_requests_post() -> BaseTool:
|
||||
return RequestsPostTool(requests_wrapper=TextRequestsWrapper())
|
||||
# Dangerous requests are allowed here, because there's another flag that the user
|
||||
# has to provide in order to actually opt in.
|
||||
# This is a private function and should not be used directly.
|
||||
return RequestsPostTool(
|
||||
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
|
||||
)
|
||||
|
||||
|
||||
def _get_tools_requests_patch() -> BaseTool:
|
||||
return RequestsPatchTool(requests_wrapper=TextRequestsWrapper())
|
||||
# Dangerous requests are allowed here, because there's another flag that the user
|
||||
# has to provide in order to actually opt in.
|
||||
# This is a private function and should not be used directly.
|
||||
return RequestsPatchTool(
|
||||
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
|
||||
)
|
||||
|
||||
|
||||
def _get_tools_requests_put() -> BaseTool:
|
||||
return RequestsPutTool(requests_wrapper=TextRequestsWrapper())
|
||||
# Dangerous requests are allowed here, because there's another flag that the user
|
||||
# has to provide in order to actually opt in.
|
||||
# This is a private function and should not be used directly.
|
||||
return RequestsPutTool(
|
||||
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
|
||||
)
|
||||
|
||||
|
||||
def _get_tools_requests_delete() -> BaseTool:
|
||||
return RequestsDeleteTool(requests_wrapper=TextRequestsWrapper())
|
||||
# Dangerous requests are allowed here, because there's another flag that the user
|
||||
# has to provide in order to actually opt in.
|
||||
# This is a private function and should not be used directly.
|
||||
return RequestsDeleteTool(
|
||||
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
|
||||
)
|
||||
|
||||
|
||||
def _get_terminal() -> BaseTool:
|
||||
@ -134,6 +159,15 @@ def _get_sleep() -> BaseTool:
|
||||
|
||||
|
||||
_BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
|
||||
"sleep": _get_sleep,
|
||||
}
|
||||
|
||||
DANGEROUS_TOOLS = {
|
||||
# Tools that contain some level of risk.
|
||||
# Please use with caution and read the documentation of these tools
|
||||
# to understand the risks and how to mitigate them.
|
||||
# Refer to https://python.langchain.com/docs/security
|
||||
# for more information.
|
||||
"requests": _get_tools_requests_get, # preserved for backwards compatibility
|
||||
"requests_get": _get_tools_requests_get,
|
||||
"requests_post": _get_tools_requests_post,
|
||||
@ -141,7 +175,6 @@ _BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
|
||||
"requests_put": _get_tools_requests_put,
|
||||
"requests_delete": _get_tools_requests_delete,
|
||||
"terminal": _get_terminal,
|
||||
"sleep": _get_sleep,
|
||||
}
|
||||
|
||||
|
||||
@ -541,6 +574,7 @@ def load_tools(
|
||||
tool_names: List[str],
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
callbacks: Callbacks = None,
|
||||
allow_dangerous_tools: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[BaseTool]:
|
||||
"""Load tools based on their name.
|
||||
@ -566,6 +600,15 @@ def load_tools(
|
||||
llm: An optional language model, may be needed to initialize certain tools.
|
||||
callbacks: Optional callback manager or list of callback handlers.
|
||||
If not provided, default global callback manager will be used.
|
||||
allow_dangerous_tools: Optional flag to allow dangerous tools.
|
||||
Tools that contain some level of risk.
|
||||
Please use with caution and read the documentation of these tools
|
||||
to understand the risks and how to mitigate them.
|
||||
Refer to https://python.langchain.com/docs/security
|
||||
for more information.
|
||||
Please note that this list may not be fully exhaustive.
|
||||
It is your responsibility to understand which tools
|
||||
you're using and the risks associated with them.
|
||||
|
||||
Returns:
|
||||
List of tools.
|
||||
@ -574,10 +617,26 @@ def load_tools(
|
||||
callbacks = _handle_callbacks(
|
||||
callback_manager=kwargs.get("callback_manager"), callbacks=callbacks
|
||||
)
|
||||
# print(_BASE_TOOLS)
|
||||
# print(1)
|
||||
for name in tool_names:
|
||||
if name == "requests":
|
||||
if name in DANGEROUS_TOOLS and not allow_dangerous_tools:
|
||||
raise ValueError(
|
||||
f"{name} is a dangerous tool. You cannot use it without opting in "
|
||||
"by setting allow_dangerous_tools to True. "
|
||||
"Most tools have some inherit risk to them merely because they are "
|
||||
'allowed to interact with the "real world".'
|
||||
"Please refer to LangChain security guidelines "
|
||||
"to https://python.langchain.com/docs/security."
|
||||
"Some tools have been designated as dangerous because they pose "
|
||||
"risk that is not intuitively obvious. For example, a tool that "
|
||||
"allows an agent to make requests to the web, can also be used "
|
||||
"to make requests to a server that is only accessible from the "
|
||||
"server hosting the code."
|
||||
"Again, all tools carry some risk, and it's your responsibility to "
|
||||
"understand which tools you're using and the risks associated with "
|
||||
"them."
|
||||
)
|
||||
|
||||
if name in {"requests"}:
|
||||
warnings.warn(
|
||||
"tool name `requests` is deprecated - "
|
||||
"please use `requests_all` or specify the requests method"
|
||||
@ -590,6 +649,8 @@ def load_tools(
|
||||
tool_names.extend(requests_method_tools)
|
||||
elif name in _BASE_TOOLS:
|
||||
tools.append(_BASE_TOOLS[name]())
|
||||
elif name in DANGEROUS_TOOLS:
|
||||
tools.append(DANGEROUS_TOOLS[name]())
|
||||
elif name in _LLM_TOOLS:
|
||||
if llm is None:
|
||||
raise ValueError(f"Tool {name} requires an LLM to be provided")
|
||||
@ -628,4 +689,5 @@ def get_all_tool_names() -> List[str]:
|
||||
+ list(_EXTRA_OPTIONAL_TOOLS)
|
||||
+ list(_EXTRA_LLM_TOOLS)
|
||||
+ list(_LLM_TOOLS)
|
||||
+ list(DANGEROUS_TOOLS)
|
||||
)
|
||||
|
@ -71,7 +71,11 @@ def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
|
||||
"""Test load_tools raises a deprecation for old callback manager kwarg."""
|
||||
callback_manager = MagicMock()
|
||||
with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"):
|
||||
tools = load_tools(["requests_get"], callback_manager=callback_manager)
|
||||
tools = load_tools(
|
||||
["requests_get"],
|
||||
callback_manager=callback_manager,
|
||||
allow_dangerous_tools=True,
|
||||
)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].callbacks == callback_manager
|
||||
|
||||
@ -79,7 +83,11 @@ def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
|
||||
def test_load_tools_with_callbacks_is_called() -> None:
|
||||
"""Test callbacks are called when provided to load_tools fn."""
|
||||
callbacks = [FakeCallbackHandler()]
|
||||
tools = load_tools(["requests_get"], callbacks=callbacks) # type: ignore
|
||||
tools = load_tools(
|
||||
["requests_get"], # type: ignore
|
||||
callbacks=callbacks, # type: ignore
|
||||
allow_dangerous_tools=True,
|
||||
)
|
||||
assert len(tools) == 1
|
||||
# Patch the requests.get() method to return a mock response
|
||||
with unittest.mock.patch(
|
||||
|
Loading…
Reference in New Issue
Block a user