From 5042bd40d300312c23354c178c552749c8eda632 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Fri, 28 Apr 2023 11:10:43 -0700 Subject: [PATCH] Add Shell Tool (#3335) Create an official bash shell tool to replace the dynamically generated one --- docs/modules/agents/tools/examples/bash.ipynb | 200 +++++++++--------- langchain/agents/load_tools.py | 7 +- langchain/tools/__init__.py | 4 +- langchain/tools/shell/__init__.py | 5 + langchain/tools/shell/tool.py | 71 +++++++ tests/unit_tests/tools/shell/__init__.py | 0 tests/unit_tests/tools/shell/test_shell.py | 43 ++++ 7 files changed, 224 insertions(+), 106 deletions(-) create mode 100644 langchain/tools/shell/__init__.py create mode 100644 langchain/tools/shell/tool.py create mode 100644 tests/unit_tests/tools/shell/__init__.py create mode 100644 tests/unit_tests/tools/shell/test_shell.py diff --git a/docs/modules/agents/tools/examples/bash.ipynb b/docs/modules/agents/tools/examples/bash.ipynb index 3ec32241..117f296b 100644 --- a/docs/modules/agents/tools/examples/bash.ipynb +++ b/docs/modules/agents/tools/examples/bash.ipynb @@ -5,158 +5,158 @@ "id": "8f210ec3", "metadata": {}, "source": [ - "# Bash\n", - "It can often be useful to have an LLM generate bash commands, and then run them. A common use case for this is letting the LLM interact with your local file system. We provide an easy util to execute bash commands." + "# Shell Tool\n", + "\n", + "Giving agents access to the shell is powerful (though risky outside a sandboxed environment).\n", + "\n", + "The LLM can use it to execute any shell commands. A common use case for this is letting the LLM interact with your local file system." ] }, { "cell_type": "code", "execution_count": 1, "id": "f7b3767b", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "from langchain.utilities import BashProcess" + "from langchain.tools import ShellTool\n", + "\n", + "shell_tool = ShellTool()" ] }, { "cell_type": "code", "execution_count": 2, - "id": "cf1c92f0", - "metadata": {}, - "outputs": [], - "source": [ - "bash = BashProcess()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "2fa952fc", - "metadata": {}, + "id": "c92ac832-556b-4f66-baa4-b78f965dfba0", + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "apify.ipynb\n", - "arxiv.ipynb\n", - "bash.ipynb\n", - "bing_search.ipynb\n", - "chatgpt_plugins.ipynb\n", - "ddg.ipynb\n", - "google_places.ipynb\n", - "google_search.ipynb\n", - "google_serper.ipynb\n", - "gradio_tools.ipynb\n", - "human_tools.ipynb\n", - "ifttt.ipynb\n", - "openweathermap.ipynb\n", - "python.ipynb\n", - "requests.ipynb\n", - "search_tools.ipynb\n", - "searx_search.ipynb\n", - "serpapi.ipynb\n", - "wikipedia.ipynb\n", - "wolfram_alpha.ipynb\n", - "zapier.ipynb\n", + "Hello World!\n", + "\n", + "real\t0m0.000s\n", + "user\t0m0.000s\n", + "sys\t0m0.000s\n", "\n" ] - } - ], - "source": [ - "print(bash.run(\"ls\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "e7896f8e", - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "apify.ipynb\n", - "arxiv.ipynb\n", - "bash.ipynb\n", - "bing_search.ipynb\n", - "chatgpt_plugins.ipynb\n", - "ddg.ipynb\n", - "google_places.ipynb\n", - "google_search.ipynb\n", - "google_serper.ipynb\n", - "gradio_tools.ipynb\n", - "human_tools.ipynb\n", - "ifttt.ipynb\n", - "openweathermap.ipynb\n", - "python.ipynb\n", - "requests.ipynb\n", - "search_tools.ipynb\n", - "searx_search.ipynb\n", - "serpapi.ipynb\n", - "wikipedia.ipynb\n", - "wolfram_alpha.ipynb\n", - "zapier.ipynb\n", - "\n" + "/Users/wfh/code/lc/lckg/langchain/tools/shell/tool.py:34: UserWarning: The shell tool has no safeguards by default. Use at your own risk.\n", + " warnings.warn(\n" ] } ], "source": [ - "bash.run(\"cd ..\")\n", - "# The commands are executed in a new subprocess each time, meaning that\n", - "# this call will return the same results as the last.\n", - "print(bash.run(\"ls\"))" + "print(shell_tool.run({\"commands\": [\"echo 'Hello World!'\", \"time\"]}))" ] }, { - "attachments": {}, "cell_type": "markdown", - "id": "851fee9f", + "id": "2fa952fc", "metadata": {}, "source": [ - "## Terminal Persistance\n", + "### Use with Agents\n", "\n", - "By default, the bash command will be executed in a new subprocess each time. To retain a persistent bash session, we can use the `persistent=True` arg." + "As with all tools, these can be given to an agent to accomplish more complex tasks. Let's have the agent fetch some links from a web page." ] }, { "cell_type": "code", - "execution_count": 5, - "id": "4a93ea2c", - "metadata": {}, - "outputs": [], - "source": [ - "bash = BashProcess(persistent=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "a1e98b78", - "metadata": {}, + "execution_count": 3, + "id": "851fee9f", + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "custom_tools.ipynb\t\tmulti_input_tool.ipynb\n", - "examples\t\t\ttool_input_validation.ipynb\n", - "getting_started.md\n" + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mQuestion: What is the task?\n", + "Thought: We need to download the langchain.com webpage and extract all the URLs from it. Then we need to sort the URLs and return them.\n", + "Action:\n", + "```\n", + "{\n", + " \"action\": \"shell\",\n", + " \"action_input\": {\n", + " \"commands\": [\n", + " \"curl -s https://langchain.com | grep -o 'http[s]*://[^\\\" ]*' | sort\"\n", + " ]\n", + " }\n", + "}\n", + "```\n", + "\u001b[0m" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/wfh/code/lc/lckg/langchain/tools/shell/tool.py:34: UserWarning: The shell tool has no safeguards by default. Use at your own risk.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Observation: \u001b[36;1m\u001b[1;3mhttps://blog.langchain.dev/\n", + "https://discord.gg/6adMQxSpJS\n", + "https://docs.langchain.com/docs/\n", + "https://github.com/hwchase17/chat-langchain\n", + "https://github.com/hwchase17/langchain\n", + "https://github.com/hwchase17/langchainjs\n", + "https://github.com/sullivan-sean/chat-langchainjs\n", + "https://js.langchain.com/docs/\n", + "https://python.langchain.com/en/latest/\n", + "https://twitter.com/langchainai\n", + "\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mThe URLs have been successfully extracted and sorted. We can return the list of URLs as the final answer.\n", + "Final Answer: [\"https://blog.langchain.dev/\", \"https://discord.gg/6adMQxSpJS\", \"https://docs.langchain.com/docs/\", \"https://github.com/hwchase17/chat-langchain\", \"https://github.com/hwchase17/langchain\", \"https://github.com/hwchase17/langchainjs\", \"https://github.com/sullivan-sean/chat-langchainjs\", \"https://js.langchain.com/docs/\", \"https://python.langchain.com/en/latest/\", \"https://twitter.com/langchainai\"]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'[\"https://blog.langchain.dev/\", \"https://discord.gg/6adMQxSpJS\", \"https://docs.langchain.com/docs/\", \"https://github.com/hwchase17/chat-langchain\", \"https://github.com/hwchase17/langchain\", \"https://github.com/hwchase17/langchainjs\", \"https://github.com/sullivan-sean/chat-langchainjs\", \"https://js.langchain.com/docs/\", \"https://python.langchain.com/en/latest/\", \"https://twitter.com/langchainai\"]'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "bash.run(\"cd ..\")\n", - "# Note the list of files is different\n", - "print(bash.run(\"ls\"))" + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.agents import initialize_agent\n", + "from langchain.agents import AgentType\n", + "\n", + "llm = ChatOpenAI(temperature=0)\n", + "\n", + "shell_tool.description = shell_tool.description + f\"args {shell_tool.args}\".replace(\"{\", \"{{\").replace(\"}\", \"}}\")\n", + "self_ask_with_search = initialize_agent([shell_tool], llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)\n", + "self_ask_with_search.run(\"Download the langchain.com webpage and grep for all urls. Return only a sorted list of them. Be sure to use double quotes.\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "e13c1c9c", + "id": "8d0ea3ac-0890-4e39-9cec-74bd80b4b8b8", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index ade98c96..399e6b6c 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -27,6 +27,7 @@ from langchain.tools.requests.tool import ( RequestsPutTool, ) from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun +from langchain.tools.shell.tool import ShellTool from langchain.tools.wikipedia.tool import WikipediaQueryRun from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun from langchain.utilities import ArxivAPIWrapper @@ -67,11 +68,7 @@ def _get_tools_requests_delete() -> BaseTool: def _get_terminal() -> BaseTool: - return Tool( - name="Terminal", - description="Executes commands in a terminal. Input should be valid commands, and the output will be any output from running that command.", - func=BashProcess().run, - ) + return ShellTool() _BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = { diff --git a/langchain/tools/__init__.py b/langchain/tools/__init__.py index a95590cd..f8080cd5 100644 --- a/langchain/tools/__init__.py +++ b/langchain/tools/__init__.py @@ -26,9 +26,11 @@ from langchain.tools.playwright import ( NavigateTool, ) from langchain.tools.plugin import AIPluginTool +from langchain.tools.shell.tool import ShellTool __all__ = [ "AIPluginTool", + "APIOperation", "BaseBrowserTool", "BaseTool", "BaseTool", @@ -55,6 +57,6 @@ __all__ = [ "NavigateTool", "OpenAPISpec", "ReadFileTool", + "ShellTool", "WriteFileTool", - "APIOperation", ] diff --git a/langchain/tools/shell/__init__.py b/langchain/tools/shell/__init__.py new file mode 100644 index 00000000..991b8f6a --- /dev/null +++ b/langchain/tools/shell/__init__.py @@ -0,0 +1,5 @@ +"""Shell tool.""" + +from langchain.tools.shell.tool import ShellTool + +__all__ = ["ShellTool"] diff --git a/langchain/tools/shell/tool.py b/langchain/tools/shell/tool.py new file mode 100644 index 00000000..8f9ecaef --- /dev/null +++ b/langchain/tools/shell/tool.py @@ -0,0 +1,71 @@ +import asyncio +import platform +import warnings +from typing import List, Type + +from pydantic import BaseModel, Field, root_validator + +from langchain.tools.base import BaseTool +from langchain.utilities.bash import BashProcess + + +class ShellInput(BaseModel): + """Commands for the Bash Shell tool.""" + + commands: List[str] = Field( + ..., + description="List of shell commands to run. Deserialized using json.loads", + ) + """List of shell commands to run.""" + + @root_validator + def _validate_commands(cls, values: dict) -> dict: + """Validate commands.""" + # TODO: Add real validators + commands = values.get("commands") + if not isinstance(commands, list): + values["commands"] = [commands] + # Warn that the bash tool is not safe + warnings.warn( + "The shell tool has no safeguards by default. Use at your own risk." + ) + return values + + +def _get_default_bash_processs() -> BashProcess: + """Get file path from string.""" + return BashProcess(return_err_output=True) + + +def _get_platform() -> str: + """Get platform.""" + system = platform.system() + if system == "Darwin": + return "MacOS" + return system + + +class ShellTool(BaseTool): + """Tool to run shell commands.""" + + process: BashProcess = Field(default_factory=_get_default_bash_processs) + """Bash process to run commands.""" + + name: str = "terminal" + """Name of tool.""" + + description: str = f"Run shell commands on this {_get_platform()} machine." + """Description of tool.""" + + args_schema: Type[BaseModel] = ShellInput + """Schema for input arguments.""" + + def _run(self, commands: List[str]) -> str: + """Run commands and return final output.""" + return self.process.run(commands) + + async def _arun(self, commands: List[str]) -> str: + """Run commands asynchronously and return final output.""" + return await asyncio.get_event_loop().run_in_executor( + None, self.process.run, commands + ) diff --git a/tests/unit_tests/tools/shell/__init__.py b/tests/unit_tests/tools/shell/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/tools/shell/test_shell.py b/tests/unit_tests/tools/shell/test_shell.py new file mode 100644 index 00000000..48ab3f37 --- /dev/null +++ b/tests/unit_tests/tools/shell/test_shell.py @@ -0,0 +1,43 @@ +import warnings + +import pytest + +from langchain.tools.shell.tool import ShellInput, ShellTool + +# Test data +test_commands = ["echo 'Hello, World!'", "echo 'Another command'"] + + +def test_shell_input_validation() -> None: + shell_input = ShellInput(commands=test_commands) + assert isinstance(shell_input.commands, list) + assert len(shell_input.commands) == 2 + + with warnings.catch_warnings(record=True) as w: + ShellInput(commands=test_commands) + assert len(w) == 1 + assert ( + str(w[-1].message) + == "The shell tool has no safeguards by default. Use at your own risk." + ) + + +def test_shell_tool_init() -> None: + shell_tool = ShellTool() + assert shell_tool.name == "terminal" + assert isinstance(shell_tool.description, str) + assert shell_tool.args_schema == ShellInput + assert shell_tool.process is not None + + +@pytest.mark.asyncio +async def test_shell_tool_arun() -> None: + shell_tool = ShellTool() + result = await shell_tool._arun(commands=test_commands) + assert result.strip() == "Hello, World!\nAnother command" + + +def test_shell_tool_run() -> None: + shell_tool = ShellTool() + result = shell_tool._run(commands=test_commands) + assert result.strip() == "Hello, World!\nAnother command"