add LLMBashChain to experimental (#11305)

Add LLMBashChain to experimental
pull/11308/head
Eugene Yurtsev 1 year ago committed by GitHub
parent 29b9a890d4
commit 5e2d5047af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1 @@
"""Chain that interprets a prompt and executes bash code to perform bash operations."""

@ -0,0 +1,125 @@
"""Chain that interprets a prompt and executes bash operations."""
from __future__ import annotations
import logging
import warnings
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.schema import BasePromptTemplate, OutputParserException
from langchain.schema.language_model import BaseLanguageModel
from langchain_experimental.llm_bash.bash import BashProcess
from langchain_experimental.llm_bash.prompt import PROMPT
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
logger = logging.getLogger(__name__)
class LLMBashChain(Chain):
"""Chain that interprets a prompt and executes bash operations.
Example:
.. code-block:: python
from langchain.chains import LLMBashChain
from langchain.llms import OpenAI
llm_bash = LLMBashChain.from_llm(OpenAI())
"""
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
prompt: BasePromptTemplate = PROMPT
"""[Deprecated]"""
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMBashChain with an llm is deprecated. "
"Please instantiate with llm_chain or using the from_llm class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@root_validator
def validate_prompt(cls, values: Dict) -> Dict:
if values["llm_chain"].prompt.output_parser is None:
raise ValueError(
"The prompt used by llm_chain is expected to have an output_parser."
)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key], verbose=self.verbose)
t = self.llm_chain.predict(
question=inputs[self.input_key], callbacks=_run_manager.get_child()
)
_run_manager.on_text(t, color="green", verbose=self.verbose)
t = t.strip()
try:
parser = self.llm_chain.prompt.output_parser
command_list = parser.parse(t) # type: ignore[union-attr]
except OutputParserException as e:
_run_manager.on_chain_error(e, verbose=self.verbose)
raise e
if self.verbose:
_run_manager.on_text("\nCode: ", verbose=self.verbose)
_run_manager.on_text(
str(command_list), color="yellow", verbose=self.verbose
)
output = self.bash_process.run(command_list)
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
return {self.output_key: output}
@property
def _chain_type(self) -> str:
return "llm_bash_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMBashChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)

