Persistent Bash Shell (#3580)

Clean up linting and make more idiomatic by using an output parser

---------

Co-authored-by: FergusFettes <fergusfettes@gmail.com>
This commit is contained in:
Zander Chase 2023-04-26 15:20:28 -07:00 committed by GitHub
parent c5451f4298
commit ee670c448e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 464 additions and 37 deletions

View File

@ -39,11 +39,27 @@
"name": "stdout",
"output_type": "stream",
"text": [
"apify.ipynb\n",
"arxiv.ipynb\n",
"bash.ipynb\n",
"bing_search.ipynb\n",
"chatgpt_plugins.ipynb\n",
"ddg.ipynb\n",
"google_places.ipynb\n",
"google_search.ipynb\n",
"google_serper.ipynb\n",
"gradio_tools.ipynb\n",
"human_tools.ipynb\n",
"ifttt.ipynb\n",
"openweathermap.ipynb\n",
"python.ipynb\n",
"requests.ipynb\n",
"search_tools.ipynb\n",
"searx_search.ipynb\n",
"serpapi.ipynb\n",
"wikipedia.ipynb\n",
"wolfram_alpha.ipynb\n",
"zapier.ipynb\n",
"\n"
]
}
@ -54,9 +70,94 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "e7896f8e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"apify.ipynb\n",
"arxiv.ipynb\n",
"bash.ipynb\n",
"bing_search.ipynb\n",
"chatgpt_plugins.ipynb\n",
"ddg.ipynb\n",
"google_places.ipynb\n",
"google_search.ipynb\n",
"google_serper.ipynb\n",
"gradio_tools.ipynb\n",
"human_tools.ipynb\n",
"ifttt.ipynb\n",
"openweathermap.ipynb\n",
"python.ipynb\n",
"requests.ipynb\n",
"search_tools.ipynb\n",
"searx_search.ipynb\n",
"serpapi.ipynb\n",
"wikipedia.ipynb\n",
"wolfram_alpha.ipynb\n",
"zapier.ipynb\n",
"\n"
]
}
],
"source": [
"bash.run(\"cd ..\")\n",
"# The commands are executed in a new subprocess each time, meaning that\n",
"# this call will return the same results as the last.\n",
"print(bash.run(\"ls\"))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "851fee9f",
"metadata": {},
"source": [
"## Terminal Persistance\n",
"\n",
"By default, the bash command will be executed in a new subprocess each time. To retain a persistent bash session, we can use the `persistent=True` arg."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4a93ea2c",
"metadata": {},
"outputs": [],
"source": [
"bash = BashProcess(persistent=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a1e98b78",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"custom_tools.ipynb\t\tmulti_input_tool.ipynb\n",
"examples\t\t\ttool_input_validation.ipynb\n",
"getting_started.md\n"
]
}
],
"source": [
"bash.run(\"cd ..\")\n",
"# Note the list of files is different\n",
"print(bash.run(\"ls\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e13c1c9c",
"metadata": {},
"outputs": [],
"source": []
}
@ -77,7 +178,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.8.16"
}
},
"nbformat": 4,

View File

@ -24,8 +24,8 @@
"\n",
"```bash\n",
"echo \"Hello World\"\n",
"```\u001b[0m['```bash', 'echo \"Hello World\"', '```']\n",
"\n",
"```\u001b[0m\n",
"Code: \u001b[33;1m\u001b[1;3m['echo \"Hello World\"']\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
@ -65,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -93,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -107,8 +107,8 @@
"\n",
"```bash\n",
"printf \"Hello World\\n\"\n",
"```\u001b[0m['```bash', 'printf \"Hello World\\\\n\"', '```']\n",
"\n",
"```\u001b[0m\n",
"Code: \u001b[33;1m\u001b[1;3m['printf \"Hello World\\\\n\"']\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
@ -120,7 +120,7 @@
"'Hello World\\n'"
]
},
"execution_count": 29,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@ -132,6 +132,114 @@
"\n",
"bash_chain.run(text)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Persistent Terminal\n",
"\n",
"By default, the chain will run in a separate subprocess each time it is called. This behavior can be changed by instantiating with a persistent bash process."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
"List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n",
"\n",
"```bash\n",
"ls\n",
"cd ..\n",
"```\u001b[0m\n",
"Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3mapi.ipynb\t\t\tllm_summarization_checker.ipynb\n",
"constitutional_chain.ipynb\tmoderation.ipynb\n",
"llm_bash.ipynb\t\t\topenai_openapi.yaml\n",
"llm_checker.ipynb\t\topenapi.ipynb\n",
"llm_math.ipynb\t\t\tpal.ipynb\n",
"llm_requests.ipynb\t\tsqlite.ipynb\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'api.ipynb\\t\\t\\tllm_summarization_checker.ipynb\\r\\nconstitutional_chain.ipynb\\tmoderation.ipynb\\r\\nllm_bash.ipynb\\t\\t\\topenai_openapi.yaml\\r\\nllm_checker.ipynb\\t\\topenapi.ipynb\\r\\nllm_math.ipynb\\t\\t\\tpal.ipynb\\r\\nllm_requests.ipynb\\t\\tsqlite.ipynb'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.utilities.bash import BashProcess\n",
"\n",
"\n",
"persistent_process = BashProcess(persistent=True)\n",
"bash_chain = LLMBashChain.from_bash_process(llm=llm, bash_process=persistent_process, verbose=True)\n",
"\n",
"text = \"List the current directory then move up a level.\"\n",
"\n",
"bash_chain.run(text)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
"List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n",
"\n",
"```bash\n",
"ls\n",
"cd ..\n",
"```\u001b[0m\n",
"Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3mexamples\t\tgetting_started.ipynb\tindex_examples\n",
"generic\t\t\thow_to_guides.rst\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'examples\\t\\tgetting_started.ipynb\\tindex_examples\\r\\ngeneric\\t\\t\\thow_to_guides.rst'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Run the same command again and see that the state is maintained between calls\n",
"bash_chain.run(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -150,7 +258,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.8.16"
}
},
"nbformat": 4,

