Add Shell Tool (#3335)

Create an official bash shell tool to replace the dynamically generated one
This commit is contained in:
Zander Chase 2023-04-28 11:10:43 -07:00 committed by GitHub
parent 334c162f16
commit 5042bd40d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 242 additions and 124 deletions

View File

@ -5,158 +5,158 @@
"id": "8f210ec3", "id": "8f210ec3",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Bash\n", "# Shell Tool\n",
"It can often be useful to have an LLM generate bash commands, and then run them. A common use case for this is letting the LLM interact with your local file system. We provide an easy util to execute bash commands." "\n",
"Giving agents access to the shell is powerful (though risky outside a sandboxed environment).\n",
"\n",
"The LLM can use it to execute any shell commands. A common use case for this is letting the LLM interact with your local file system."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"id": "f7b3767b", "id": "f7b3767b",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.utilities import BashProcess" "from langchain.tools import ShellTool\n",
"\n",
"shell_tool = ShellTool()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"id": "cf1c92f0", "id": "c92ac832-556b-4f66-baa4-b78f965dfba0",
"metadata": {}, "metadata": {
"outputs": [], "tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello World!\n",
"\n",
"real\t0m0.000s\n",
"user\t0m0.000s\n",
"sys\t0m0.000s\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wfh/code/lc/lckg/langchain/tools/shell/tool.py:34: UserWarning: The shell tool has no safeguards by default. Use at your own risk.\n",
" warnings.warn(\n"
]
}
],
"source": [ "source": [
"bash = BashProcess()" "print(shell_tool.run({\"commands\": [\"echo 'Hello World!'\", \"time\"]}))"
]
},
{
"cell_type": "markdown",
"id": "2fa952fc",
"metadata": {},
"source": [
"### Use with Agents\n",
"\n",
"As with all tools, these can be given to an agent to accomplish more complex tasks. Let's have the agent fetch some links from a web page."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"id": "2fa952fc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"apify.ipynb\n",
"arxiv.ipynb\n",
"bash.ipynb\n",
"bing_search.ipynb\n",
"chatgpt_plugins.ipynb\n",
"ddg.ipynb\n",
"google_places.ipynb\n",
"google_search.ipynb\n",
"google_serper.ipynb\n",
"gradio_tools.ipynb\n",
"human_tools.ipynb\n",
"ifttt.ipynb\n",
"openweathermap.ipynb\n",
"python.ipynb\n",
"requests.ipynb\n",
"search_tools.ipynb\n",
"searx_search.ipynb\n",
"serpapi.ipynb\n",
"wikipedia.ipynb\n",
"wolfram_alpha.ipynb\n",
"zapier.ipynb\n",
"\n"
]
}
],
"source": [
"print(bash.run(\"ls\"))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e7896f8e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"apify.ipynb\n",
"arxiv.ipynb\n",
"bash.ipynb\n",
"bing_search.ipynb\n",
"chatgpt_plugins.ipynb\n",
"ddg.ipynb\n",
"google_places.ipynb\n",
"google_search.ipynb\n",
"google_serper.ipynb\n",
"gradio_tools.ipynb\n",
"human_tools.ipynb\n",
"ifttt.ipynb\n",
"openweathermap.ipynb\n",
"python.ipynb\n",
"requests.ipynb\n",
"search_tools.ipynb\n",
"searx_search.ipynb\n",
"serpapi.ipynb\n",
"wikipedia.ipynb\n",
"wolfram_alpha.ipynb\n",
"zapier.ipynb\n",
"\n"
]
}
],
"source": [
"bash.run(\"cd ..\")\n",
"# The commands are executed in a new subprocess each time, meaning that\n",
"# this call will return the same results as the last.\n",
"print(bash.run(\"ls\"))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "851fee9f", "id": "851fee9f",
"metadata": {}, "metadata": {
"source": [ "tags": []
"## Terminal Persistance\n",
"\n",
"By default, the bash command will be executed in a new subprocess each time. To retain a persistent bash session, we can use the `persistent=True` arg."
]
}, },
{
"cell_type": "code",
"execution_count": 5,
"id": "4a93ea2c",
"metadata": {},
"outputs": [],
"source": [
"bash = BashProcess(persistent=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a1e98b78",
"metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"custom_tools.ipynb\t\tmulti_input_tool.ipynb\n", "\n",
"examples\t\t\ttool_input_validation.ipynb\n", "\n",
"getting_started.md\n" "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mQuestion: What is the task?\n",
"Thought: We need to download the langchain.com webpage and extract all the URLs from it. Then we need to sort the URLs and return them.\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"shell\",\n",
" \"action_input\": {\n",
" \"commands\": [\n",
" \"curl -s https://langchain.com | grep -o 'http[s]*://[^\\\" ]*' | sort\"\n",
" ]\n",
" }\n",
"}\n",
"```\n",
"\u001b[0m"
] ]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wfh/code/lc/lckg/langchain/tools/shell/tool.py:34: UserWarning: The shell tool has no safeguards by default. Use at your own risk.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Observation: \u001b[36;1m\u001b[1;3mhttps://blog.langchain.dev/\n",
"https://discord.gg/6adMQxSpJS\n",
"https://docs.langchain.com/docs/\n",
"https://github.com/hwchase17/chat-langchain\n",
"https://github.com/hwchase17/langchain\n",
"https://github.com/hwchase17/langchainjs\n",
"https://github.com/sullivan-sean/chat-langchainjs\n",
"https://js.langchain.com/docs/\n",
"https://python.langchain.com/en/latest/\n",
"https://twitter.com/langchainai\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThe URLs have been successfully extracted and sorted. We can return the list of URLs as the final answer.\n",
"Final Answer: [\"https://blog.langchain.dev/\", \"https://discord.gg/6adMQxSpJS\", \"https://docs.langchain.com/docs/\", \"https://github.com/hwchase17/chat-langchain\", \"https://github.com/hwchase17/langchain\", \"https://github.com/hwchase17/langchainjs\", \"https://github.com/sullivan-sean/chat-langchainjs\", \"https://js.langchain.com/docs/\", \"https://python.langchain.com/en/latest/\", \"https://twitter.com/langchainai\"]\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'[\"https://blog.langchain.dev/\", \"https://discord.gg/6adMQxSpJS\", \"https://docs.langchain.com/docs/\", \"https://github.com/hwchase17/chat-langchain\", \"https://github.com/hwchase17/langchain\", \"https://github.com/hwchase17/langchainjs\", \"https://github.com/sullivan-sean/chat-langchainjs\", \"https://js.langchain.com/docs/\", \"https://python.langchain.com/en/latest/\", \"https://twitter.com/langchainai\"]'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
"bash.run(\"cd ..\")\n", "from langchain.chat_models import ChatOpenAI\n",
"# Note the list of files is different\n", "from langchain.agents import initialize_agent\n",
"print(bash.run(\"ls\"))" "from langchain.agents import AgentType\n",
"\n",
"llm = ChatOpenAI(temperature=0)\n",
"\n",
"shell_tool.description = shell_tool.description + f\"args {shell_tool.args}\".replace(\"{\", \"{{\").replace(\"}\", \"}}\")\n",
"self_ask_with_search = initialize_agent([shell_tool], llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)\n",
"self_ask_with_search.run(\"Download the langchain.com webpage and grep for all urls. Return only a sorted list of them. Be sure to use double quotes.\")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "e13c1c9c", "id": "8d0ea3ac-0890-4e39-9cec-74bd80b4b8b8",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

View File

@ -27,6 +27,7 @@ from langchain.tools.requests.tool import (
RequestsPutTool, RequestsPutTool,
) )
from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun
from langchain.tools.shell.tool import ShellTool
from langchain.tools.wikipedia.tool import WikipediaQueryRun from langchain.tools.wikipedia.tool import WikipediaQueryRun
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
from langchain.utilities import ArxivAPIWrapper from langchain.utilities import ArxivAPIWrapper
@ -67,11 +68,7 @@ def _get_tools_requests_delete() -> BaseTool:
def _get_terminal() -> BaseTool: def _get_terminal() -> BaseTool:
return Tool( return ShellTool()
name="Terminal",
description="Executes commands in a terminal. Input should be valid commands, and the output will be any output from running that command.",
func=BashProcess().run,
)
_BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = { _BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {

View File

@ -26,9 +26,11 @@ from langchain.tools.playwright import (
NavigateTool, NavigateTool,
) )
from langchain.tools.plugin import AIPluginTool from langchain.tools.plugin import AIPluginTool
from langchain.tools.shell.tool import ShellTool
__all__ = [ __all__ = [
"AIPluginTool", "AIPluginTool",
"APIOperation",
"BaseBrowserTool", "BaseBrowserTool",
"BaseTool", "BaseTool",
"BaseTool", "BaseTool",
@ -55,6 +57,6 @@ __all__ = [
"NavigateTool", "NavigateTool",
"OpenAPISpec", "OpenAPISpec",
"ReadFileTool", "ReadFileTool",
"ShellTool",
"WriteFileTool", "WriteFileTool",
"APIOperation",
] ]

View File

@ -0,0 +1,5 @@
"""Shell tool."""
from langchain.tools.shell.tool import ShellTool
__all__ = ["ShellTool"]

View File

@ -0,0 +1,71 @@
import asyncio
import platform
import warnings
from typing import List, Type
from pydantic import BaseModel, Field, root_validator
from langchain.tools.base import BaseTool
from langchain.utilities.bash import BashProcess
class ShellInput(BaseModel):
"""Commands for the Bash Shell tool."""
commands: List[str] = Field(
...,
description="List of shell commands to run. Deserialized using json.loads",
)
"""List of shell commands to run."""
@root_validator
def _validate_commands(cls, values: dict) -> dict:
"""Validate commands."""
# TODO: Add real validators
commands = values.get("commands")
if not isinstance(commands, list):
values["commands"] = [commands]
# Warn that the bash tool is not safe
warnings.warn(
"The shell tool has no safeguards by default. Use at your own risk."
)
return values
def _get_default_bash_processs() -> BashProcess:
"""Get file path from string."""
return BashProcess(return_err_output=True)
def _get_platform() -> str:
"""Get platform."""
system = platform.system()
if system == "Darwin":
return "MacOS"
return system
class ShellTool(BaseTool):
"""Tool to run shell commands."""
process: BashProcess = Field(default_factory=_get_default_bash_processs)
"""Bash process to run commands."""
name: str = "terminal"
"""Name of tool."""
description: str = f"Run shell commands on this {_get_platform()} machine."
"""Description of tool."""
args_schema: Type[BaseModel] = ShellInput
"""Schema for input arguments."""
def _run(self, commands: List[str]) -> str:
"""Run commands and return final output."""
return self.process.run(commands)
async def _arun(self, commands: List[str]) -> str:
"""Run commands asynchronously and return final output."""
return await asyncio.get_event_loop().run_in_executor(
None, self.process.run, commands
)

View File

View File

@ -0,0 +1,43 @@
import warnings
import pytest
from langchain.tools.shell.tool import ShellInput, ShellTool
# Test data
test_commands = ["echo 'Hello, World!'", "echo 'Another command'"]
def test_shell_input_validation() -> None:
shell_input = ShellInput(commands=test_commands)
assert isinstance(shell_input.commands, list)
assert len(shell_input.commands) == 2
with warnings.catch_warnings(record=True) as w:
ShellInput(commands=test_commands)
assert len(w) == 1
assert (
str(w[-1].message)
== "The shell tool has no safeguards by default. Use at your own risk."
)
def test_shell_tool_init() -> None:
shell_tool = ShellTool()
assert shell_tool.name == "terminal"
assert isinstance(shell_tool.description, str)
assert shell_tool.args_schema == ShellInput
assert shell_tool.process is not None
@pytest.mark.asyncio
async def test_shell_tool_arun() -> None:
shell_tool = ShellTool()
result = await shell_tool._arun(commands=test_commands)
assert result.strip() == "Hello, World!\nAnother command"
def test_shell_tool_run() -> None:
shell_tool = ShellTool()
result = shell_tool._run(commands=test_commands)
assert result.strip() == "Hello, World!\nAnother command"