@ -0,0 +1,184 @@
"""Wrapper around subprocess to run commands."""
from __future__ import annotations
import platform
import re
import subprocess
from typing import TYPE_CHECKING, List, Union
from uuid import uuid4
if TYPE_CHECKING:
import pexpect
class BashProcess:
"""
Wrapper class for starting subprocesses.
Uses the python built-in subprocesses.run()
Persistent processes are **not** available
on Windows systems, as pexpect makes use of
Unix pseudoterminals (ptys). MacOS and Linux
are okay.
Example:
.. code-block:: python
from langchain.utilities.bash import BashProcess
bash = BashProcess(
strip_newlines = False,
return_err_output = False,
persistent = False
)
bash.run('echo \'hello world\'')
"""
strip_newlines: bool = False
"""Whether or not to run .strip() on the output"""
return_err_output: bool = False
"""Whether or not to return the output of a failed
command, or just the error message and stacktrace"""
persistent: bool = False
"""Whether or not to spawn a persistent session
NOTE: Unavailable for Windows environments"""
def __init__(
self,
strip_newlines: bool = False,
return_err_output: bool = False,
persistent: bool = False,
):
"""
Initializes with default settings
"""
self.strip_newlines = strip_newlines
self.return_err_output = return_err_output
self.prompt = ""
self.process = None
if persistent:
self.prompt = str(uuid4())
self.process = self._initialize_persistent_process(self, self.prompt)
@staticmethod
def _lazy_import_pexpect() -> pexpect:
"""Import pexpect only when needed."""
if platform.system() == "Windows":
raise ValueError(
"Persistent bash processes are not yet supported on Windows."
)
try:
import pexpect
except ImportError:
raise ImportError(
"pexpect required for persistent bash processes."
" To install, run `pip install pexpect`."
)
return pexpect
@staticmethod
def _initialize_persistent_process(self: BashProcess, prompt: str) -> pexpect.spawn:
# Start bash in a clean environment
# Doesn't work on windows
"""
Initializes a persistent bash setting in a
clean environment.
NOTE: Unavailable on Windows
Args:
Prompt(str): the bash command to execute
""" # noqa: E501
pexpect = self._lazy_import_pexpect()
process = pexpect.spawn(
"env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8"
)
# Set the custom prompt
process.sendline("PS1=" + prompt)
process.expect_exact(prompt, timeout=10)
return process
def run(self, commands: Union[str, List[str]]) -> str:
"""
Run commands in either an existing persistent
subprocess or on in a new subprocess environment.
Args:
commands(List[str]): a list of commands to
execute in the session
""" # noqa: E501
if isinstance(commands, str):
commands = [commands]
commands = ";".join(commands)
if self.process is not None:
return self._run_persistent(
commands,
)
else:
return self._run(commands)
def _run(self, command: str) -> str:
"""
Runs a command in a subprocess and returns
the output.
Args:
command: The command to run
""" # noqa: E501
try:
output = subprocess.run(
command,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
).stdout.decode()
except subprocess.CalledProcessError as error:
if self.return_err_output:
return error.stdout.decode()
return str(error)
if self.strip_newlines:
output = output.strip()
return output
def process_output(self, output: str, command: str) -> str:
"""
Uses regex to remove the command from the output
Args:
output: a process' output string
command: the executed command
""" # noqa: E501
pattern = re.escape(command) + r"\s*\n"
output = re.sub(pattern, "", output, count=1)
return output.strip()
def _run_persistent(self, command: str) -> str:
"""
Runs commands in a persistent environment
and returns the output.
Args:
command: the command to execute
""" # noqa: E501
pexpect = self._lazy_import_pexpect()
if self.process is None:
raise ValueError("Process not initialized")
self.process.sendline(command)
# Clear the output with an empty string
self.process.expect(self.prompt, timeout=10)
self.process.sendline("")
try:
self.process.expect([self.prompt, pexpect.EOF], timeout=10)
except pexpect.TIMEOUT:
return f"Timeout error while executing command {command}"
if self.process.after == pexpect.EOF:
return f"Exited with error status: {self.process.exitstatus}"
output = self.process.before
output = self.process_output(output, command)
if self.strip_newlines:
return output.strip()
return output