View File

@ -1,15 +1,46 @@
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
from typing import Dict, List
import logging
import re
from typing import Any, Dict, List
from pydantic import Extra
from pydantic import Extra, Field
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
from langchain.utilities.bash import BashProcess
logger = logging.getLogger(__name__)
class BashOutputParser(BaseOutputParser):
"""Parser for bash output."""
def parse(self, text: str) -> List[str]:
if "```bash" in text:
return self.get_code_blocks(text)
else:
raise OutputParserException(
f"Failed to parse bash output. Got: {text}",
)
@staticmethod
def get_code_blocks(t: str) -> List[str]:
"""Get multiple code blocks from the LLM result."""
code_blocks: List[str] = []
# Bash markdown code blocks
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
for match in pattern.finditer(t):
matched = match.group(1).strip()
if matched:
code_blocks.extend(
[line for line in matched.split("\n") if line.strip()]
)
return code_blocks
class LLMBashChain(Chain):
"""Chain that interprets a prompt and executes bash code to perform bash operations.
@ -26,6 +57,8 @@ class LLMBashChain(Chain):
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
prompt: BasePromptTemplate = PROMPT
output_parser: BaseOutputParser = Field(default_factory=BashOutputParser)
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
class Config:
"""Configuration for this pydantic object."""
@ -51,29 +84,40 @@ class LLMBashChain(Chain):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
bash_executor = BashProcess()
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
t = llm_executor.predict(question=inputs[self.input_key])
self.callback_manager.on_text(t, color="green", verbose=self.verbose)
t = t.strip()
if t.startswith("```bash"):
# Split the string into a list of substrings
command_list = t.split("\n")
print(command_list)
try:
command_list = self.output_parser.parse(t)
except OutputParserException as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
# Remove the first and last substrings
command_list = [s for s in command_list[1:-1]]
output = bash_executor.run(command_list)
if self.verbose:
self.callback_manager.on_text("\nCode: ", verbose=self.verbose)
self.callback_manager.on_text(
str(command_list), color="yellow", verbose=self.verbose
)
output = self.bash_process.run(command_list)
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
else:
raise ValueError(f"unknown format from LLM: {t}")
return {self.output_key: output}
@property
def _chain_type(self) -> str:
return "llm_bash_chain"
@classmethod
def from_bash_process(
cls,
bash_process: BashProcess,
llm: BaseLanguageModel,
**kwargs: Any,
) -> "LLMBashChain":
"""Create a LLMBashChain from a BashProcess."""
return cls(llm=llm, bash_process=bash_process, **kwargs)

View File

@ -1,24 +1,59 @@
"""Wrapper around subprocess to run commands."""
import re
import subprocess
from typing import List, Union
from uuid import uuid4
import pexpect
class BashProcess:
"""Executes bash commands and returns the output."""
def __init__(self, strip_newlines: bool = False, return_err_output: bool = False):
def __init__(
self,
strip_newlines: bool = False,
return_err_output: bool = False,
persistent: bool = False,
):
"""Initialize with stripping newlines."""
self.strip_newlines = strip_newlines
self.return_err_output = return_err_output
self.prompt = ""
self.process = None
if persistent:
self.prompt = str(uuid4())
self.process = self._initialize_persistent_process(self.prompt)
@staticmethod
def _initialize_persistent_process(prompt: str) -> pexpect.spawn:
# Start bash in a clean environment
process = pexpect.spawn(
"env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8"
)
# Set the custom prompt
process.sendline("PS1=" + prompt)
process.expect_exact(prompt, timeout=10)
return process
def run(self, commands: Union[str, List[str]]) -> str:
"""Run commands and return final output."""
if isinstance(commands, str):
commands = [commands]
commands = ";".join(commands)
if self.process is not None:
return self._run_persistent(
commands,
)
else:
return self._run(commands)
def _run(self, command: str) -> str:
"""Run commands and return final output."""
try:
output = subprocess.run(
commands,
command,
shell=True,
check=True,
stdout=subprocess.PIPE,
@ -31,3 +66,31 @@ class BashProcess:
if self.strip_newlines:
output = output.strip()
return output
def process_output(self, output: str, command: str) -> str:
# Remove the command from the output using a regular expression
pattern = re.escape(command) + r"\s*\n"
output = re.sub(pattern, "", output, count=1)
return output.strip()
def _run_persistent(self, command: str) -> str:
"""Run commands and return final output."""
if self.process is None:
raise ValueError("Process not initialized")
self.process.sendline(command)
# Clear the output with an empty string
self.process.expect(self.prompt, timeout=10)
self.process.sendline("")
try:
self.process.expect([self.prompt, pexpect.EOF], timeout=10)
except pexpect.TIMEOUT:
return f"Timeout error while executing command {command}"
if self.process.after == pexpect.EOF:
return f"Exited with error status: {self.process.exitstatus}"
output = self.process.before
output = self.process_output(output, command)
if self.strip_newlines:
return output.strip()
return output

View File

@ -3,26 +3,107 @@ import sys
import pytest
from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_bash.base import BashOutputParser, LLMBashChain
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE
from langchain.schema import OutputParserException
from tests.unit_tests.llms.fake_llm import FakeLLM
_SAMPLE_CODE = """
Unrelated text
```bash
echo hello
```
Unrelated text
"""
_SAMPLE_CODE_2_LINES = """
Unrelated text
```bash
echo hello
echo world
```
Unrelated text
"""
@pytest.fixture
def fake_llm_bash_chain() -> LLMBashChain:
"""Fake LLM Bash chain for testing."""
question = "Please write a bash script that prints 'Hello World' to the console."
prompt = _PROMPT_TEMPLATE.format(question=question)
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
fake_llm = FakeLLM(queries=queries)
return LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
def output_parser() -> BashOutputParser:
"""Output parser for testing."""
return BashOutputParser()
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_simple_question(fake_llm_bash_chain: LLMBashChain) -> None:
def test_simple_question() -> None:
"""Test simple question that should not need python."""
question = "Please write a bash script that prints 'Hello World' to the console."
prompt = _PROMPT_TEMPLATE.format(question=question)
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
fake_llm = FakeLLM(queries=queries)
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
output = fake_llm_bash_chain.run(question)
assert output == "2\n"
def test_get_code(output_parser: BashOutputParser) -> None:
"""Test the parser."""
code_lines = output_parser.parse(_SAMPLE_CODE)
code = [c for c in code_lines if c.strip()]
assert code == code_lines
assert code == ["echo hello"]
code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES)
assert code_lines == ["echo hello", "echo hello", "echo world"]
def test_parsing_error() -> None:
"""Test that LLM Output without a bash block raises an exce"""
question = "Please echo 'hello world' to the terminal."
prompt = _PROMPT_TEMPLATE.format(question=question)
queries = {
prompt: """
```text
echo 'hello world'
```
"""
}
fake_llm = FakeLLM(queries=queries)
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
with pytest.raises(OutputParserException):
fake_llm_bash_chain.run(question)
def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None:
text = """
Unrelated text
```bash
echo hello
ls && pwd && ls
```
```python
print("hello")
```
```bash
echo goodbye
```
"""
code_lines = output_parser.parse(text)
assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"]
def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None:
"""Test that backticks w/o a newline are ignored."""
text = """
Unrelated text
```bash
echo hello
echo "```bash is in this string```"
```
"""
code_lines = output_parser.parse(text)
assert code_lines == ["echo hello", 'echo "```bash is in this string```"']

View File

@ -21,6 +21,23 @@ def test_pwd_command() -> None:
assert output == subprocess.check_output("pwd", shell=True).decode()
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_pwd_command_persistent() -> None:
"""Test correct functionality when the bash process is persistent."""
session = BashProcess(persistent=True, strip_newlines=True)
commands = ["pwd"]
output = session.run(commands)
assert subprocess.check_output("pwd", shell=True).decode().strip() in output
session.run(["cd .."])
new_output = session.run(["pwd"])
# Assert that the new_output is a parent of the old output
assert Path(output).parent == Path(new_output)
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
@ -66,3 +83,16 @@ def test_create_directory_and_files(tmp_path: Path) -> None:
# check that the files were created in the temporary directory
output = session.run([f"ls {temp_dir}"])
assert output == "file1.txt\nfile2.txt"
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_create_bash_persistent() -> None:
"""Test the pexpect persistent bash terminal"""
session = BashProcess(persistent=True)
response = session.run("echo hello")
response += session.run("echo world")
assert "hello" in response
assert "world" in response