From ee670c448e8aacbbc95df4c94da13d8509092b9a Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Wed, 26 Apr 2023 15:20:28 -0700 Subject: [PATCH] Persistent Bash Shell (#3580) Clean up linting and make more idiomatic by using an output parser --------- Co-authored-by: FergusFettes --- docs/modules/agents/tools/examples/bash.ipynb | 105 ++++++++++++++- docs/modules/chains/examples/llm_bash.ipynb | 124 ++++++++++++++++-- langchain/chains/llm_bash/base.py | 76 ++++++++--- langchain/utilities/bash.py | 67 +++++++++- tests/unit_tests/chains/test_llm_bash.py | 99 ++++++++++++-- tests/unit_tests/test_bash.py | 30 +++++ 6 files changed, 464 insertions(+), 37 deletions(-) diff --git a/docs/modules/agents/tools/examples/bash.ipynb b/docs/modules/agents/tools/examples/bash.ipynb index e16e930f..3ec32241 100644 --- a/docs/modules/agents/tools/examples/bash.ipynb +++ b/docs/modules/agents/tools/examples/bash.ipynb @@ -39,11 +39,27 @@ "name": "stdout", "output_type": "stream", "text": [ + "apify.ipynb\n", + "arxiv.ipynb\n", "bash.ipynb\n", + "bing_search.ipynb\n", + "chatgpt_plugins.ipynb\n", + "ddg.ipynb\n", + "google_places.ipynb\n", "google_search.ipynb\n", + "google_serper.ipynb\n", + "gradio_tools.ipynb\n", + "human_tools.ipynb\n", + "ifttt.ipynb\n", + "openweathermap.ipynb\n", "python.ipynb\n", "requests.ipynb\n", + "search_tools.ipynb\n", + "searx_search.ipynb\n", "serpapi.ipynb\n", + "wikipedia.ipynb\n", + "wolfram_alpha.ipynb\n", + "zapier.ipynb\n", "\n" ] } @@ -54,9 +70,94 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, + "id": "e7896f8e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "apify.ipynb\n", + "arxiv.ipynb\n", + "bash.ipynb\n", + "bing_search.ipynb\n", + "chatgpt_plugins.ipynb\n", + "ddg.ipynb\n", + "google_places.ipynb\n", + "google_search.ipynb\n", + "google_serper.ipynb\n", + "gradio_tools.ipynb\n", + "human_tools.ipynb\n", + "ifttt.ipynb\n", + "openweathermap.ipynb\n", + "python.ipynb\n", + "requests.ipynb\n", + "search_tools.ipynb\n", + "searx_search.ipynb\n", + "serpapi.ipynb\n", + "wikipedia.ipynb\n", + "wolfram_alpha.ipynb\n", + "zapier.ipynb\n", + "\n" + ] + } + ], + "source": [ + "bash.run(\"cd ..\")\n", + "# The commands are executed in a new subprocess each time, meaning that\n", + "# this call will return the same results as the last.\n", + "print(bash.run(\"ls\"))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", "id": "851fee9f", "metadata": {}, + "source": [ + "## Terminal Persistance\n", + "\n", + "By default, the bash command will be executed in a new subprocess each time. To retain a persistent bash session, we can use the `persistent=True` arg." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4a93ea2c", + "metadata": {}, + "outputs": [], + "source": [ + "bash = BashProcess(persistent=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a1e98b78", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "custom_tools.ipynb\t\tmulti_input_tool.ipynb\n", + "examples\t\t\ttool_input_validation.ipynb\n", + "getting_started.md\n" + ] + } + ], + "source": [ + "bash.run(\"cd ..\")\n", + "# Note the list of files is different\n", + "print(bash.run(\"ls\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e13c1c9c", + "metadata": {}, "outputs": [], "source": [] } @@ -77,7 +178,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/docs/modules/chains/examples/llm_bash.ipynb b/docs/modules/chains/examples/llm_bash.ipynb index 6a5f4d04..c2cb0fe6 100644 --- a/docs/modules/chains/examples/llm_bash.ipynb +++ b/docs/modules/chains/examples/llm_bash.ipynb @@ -24,8 +24,8 @@ "\n", "```bash\n", "echo \"Hello World\"\n", - "```\u001b[0m['```bash', 'echo \"Hello World\"', '```']\n", - "\n", + "```\u001b[0m\n", + "Code: \u001b[33;1m\u001b[1;3m['echo \"Hello World\"']\u001b[0m\n", "Answer: \u001b[33;1m\u001b[1;3mHello World\n", "\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -65,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -107,8 +107,8 @@ "\n", "```bash\n", "printf \"Hello World\\n\"\n", - "```\u001b[0m['```bash', 'printf \"Hello World\\\\n\"', '```']\n", - "\n", + "```\u001b[0m\n", + "Code: \u001b[33;1m\u001b[1;3m['printf \"Hello World\\\\n\"']\u001b[0m\n", "Answer: \u001b[33;1m\u001b[1;3mHello World\n", "\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -120,7 +120,7 @@ "'Hello World\\n'" ] }, - "execution_count": 29, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -132,6 +132,114 @@ "\n", "bash_chain.run(text)" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Persistent Terminal\n", + "\n", + "By default, the chain will run in a separate subprocess each time it is called. This behavior can be changed by instantiating with a persistent bash process." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n", + "List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n", + "\n", + "```bash\n", + "ls\n", + "cd ..\n", + "```\u001b[0m\n", + "Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3mapi.ipynb\t\t\tllm_summarization_checker.ipynb\n", + "constitutional_chain.ipynb\tmoderation.ipynb\n", + "llm_bash.ipynb\t\t\topenai_openapi.yaml\n", + "llm_checker.ipynb\t\topenapi.ipynb\n", + "llm_math.ipynb\t\t\tpal.ipynb\n", + "llm_requests.ipynb\t\tsqlite.ipynb\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'api.ipynb\\t\\t\\tllm_summarization_checker.ipynb\\r\\nconstitutional_chain.ipynb\\tmoderation.ipynb\\r\\nllm_bash.ipynb\\t\\t\\topenai_openapi.yaml\\r\\nllm_checker.ipynb\\t\\topenapi.ipynb\\r\\nllm_math.ipynb\\t\\t\\tpal.ipynb\\r\\nllm_requests.ipynb\\t\\tsqlite.ipynb'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.utilities.bash import BashProcess\n", + "\n", + "\n", + "persistent_process = BashProcess(persistent=True)\n", + "bash_chain = LLMBashChain.from_bash_process(llm=llm, bash_process=persistent_process, verbose=True)\n", + "\n", + "text = \"List the current directory then move up a level.\"\n", + "\n", + "bash_chain.run(text)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n", + "List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n", + "\n", + "```bash\n", + "ls\n", + "cd ..\n", + "```\u001b[0m\n", + "Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3mexamples\t\tgetting_started.ipynb\tindex_examples\n", + "generic\t\t\thow_to_guides.rst\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'examples\\t\\tgetting_started.ipynb\\tindex_examples\\r\\ngeneric\\t\\t\\thow_to_guides.rst'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run the same command again and see that the state is maintained between calls\n", + "bash_chain.run(text)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -150,7 +258,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 9a9f44b7..c2ae218f 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -1,15 +1,46 @@ """Chain that interprets a prompt and executes bash code to perform bash operations.""" -from typing import Dict, List +import logging +import re +from typing import Any, Dict, List -from pydantic import Extra +from pydantic import Extra, Field from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.prompt import PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel +from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException from langchain.utilities.bash import BashProcess +logger = logging.getLogger(__name__) + + +class BashOutputParser(BaseOutputParser): + """Parser for bash output.""" + + def parse(self, text: str) -> List[str]: + if "```bash" in text: + return self.get_code_blocks(text) + else: + raise OutputParserException( + f"Failed to parse bash output. Got: {text}", + ) + + @staticmethod + def get_code_blocks(t: str) -> List[str]: + """Get multiple code blocks from the LLM result.""" + code_blocks: List[str] = [] + # Bash markdown code blocks + pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL) + for match in pattern.finditer(t): + matched = match.group(1).strip() + if matched: + code_blocks.extend( + [line for line in matched.split("\n") if line.strip()] + ) + + return code_blocks + class LLMBashChain(Chain): """Chain that interprets a prompt and executes bash code to perform bash operations. @@ -26,6 +57,8 @@ class LLMBashChain(Chain): input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: prompt: BasePromptTemplate = PROMPT + output_parser: BaseOutputParser = Field(default_factory=BashOutputParser) + bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private: class Config: """Configuration for this pydantic object.""" @@ -51,29 +84,40 @@ class LLMBashChain(Chain): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: llm_executor = LLMChain(prompt=self.prompt, llm=self.llm) - bash_executor = BashProcess() + self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose) t = llm_executor.predict(question=inputs[self.input_key]) self.callback_manager.on_text(t, color="green", verbose=self.verbose) - t = t.strip() - if t.startswith("```bash"): - # Split the string into a list of substrings - command_list = t.split("\n") - print(command_list) + try: + command_list = self.output_parser.parse(t) + except OutputParserException as e: + self.callback_manager.on_chain_error(e, verbose=self.verbose) + raise e - # Remove the first and last substrings - command_list = [s for s in command_list[1:-1]] - output = bash_executor.run(command_list) + if self.verbose: + self.callback_manager.on_text("\nCode: ", verbose=self.verbose) + self.callback_manager.on_text( + str(command_list), color="yellow", verbose=self.verbose + ) - self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) - self.callback_manager.on_text(output, color="yellow", verbose=self.verbose) + output = self.bash_process.run(command_list) - else: - raise ValueError(f"unknown format from LLM: {t}") + self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) + self.callback_manager.on_text(output, color="yellow", verbose=self.verbose) return {self.output_key: output} @property def _chain_type(self) -> str: return "llm_bash_chain" + + @classmethod + def from_bash_process( + cls, + bash_process: BashProcess, + llm: BaseLanguageModel, + **kwargs: Any, + ) -> "LLMBashChain": + """Create a LLMBashChain from a BashProcess.""" + return cls(llm=llm, bash_process=bash_process, **kwargs) diff --git a/langchain/utilities/bash.py b/langchain/utilities/bash.py index d4bcf73d..2a37f1e7 100644 --- a/langchain/utilities/bash.py +++ b/langchain/utilities/bash.py @@ -1,24 +1,59 @@ """Wrapper around subprocess to run commands.""" +import re import subprocess from typing import List, Union +from uuid import uuid4 + +import pexpect class BashProcess: """Executes bash commands and returns the output.""" - def __init__(self, strip_newlines: bool = False, return_err_output: bool = False): + def __init__( + self, + strip_newlines: bool = False, + return_err_output: bool = False, + persistent: bool = False, + ): """Initialize with stripping newlines.""" self.strip_newlines = strip_newlines self.return_err_output = return_err_output + self.prompt = "" + self.process = None + if persistent: + self.prompt = str(uuid4()) + self.process = self._initialize_persistent_process(self.prompt) + + @staticmethod + def _initialize_persistent_process(prompt: str) -> pexpect.spawn: + # Start bash in a clean environment + process = pexpect.spawn( + "env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8" + ) + # Set the custom prompt + process.sendline("PS1=" + prompt) + + process.expect_exact(prompt, timeout=10) + return process def run(self, commands: Union[str, List[str]]) -> str: """Run commands and return final output.""" if isinstance(commands, str): commands = [commands] commands = ";".join(commands) + if self.process is not None: + return self._run_persistent( + commands, + ) + else: + return self._run(commands) + + def _run(self, command: str) -> str: + """Run commands and return final output.""" try: output = subprocess.run( - commands, + command, shell=True, check=True, stdout=subprocess.PIPE, @@ -31,3 +66,31 @@ class BashProcess: if self.strip_newlines: output = output.strip() return output + + def process_output(self, output: str, command: str) -> str: + # Remove the command from the output using a regular expression + pattern = re.escape(command) + r"\s*\n" + output = re.sub(pattern, "", output, count=1) + return output.strip() + + def _run_persistent(self, command: str) -> str: + """Run commands and return final output.""" + if self.process is None: + raise ValueError("Process not initialized") + self.process.sendline(command) + + # Clear the output with an empty string + self.process.expect(self.prompt, timeout=10) + self.process.sendline("") + + try: + self.process.expect([self.prompt, pexpect.EOF], timeout=10) + except pexpect.TIMEOUT: + return f"Timeout error while executing command {command}" + if self.process.after == pexpect.EOF: + return f"Exited with error status: {self.process.exitstatus}" + output = self.process.before + output = self.process_output(output, command) + if self.strip_newlines: + return output.strip() + return output diff --git a/tests/unit_tests/chains/test_llm_bash.py b/tests/unit_tests/chains/test_llm_bash.py index fdf3c6fe..3e20e356 100644 --- a/tests/unit_tests/chains/test_llm_bash.py +++ b/tests/unit_tests/chains/test_llm_bash.py @@ -3,26 +3,107 @@ import sys import pytest -from langchain.chains.llm_bash.base import LLMBashChain +from langchain.chains.llm_bash.base import BashOutputParser, LLMBashChain from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE +from langchain.schema import OutputParserException from tests.unit_tests.llms.fake_llm import FakeLLM +_SAMPLE_CODE = """ +Unrelated text +```bash +echo hello +``` +Unrelated text +""" + + +_SAMPLE_CODE_2_LINES = """ +Unrelated text +```bash +echo hello + +echo world +``` +Unrelated text +""" + @pytest.fixture -def fake_llm_bash_chain() -> LLMBashChain: - """Fake LLM Bash chain for testing.""" - question = "Please write a bash script that prints 'Hello World' to the console." - prompt = _PROMPT_TEMPLATE.format(question=question) - queries = {prompt: "```bash\nexpr 1 + 1\n```"} - fake_llm = FakeLLM(queries=queries) - return LLMBashChain(llm=fake_llm, input_key="q", output_key="a") +def output_parser() -> BashOutputParser: + """Output parser for testing.""" + return BashOutputParser() @pytest.mark.skipif( sys.platform.startswith("win"), reason="Test not supported on Windows" ) -def test_simple_question(fake_llm_bash_chain: LLMBashChain) -> None: +def test_simple_question() -> None: """Test simple question that should not need python.""" question = "Please write a bash script that prints 'Hello World' to the console." + prompt = _PROMPT_TEMPLATE.format(question=question) + queries = {prompt: "```bash\nexpr 1 + 1\n```"} + fake_llm = FakeLLM(queries=queries) + fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a") output = fake_llm_bash_chain.run(question) assert output == "2\n" + + +def test_get_code(output_parser: BashOutputParser) -> None: + """Test the parser.""" + code_lines = output_parser.parse(_SAMPLE_CODE) + code = [c for c in code_lines if c.strip()] + assert code == code_lines + assert code == ["echo hello"] + + code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES) + assert code_lines == ["echo hello", "echo hello", "echo world"] + + +def test_parsing_error() -> None: + """Test that LLM Output without a bash block raises an exce""" + question = "Please echo 'hello world' to the terminal." + prompt = _PROMPT_TEMPLATE.format(question=question) + queries = { + prompt: """ +```text +echo 'hello world' +``` +""" + } + fake_llm = FakeLLM(queries=queries) + fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a") + with pytest.raises(OutputParserException): + fake_llm_bash_chain.run(question) + + +def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None: + text = """ +Unrelated text +```bash +echo hello +ls && pwd && ls +``` + +```python +print("hello") +``` + +```bash +echo goodbye +``` +""" + code_lines = output_parser.parse(text) + assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"] + + +def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None: + """Test that backticks w/o a newline are ignored.""" + text = """ +Unrelated text +```bash +echo hello +echo "```bash is in this string```" +``` +""" + code_lines = output_parser.parse(text) + assert code_lines == ["echo hello", 'echo "```bash is in this string```"'] diff --git a/tests/unit_tests/test_bash.py b/tests/unit_tests/test_bash.py index e1891adf..29b1ce3b 100644 --- a/tests/unit_tests/test_bash.py +++ b/tests/unit_tests/test_bash.py @@ -21,6 +21,23 @@ def test_pwd_command() -> None: assert output == subprocess.check_output("pwd", shell=True).decode() +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_pwd_command_persistent() -> None: + """Test correct functionality when the bash process is persistent.""" + session = BashProcess(persistent=True, strip_newlines=True) + commands = ["pwd"] + output = session.run(commands) + + assert subprocess.check_output("pwd", shell=True).decode().strip() in output + + session.run(["cd .."]) + new_output = session.run(["pwd"]) + # Assert that the new_output is a parent of the old output + assert Path(output).parent == Path(new_output) + + @pytest.mark.skipif( sys.platform.startswith("win"), reason="Test not supported on Windows" ) @@ -66,3 +83,16 @@ def test_create_directory_and_files(tmp_path: Path) -> None: # check that the files were created in the temporary directory output = session.run([f"ls {temp_dir}"]) assert output == "file1.txt\nfile2.txt" + + +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_create_bash_persistent() -> None: + """Test the pexpect persistent bash terminal""" + session = BashProcess(persistent=True) + response = session.run("echo hello") + response += session.run("echo world") + + assert "hello" in response + assert "world" in response