@ -0,0 +1,64 @@
# flake8: noqa
from __future__ import annotations
import re
from typing import List
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
Question: "copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'"
I need to take the following actions:
- List all files in the directory
- Create a new directory
- Copy the files from the first directory into the second directory
```bash
ls
mkdir myNewDirectory
cp -r target/* myNewDirectory
```
That is the format. Begin!
Question: {question}"""
class BashOutputParser(BaseOutputParser):
"""Parser for bash output."""
def parse(self, text: str) -> List[str]:
if "```bash" in text:
return self.get_code_blocks(text)
else:
raise OutputParserException(
f"Failed to parse bash output. Got: {text}",
)
@staticmethod
def get_code_blocks(t: str) -> List[str]:
"""Get multiple code blocks from the LLM result."""
code_blocks: List[str] = []
# Bash markdown code blocks
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
for match in pattern.finditer(t):
matched = match.group(1).strip()
if matched:
code_blocks.extend(
[line for line in matched.split("\n") if line.strip()]
)
return code_blocks
@property
def _type(self) -> str:
return "bash"
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
output_parser=BashOutputParser(),
)

@ -0,0 +1,102 @@
"""Test the bash utility."""
import re
import subprocess
import sys
from pathlib import Path
import pytest
from langchain_experimental.llm_bash.bash import BashProcess
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_pwd_command() -> None:
"""Test correct functionality."""
session = BashProcess()
commands = ["pwd"]
output = session.run(commands)
assert output == subprocess.check_output("pwd", shell=True).decode()
@pytest.mark.skip(reason="flaky on GHA, TODO to fix")
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_pwd_command_persistent() -> None:
"""Test correct functionality when the bash process is persistent."""
session = BashProcess(persistent=True, strip_newlines=True)
commands = ["pwd"]
output = session.run(commands)
assert subprocess.check_output("pwd", shell=True).decode().strip() in output
session.run(["cd .."])
new_output = session.run(["pwd"])
# Assert that the new_output is a parent of the old output
assert Path(output).parent == Path(new_output)
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_incorrect_command() -> None:
"""Test handling of incorrect command."""
session = BashProcess()
output = session.run(["invalid_command"])
assert output == "Command 'invalid_command' returned non-zero exit status 127."
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_incorrect_command_return_err_output() -> None:
"""Test optional returning of shell output on incorrect command."""
session = BashProcess(return_err_output=True)
output = session.run(["invalid_command"])
assert re.match(
r"^/bin/sh:.*invalid_command.*(?:not found|Permission denied).*$", output
)
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_create_directory_and_files(tmp_path: Path) -> None:
"""Test creation of a directory and files in a temporary directory."""
session = BashProcess(strip_newlines=True)
# create a subdirectory in the temporary directory
temp_dir = tmp_path / "test_dir"
temp_dir.mkdir()
# run the commands in the temporary directory
commands = [
f"touch {temp_dir}/file1.txt",
f"touch {temp_dir}/file2.txt",
f"echo 'hello world' > {temp_dir}/file2.txt",
f"cat {temp_dir}/file2.txt",
]
output = session.run(commands)
assert output == "hello world"
# check that the files were created in the temporary directory
output = session.run([f"ls {temp_dir}"])
assert output == "file1.txt\nfile2.txt"
@pytest.mark.skip(reason="flaky on GHA, TODO to fix")
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_create_bash_persistent() -> None:
"""Test the pexpect persistent bash terminal"""
session = BashProcess(persistent=True)
response = session.run("echo hello")
response += session.run("echo world")
assert "hello" in response
assert "world" in response

@ -0,0 +1,109 @@
"""Test LLM Bash functionality."""
import sys
import pytest
from langchain.schema import OutputParserException
from langchain_experimental.llm_bash.base import LLMBashChain
from langchain_experimental.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser
from tests.unit_tests.fake_llm import FakeLLM
_SAMPLE_CODE = """
Unrelated text
```bash
echo hello
```
Unrelated text
"""
_SAMPLE_CODE_2_LINES = """
Unrelated text
```bash
echo hello
echo world
```
Unrelated text
"""
@pytest.fixture
def output_parser() -> BashOutputParser:
"""Output parser for testing."""
return BashOutputParser()
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_simple_question() -> None:
"""Test simple question that should not need python."""
question = "Please write a bash script that prints 'Hello World' to the console."
prompt = _PROMPT_TEMPLATE.format(question=question)
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
fake_llm = FakeLLM(queries=queries)
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
output = fake_llm_bash_chain.run(question)
assert output == "2\n"
def test_get_code(output_parser: BashOutputParser) -> None:
"""Test the parser."""
code_lines = output_parser.parse(_SAMPLE_CODE)
code = [c for c in code_lines if c.strip()]
assert code == code_lines
assert code == ["echo hello"]
code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES)
assert code_lines == ["echo hello", "echo hello", "echo world"]
def test_parsing_error() -> None:
"""Test that LLM Output without a bash block raises an exce"""
question = "Please echo 'hello world' to the terminal."
prompt = _PROMPT_TEMPLATE.format(question=question)
queries = {
prompt: """
```text
echo 'hello world'
```
"""
}
fake_llm = FakeLLM(queries=queries)
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
with pytest.raises(OutputParserException):
fake_llm_bash_chain.run(question)
def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None:
text = """
Unrelated text
```bash
echo hello
ls && pwd && ls
```
```python
print("hello")
```
```bash
echo goodbye
```
"""
code_lines = output_parser.parse(text)
assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"]
def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None:
"""Test that backticks w/o a newline are ignored."""
text = """
Unrelated text
```bash
echo hello
echo "```bash is in this string```"
```
"""
code_lines = output_parser.parse(text)
assert code_lines == ["echo hello", 'echo "```bash is in this string```"']
Loading…
Cancel
Save