mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add Shell Tool (#3335)
Create an official bash shell tool to replace the dynamically generated one
This commit is contained in:
parent
334c162f16
commit
5042bd40d3
@ -5,158 +5,158 @@
|
|||||||
"id": "8f210ec3",
|
"id": "8f210ec3",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Bash\n",
|
"# Shell Tool\n",
|
||||||
"It can often be useful to have an LLM generate bash commands, and then run them. A common use case for this is letting the LLM interact with your local file system. We provide an easy util to execute bash commands."
|
"\n",
|
||||||
|
"Giving agents access to the shell is powerful (though risky outside a sandboxed environment).\n",
|
||||||
|
"\n",
|
||||||
|
"The LLM can use it to execute any shell commands. A common use case for this is letting the LLM interact with your local file system."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"id": "f7b3767b",
|
"id": "f7b3767b",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.utilities import BashProcess"
|
"from langchain.tools import ShellTool\n",
|
||||||
|
"\n",
|
||||||
|
"shell_tool = ShellTool()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"id": "cf1c92f0",
|
"id": "c92ac832-556b-4f66-baa4-b78f965dfba0",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
"outputs": [],
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Hello World!\n",
|
||||||
|
"\n",
|
||||||
|
"real\t0m0.000s\n",
|
||||||
|
"user\t0m0.000s\n",
|
||||||
|
"sys\t0m0.000s\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/wfh/code/lc/lckg/langchain/tools/shell/tool.py:34: UserWarning: The shell tool has no safeguards by default. Use at your own risk.\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"bash = BashProcess()"
|
"print(shell_tool.run({\"commands\": [\"echo 'Hello World!'\", \"time\"]}))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2fa952fc",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Use with Agents\n",
|
||||||
|
"\n",
|
||||||
|
"As with all tools, these can be given to an agent to accomplish more complex tasks. Let's have the agent fetch some links from a web page."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"id": "2fa952fc",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"apify.ipynb\n",
|
|
||||||
"arxiv.ipynb\n",
|
|
||||||
"bash.ipynb\n",
|
|
||||||
"bing_search.ipynb\n",
|
|
||||||
"chatgpt_plugins.ipynb\n",
|
|
||||||
"ddg.ipynb\n",
|
|
||||||
"google_places.ipynb\n",
|
|
||||||
"google_search.ipynb\n",
|
|
||||||
"google_serper.ipynb\n",
|
|
||||||
"gradio_tools.ipynb\n",
|
|
||||||
"human_tools.ipynb\n",
|
|
||||||
"ifttt.ipynb\n",
|
|
||||||
"openweathermap.ipynb\n",
|
|
||||||
"python.ipynb\n",
|
|
||||||
"requests.ipynb\n",
|
|
||||||
"search_tools.ipynb\n",
|
|
||||||
"searx_search.ipynb\n",
|
|
||||||
"serpapi.ipynb\n",
|
|
||||||
"wikipedia.ipynb\n",
|
|
||||||
"wolfram_alpha.ipynb\n",
|
|
||||||
"zapier.ipynb\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"print(bash.run(\"ls\"))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "e7896f8e",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"apify.ipynb\n",
|
|
||||||
"arxiv.ipynb\n",
|
|
||||||
"bash.ipynb\n",
|
|
||||||
"bing_search.ipynb\n",
|
|
||||||
"chatgpt_plugins.ipynb\n",
|
|
||||||
"ddg.ipynb\n",
|
|
||||||
"google_places.ipynb\n",
|
|
||||||
"google_search.ipynb\n",
|
|
||||||
"google_serper.ipynb\n",
|
|
||||||
"gradio_tools.ipynb\n",
|
|
||||||
"human_tools.ipynb\n",
|
|
||||||
"ifttt.ipynb\n",
|
|
||||||
"openweathermap.ipynb\n",
|
|
||||||
"python.ipynb\n",
|
|
||||||
"requests.ipynb\n",
|
|
||||||
"search_tools.ipynb\n",
|
|
||||||
"searx_search.ipynb\n",
|
|
||||||
"serpapi.ipynb\n",
|
|
||||||
"wikipedia.ipynb\n",
|
|
||||||
"wolfram_alpha.ipynb\n",
|
|
||||||
"zapier.ipynb\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"bash.run(\"cd ..\")\n",
|
|
||||||
"# The commands are executed in a new subprocess each time, meaning that\n",
|
|
||||||
"# this call will return the same results as the last.\n",
|
|
||||||
"print(bash.run(\"ls\"))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "851fee9f",
|
"id": "851fee9f",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
"source": [
|
"tags": []
|
||||||
"## Terminal Persistance\n",
|
|
||||||
"\n",
|
|
||||||
"By default, the bash command will be executed in a new subprocess each time. To retain a persistent bash session, we can use the `persistent=True` arg."
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "4a93ea2c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"bash = BashProcess(persistent=True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "a1e98b78",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"custom_tools.ipynb\t\tmulti_input_tool.ipynb\n",
|
"\n",
|
||||||
"examples\t\t\ttool_input_validation.ipynb\n",
|
"\n",
|
||||||
"getting_started.md\n"
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mQuestion: What is the task?\n",
|
||||||
|
"Thought: We need to download the langchain.com webpage and extract all the URLs from it. Then we need to sort the URLs and return them.\n",
|
||||||
|
"Action:\n",
|
||||||
|
"```\n",
|
||||||
|
"{\n",
|
||||||
|
" \"action\": \"shell\",\n",
|
||||||
|
" \"action_input\": {\n",
|
||||||
|
" \"commands\": [\n",
|
||||||
|
" \"curl -s https://langchain.com | grep -o 'http[s]*://[^\\\" ]*' | sort\"\n",
|
||||||
|
" ]\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n",
|
||||||
|
"```\n",
|
||||||
|
"\u001b[0m"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/wfh/code/lc/lckg/langchain/tools/shell/tool.py:34: UserWarning: The shell tool has no safeguards by default. Use at your own risk.\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"Observation: \u001b[36;1m\u001b[1;3mhttps://blog.langchain.dev/\n",
|
||||||
|
"https://discord.gg/6adMQxSpJS\n",
|
||||||
|
"https://docs.langchain.com/docs/\n",
|
||||||
|
"https://github.com/hwchase17/chat-langchain\n",
|
||||||
|
"https://github.com/hwchase17/langchain\n",
|
||||||
|
"https://github.com/hwchase17/langchainjs\n",
|
||||||
|
"https://github.com/sullivan-sean/chat-langchainjs\n",
|
||||||
|
"https://js.langchain.com/docs/\n",
|
||||||
|
"https://python.langchain.com/en/latest/\n",
|
||||||
|
"https://twitter.com/langchainai\n",
|
||||||
|
"\u001b[0m\n",
|
||||||
|
"Thought:\u001b[32;1m\u001b[1;3mThe URLs have been successfully extracted and sorted. We can return the list of URLs as the final answer.\n",
|
||||||
|
"Final Answer: [\"https://blog.langchain.dev/\", \"https://discord.gg/6adMQxSpJS\", \"https://docs.langchain.com/docs/\", \"https://github.com/hwchase17/chat-langchain\", \"https://github.com/hwchase17/langchain\", \"https://github.com/hwchase17/langchainjs\", \"https://github.com/sullivan-sean/chat-langchainjs\", \"https://js.langchain.com/docs/\", \"https://python.langchain.com/en/latest/\", \"https://twitter.com/langchainai\"]\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'[\"https://blog.langchain.dev/\", \"https://discord.gg/6adMQxSpJS\", \"https://docs.langchain.com/docs/\", \"https://github.com/hwchase17/chat-langchain\", \"https://github.com/hwchase17/langchain\", \"https://github.com/hwchase17/langchainjs\", \"https://github.com/sullivan-sean/chat-langchainjs\", \"https://js.langchain.com/docs/\", \"https://python.langchain.com/en/latest/\", \"https://twitter.com/langchainai\"]'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"bash.run(\"cd ..\")\n",
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
"# Note the list of files is different\n",
|
"from langchain.agents import initialize_agent\n",
|
||||||
"print(bash.run(\"ls\"))"
|
"from langchain.agents import AgentType\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatOpenAI(temperature=0)\n",
|
||||||
|
"\n",
|
||||||
|
"shell_tool.description = shell_tool.description + f\"args {shell_tool.args}\".replace(\"{\", \"{{\").replace(\"}\", \"}}\")\n",
|
||||||
|
"self_ask_with_search = initialize_agent([shell_tool], llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)\n",
|
||||||
|
"self_ask_with_search.run(\"Download the langchain.com webpage and grep for all urls. Return only a sorted list of them. Be sure to use double quotes.\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "e13c1c9c",
|
"id": "8d0ea3ac-0890-4e39-9cec-74bd80b4b8b8",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
|
@ -27,6 +27,7 @@ from langchain.tools.requests.tool import (
|
|||||||
RequestsPutTool,
|
RequestsPutTool,
|
||||||
)
|
)
|
||||||
from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun
|
from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun
|
||||||
|
from langchain.tools.shell.tool import ShellTool
|
||||||
from langchain.tools.wikipedia.tool import WikipediaQueryRun
|
from langchain.tools.wikipedia.tool import WikipediaQueryRun
|
||||||
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
|
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
|
||||||
from langchain.utilities import ArxivAPIWrapper
|
from langchain.utilities import ArxivAPIWrapper
|
||||||
@ -67,11 +68,7 @@ def _get_tools_requests_delete() -> BaseTool:
|
|||||||
|
|
||||||
|
|
||||||
def _get_terminal() -> BaseTool:
|
def _get_terminal() -> BaseTool:
|
||||||
return Tool(
|
return ShellTool()
|
||||||
name="Terminal",
|
|
||||||
description="Executes commands in a terminal. Input should be valid commands, and the output will be any output from running that command.",
|
|
||||||
func=BashProcess().run,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
|
_BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
|
||||||
|
@ -26,9 +26,11 @@ from langchain.tools.playwright import (
|
|||||||
NavigateTool,
|
NavigateTool,
|
||||||
)
|
)
|
||||||
from langchain.tools.plugin import AIPluginTool
|
from langchain.tools.plugin import AIPluginTool
|
||||||
|
from langchain.tools.shell.tool import ShellTool
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AIPluginTool",
|
"AIPluginTool",
|
||||||
|
"APIOperation",
|
||||||
"BaseBrowserTool",
|
"BaseBrowserTool",
|
||||||
"BaseTool",
|
"BaseTool",
|
||||||
"BaseTool",
|
"BaseTool",
|
||||||
@ -55,6 +57,6 @@ __all__ = [
|
|||||||
"NavigateTool",
|
"NavigateTool",
|
||||||
"OpenAPISpec",
|
"OpenAPISpec",
|
||||||
"ReadFileTool",
|
"ReadFileTool",
|
||||||
|
"ShellTool",
|
||||||
"WriteFileTool",
|
"WriteFileTool",
|
||||||
"APIOperation",
|
|
||||||
]
|
]
|
||||||
|
5
langchain/tools/shell/__init__.py
Normal file
5
langchain/tools/shell/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
"""Shell tool."""
|
||||||
|
|
||||||
|
from langchain.tools.shell.tool import ShellTool
|
||||||
|
|
||||||
|
__all__ = ["ShellTool"]
|
71
langchain/tools/shell/tool.py
Normal file
71
langchain/tools/shell/tool.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import asyncio
|
||||||
|
import platform
|
||||||
|
import warnings
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
|
||||||
|
from langchain.tools.base import BaseTool
|
||||||
|
from langchain.utilities.bash import BashProcess
|
||||||
|
|
||||||
|
|
||||||
|
class ShellInput(BaseModel):
|
||||||
|
"""Commands for the Bash Shell tool."""
|
||||||
|
|
||||||
|
commands: List[str] = Field(
|
||||||
|
...,
|
||||||
|
description="List of shell commands to run. Deserialized using json.loads",
|
||||||
|
)
|
||||||
|
"""List of shell commands to run."""
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def _validate_commands(cls, values: dict) -> dict:
|
||||||
|
"""Validate commands."""
|
||||||
|
# TODO: Add real validators
|
||||||
|
commands = values.get("commands")
|
||||||
|
if not isinstance(commands, list):
|
||||||
|
values["commands"] = [commands]
|
||||||
|
# Warn that the bash tool is not safe
|
||||||
|
warnings.warn(
|
||||||
|
"The shell tool has no safeguards by default. Use at your own risk."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_bash_processs() -> BashProcess:
|
||||||
|
"""Get file path from string."""
|
||||||
|
return BashProcess(return_err_output=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_platform() -> str:
|
||||||
|
"""Get platform."""
|
||||||
|
system = platform.system()
|
||||||
|
if system == "Darwin":
|
||||||
|
return "MacOS"
|
||||||
|
return system
|
||||||
|
|
||||||
|
|
||||||
|
class ShellTool(BaseTool):
|
||||||
|
"""Tool to run shell commands."""
|
||||||
|
|
||||||
|
process: BashProcess = Field(default_factory=_get_default_bash_processs)
|
||||||
|
"""Bash process to run commands."""
|
||||||
|
|
||||||
|
name: str = "terminal"
|
||||||
|
"""Name of tool."""
|
||||||
|
|
||||||
|
description: str = f"Run shell commands on this {_get_platform()} machine."
|
||||||
|
"""Description of tool."""
|
||||||
|
|
||||||
|
args_schema: Type[BaseModel] = ShellInput
|
||||||
|
"""Schema for input arguments."""
|
||||||
|
|
||||||
|
def _run(self, commands: List[str]) -> str:
|
||||||
|
"""Run commands and return final output."""
|
||||||
|
return self.process.run(commands)
|
||||||
|
|
||||||
|
async def _arun(self, commands: List[str]) -> str:
|
||||||
|
"""Run commands asynchronously and return final output."""
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self.process.run, commands
|
||||||
|
)
|
0
tests/unit_tests/tools/shell/__init__.py
Normal file
0
tests/unit_tests/tools/shell/__init__.py
Normal file
43
tests/unit_tests/tools/shell/test_shell.py
Normal file
43
tests/unit_tests/tools/shell/test_shell.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.tools.shell.tool import ShellInput, ShellTool
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
test_commands = ["echo 'Hello, World!'", "echo 'Another command'"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_shell_input_validation() -> None:
|
||||||
|
shell_input = ShellInput(commands=test_commands)
|
||||||
|
assert isinstance(shell_input.commands, list)
|
||||||
|
assert len(shell_input.commands) == 2
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
ShellInput(commands=test_commands)
|
||||||
|
assert len(w) == 1
|
||||||
|
assert (
|
||||||
|
str(w[-1].message)
|
||||||
|
== "The shell tool has no safeguards by default. Use at your own risk."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shell_tool_init() -> None:
|
||||||
|
shell_tool = ShellTool()
|
||||||
|
assert shell_tool.name == "terminal"
|
||||||
|
assert isinstance(shell_tool.description, str)
|
||||||
|
assert shell_tool.args_schema == ShellInput
|
||||||
|
assert shell_tool.process is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shell_tool_arun() -> None:
|
||||||
|
shell_tool = ShellTool()
|
||||||
|
result = await shell_tool._arun(commands=test_commands)
|
||||||
|
assert result.strip() == "Hello, World!\nAnother command"
|
||||||
|
|
||||||
|
|
||||||
|
def test_shell_tool_run() -> None:
|
||||||
|
shell_tool = ShellTool()
|
||||||
|
result = shell_tool._run(commands=test_commands)
|
||||||
|
assert result.strip() == "Hello, World!\nAnother command"
|
Loading…
Reference in New Issue
Block a user