diff --git a/libs/experimental/langchain_experimental/llm_bash/__init__.py b/libs/experimental/langchain_experimental/llm_bash/__init__.py new file mode 100644 index 0000000000..e1e848a1a8 --- /dev/null +++ b/libs/experimental/langchain_experimental/llm_bash/__init__.py @@ -0,0 +1 @@ +"""Chain that interprets a prompt and executes bash code to perform bash operations.""" diff --git a/libs/experimental/langchain_experimental/llm_bash/base.py b/libs/experimental/langchain_experimental/llm_bash/base.py new file mode 100644 index 0000000000..8ae1e49499 --- /dev/null +++ b/libs/experimental/langchain_experimental/llm_bash/base.py @@ -0,0 +1,125 @@ +"""Chain that interprets a prompt and executes bash operations.""" +from __future__ import annotations + +import logging +import warnings +from typing import Any, Dict, List, Optional + +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.schema import BasePromptTemplate, OutputParserException +from langchain.schema.language_model import BaseLanguageModel + +from langchain_experimental.llm_bash.bash import BashProcess +from langchain_experimental.llm_bash.prompt import PROMPT +from langchain_experimental.pydantic_v1 import Extra, Field, root_validator + +logger = logging.getLogger(__name__) + + +class LLMBashChain(Chain): + """Chain that interprets a prompt and executes bash operations. + + Example: + .. code-block:: python + + from langchain.chains import LLMBashChain + from langchain.llms import OpenAI + llm_bash = LLMBashChain.from_llm(OpenAI()) + """ + + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" + input_key: str = "question" #: :meta private: + output_key: str = "answer" #: :meta private: + prompt: BasePromptTemplate = PROMPT + """[Deprecated]""" + bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMBashChain with an llm is deprecated. " + "Please instantiate with llm_chain or using the from_llm class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + prompt = values.get("prompt", PROMPT) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + + @root_validator + def validate_prompt(cls, values: Dict) -> Dict: + if values["llm_chain"].prompt.output_parser is None: + raise ValueError( + "The prompt used by llm_chain is expected to have an output_parser." + ) + return values + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key. + + :meta private: + """ + return [self.output_key] + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key], verbose=self.verbose) + + t = self.llm_chain.predict( + question=inputs[self.input_key], callbacks=_run_manager.get_child() + ) + _run_manager.on_text(t, color="green", verbose=self.verbose) + t = t.strip() + try: + parser = self.llm_chain.prompt.output_parser + command_list = parser.parse(t) # type: ignore[union-attr] + except OutputParserException as e: + _run_manager.on_chain_error(e, verbose=self.verbose) + raise e + + if self.verbose: + _run_manager.on_text("\nCode: ", verbose=self.verbose) + _run_manager.on_text( + str(command_list), color="yellow", verbose=self.verbose + ) + output = self.bash_process.run(command_list) + _run_manager.on_text("\nAnswer: ", verbose=self.verbose) + _run_manager.on_text(output, color="yellow", verbose=self.verbose) + return {self.output_key: output} + + @property + def _chain_type(self) -> str: + return "llm_bash_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: BasePromptTemplate = PROMPT, + **kwargs: Any, + ) -> LLMBashChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) diff --git a/libs/experimental/langchain_experimental/llm_bash/bash.py b/libs/experimental/langchain_experimental/llm_bash/bash.py new file mode 100644 index 0000000000..9c6c4fcf71 --- /dev/null +++ b/libs/experimental/langchain_experimental/llm_bash/bash.py @@ -0,0 +1,184 @@ +"""Wrapper around subprocess to run commands.""" +from __future__ import annotations + +import platform +import re +import subprocess +from typing import TYPE_CHECKING, List, Union +from uuid import uuid4 + +if TYPE_CHECKING: + import pexpect + + +class BashProcess: + """ + Wrapper class for starting subprocesses. + Uses the python built-in subprocesses.run() + Persistent processes are **not** available + on Windows systems, as pexpect makes use of + Unix pseudoterminals (ptys). MacOS and Linux + are okay. + + Example: + .. code-block:: python + + from langchain.utilities.bash import BashProcess + + bash = BashProcess( + strip_newlines = False, + return_err_output = False, + persistent = False + ) + bash.run('echo \'hello world\'') + + """ + + strip_newlines: bool = False + """Whether or not to run .strip() on the output""" + return_err_output: bool = False + """Whether or not to return the output of a failed + command, or just the error message and stacktrace""" + persistent: bool = False + """Whether or not to spawn a persistent session + NOTE: Unavailable for Windows environments""" + + def __init__( + self, + strip_newlines: bool = False, + return_err_output: bool = False, + persistent: bool = False, + ): + """ + Initializes with default settings + """ + self.strip_newlines = strip_newlines + self.return_err_output = return_err_output + self.prompt = "" + self.process = None + if persistent: + self.prompt = str(uuid4()) + self.process = self._initialize_persistent_process(self, self.prompt) + + @staticmethod + def _lazy_import_pexpect() -> pexpect: + """Import pexpect only when needed.""" + if platform.system() == "Windows": + raise ValueError( + "Persistent bash processes are not yet supported on Windows." + ) + try: + import pexpect + + except ImportError: + raise ImportError( + "pexpect required for persistent bash processes." + " To install, run `pip install pexpect`." + ) + return pexpect + + @staticmethod + def _initialize_persistent_process(self: BashProcess, prompt: str) -> pexpect.spawn: + # Start bash in a clean environment + # Doesn't work on windows + """ + Initializes a persistent bash setting in a + clean environment. + NOTE: Unavailable on Windows + + Args: + Prompt(str): the bash command to execute + """ # noqa: E501 + pexpect = self._lazy_import_pexpect() + process = pexpect.spawn( + "env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8" + ) + # Set the custom prompt + process.sendline("PS1=" + prompt) + + process.expect_exact(prompt, timeout=10) + return process + + def run(self, commands: Union[str, List[str]]) -> str: + """ + Run commands in either an existing persistent + subprocess or on in a new subprocess environment. + + Args: + commands(List[str]): a list of commands to + execute in the session + """ # noqa: E501 + if isinstance(commands, str): + commands = [commands] + commands = ";".join(commands) + if self.process is not None: + return self._run_persistent( + commands, + ) + else: + return self._run(commands) + + def _run(self, command: str) -> str: + """ + Runs a command in a subprocess and returns + the output. + + Args: + command: The command to run + """ # noqa: E501 + try: + output = subprocess.run( + command, + shell=True, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ).stdout.decode() + except subprocess.CalledProcessError as error: + if self.return_err_output: + return error.stdout.decode() + return str(error) + if self.strip_newlines: + output = output.strip() + return output + + def process_output(self, output: str, command: str) -> str: + """ + Uses regex to remove the command from the output + + Args: + output: a process' output string + command: the executed command + """ # noqa: E501 + pattern = re.escape(command) + r"\s*\n" + output = re.sub(pattern, "", output, count=1) + return output.strip() + + def _run_persistent(self, command: str) -> str: + """ + Runs commands in a persistent environment + and returns the output. + + Args: + command: the command to execute + """ # noqa: E501 + pexpect = self._lazy_import_pexpect() + if self.process is None: + raise ValueError("Process not initialized") + self.process.sendline(command) + + # Clear the output with an empty string + self.process.expect(self.prompt, timeout=10) + self.process.sendline("") + + try: + self.process.expect([self.prompt, pexpect.EOF], timeout=10) + except pexpect.TIMEOUT: + return f"Timeout error while executing command {command}" + if self.process.after == pexpect.EOF: + return f"Exited with error status: {self.process.exitstatus}" + output = self.process.before + output = self.process_output(output, command) + if self.strip_newlines: + return output.strip() + return output diff --git a/libs/experimental/langchain_experimental/llm_bash/prompt.py b/libs/experimental/langchain_experimental/llm_bash/prompt.py new file mode 100644 index 0000000000..72951d2fe9 --- /dev/null +++ b/libs/experimental/langchain_experimental/llm_bash/prompt.py @@ -0,0 +1,64 @@ +# flake8: noqa +from __future__ import annotations + +import re +from typing import List + +from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BaseOutputParser, OutputParserException + +_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format: + +Question: "copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'" + +I need to take the following actions: +- List all files in the directory +- Create a new directory +- Copy the files from the first directory into the second directory +```bash +ls +mkdir myNewDirectory +cp -r target/* myNewDirectory +``` + +That is the format. Begin! + +Question: {question}""" + + +class BashOutputParser(BaseOutputParser): + """Parser for bash output.""" + + def parse(self, text: str) -> List[str]: + if "```bash" in text: + return self.get_code_blocks(text) + else: + raise OutputParserException( + f"Failed to parse bash output. Got: {text}", + ) + + @staticmethod + def get_code_blocks(t: str) -> List[str]: + """Get multiple code blocks from the LLM result.""" + code_blocks: List[str] = [] + # Bash markdown code blocks + pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL) + for match in pattern.finditer(t): + matched = match.group(1).strip() + if matched: + code_blocks.extend( + [line for line in matched.split("\n") if line.strip()] + ) + + return code_blocks + + @property + def _type(self) -> str: + return "bash" + + +PROMPT = PromptTemplate( + input_variables=["question"], + template=_PROMPT_TEMPLATE, + output_parser=BashOutputParser(), +) diff --git a/libs/experimental/tests/unit_tests/test_bash.py b/libs/experimental/tests/unit_tests/test_bash.py new file mode 100644 index 0000000000..f9acfc865c --- /dev/null +++ b/libs/experimental/tests/unit_tests/test_bash.py @@ -0,0 +1,102 @@ +"""Test the bash utility.""" +import re +import subprocess +import sys +from pathlib import Path + +import pytest + +from langchain_experimental.llm_bash.bash import BashProcess + + +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_pwd_command() -> None: + """Test correct functionality.""" + session = BashProcess() + commands = ["pwd"] + output = session.run(commands) + + assert output == subprocess.check_output("pwd", shell=True).decode() + + +@pytest.mark.skip(reason="flaky on GHA, TODO to fix") +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_pwd_command_persistent() -> None: + """Test correct functionality when the bash process is persistent.""" + session = BashProcess(persistent=True, strip_newlines=True) + commands = ["pwd"] + output = session.run(commands) + + assert subprocess.check_output("pwd", shell=True).decode().strip() in output + + session.run(["cd .."]) + new_output = session.run(["pwd"]) + # Assert that the new_output is a parent of the old output + assert Path(output).parent == Path(new_output) + + +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_incorrect_command() -> None: + """Test handling of incorrect command.""" + session = BashProcess() + output = session.run(["invalid_command"]) + assert output == "Command 'invalid_command' returned non-zero exit status 127." + + +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_incorrect_command_return_err_output() -> None: + """Test optional returning of shell output on incorrect command.""" + session = BashProcess(return_err_output=True) + output = session.run(["invalid_command"]) + assert re.match( + r"^/bin/sh:.*invalid_command.*(?:not found|Permission denied).*$", output + ) + + +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_create_directory_and_files(tmp_path: Path) -> None: + """Test creation of a directory and files in a temporary directory.""" + session = BashProcess(strip_newlines=True) + + # create a subdirectory in the temporary directory + temp_dir = tmp_path / "test_dir" + temp_dir.mkdir() + + # run the commands in the temporary directory + commands = [ + f"touch {temp_dir}/file1.txt", + f"touch {temp_dir}/file2.txt", + f"echo 'hello world' > {temp_dir}/file2.txt", + f"cat {temp_dir}/file2.txt", + ] + + output = session.run(commands) + assert output == "hello world" + + # check that the files were created in the temporary directory + output = session.run([f"ls {temp_dir}"]) + assert output == "file1.txt\nfile2.txt" + + +@pytest.mark.skip(reason="flaky on GHA, TODO to fix") +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_create_bash_persistent() -> None: + """Test the pexpect persistent bash terminal""" + session = BashProcess(persistent=True) + response = session.run("echo hello") + response += session.run("echo world") + + assert "hello" in response + assert "world" in response diff --git a/libs/experimental/tests/unit_tests/test_llm_bash.py b/libs/experimental/tests/unit_tests/test_llm_bash.py new file mode 100644 index 0000000000..5f7dc3fb85 --- /dev/null +++ b/libs/experimental/tests/unit_tests/test_llm_bash.py @@ -0,0 +1,109 @@ +"""Test LLM Bash functionality.""" +import sys + +import pytest +from langchain.schema import OutputParserException + +from langchain_experimental.llm_bash.base import LLMBashChain +from langchain_experimental.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser +from tests.unit_tests.fake_llm import FakeLLM + +_SAMPLE_CODE = """ +Unrelated text +```bash +echo hello +``` +Unrelated text +""" + + +_SAMPLE_CODE_2_LINES = """ +Unrelated text +```bash +echo hello + +echo world +``` +Unrelated text +""" + + +@pytest.fixture +def output_parser() -> BashOutputParser: + """Output parser for testing.""" + return BashOutputParser() + + +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_simple_question() -> None: + """Test simple question that should not need python.""" + question = "Please write a bash script that prints 'Hello World' to the console." + prompt = _PROMPT_TEMPLATE.format(question=question) + queries = {prompt: "```bash\nexpr 1 + 1\n```"} + fake_llm = FakeLLM(queries=queries) + fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a") + output = fake_llm_bash_chain.run(question) + assert output == "2\n" + + +def test_get_code(output_parser: BashOutputParser) -> None: + """Test the parser.""" + code_lines = output_parser.parse(_SAMPLE_CODE) + code = [c for c in code_lines if c.strip()] + assert code == code_lines + assert code == ["echo hello"] + + code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES) + assert code_lines == ["echo hello", "echo hello", "echo world"] + + +def test_parsing_error() -> None: + """Test that LLM Output without a bash block raises an exce""" + question = "Please echo 'hello world' to the terminal." + prompt = _PROMPT_TEMPLATE.format(question=question) + queries = { + prompt: """ +```text +echo 'hello world' +``` +""" + } + fake_llm = FakeLLM(queries=queries) + fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a") + with pytest.raises(OutputParserException): + fake_llm_bash_chain.run(question) + + +def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None: + text = """ +Unrelated text +```bash +echo hello +ls && pwd && ls +``` + +```python +print("hello") +``` + +```bash +echo goodbye +``` +""" + code_lines = output_parser.parse(text) + assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"] + + +def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None: + """Test that backticks w/o a newline are ignored.""" + text = """ +Unrelated text +```bash +echo hello +echo "```bash is in this string```" +``` +""" + code_lines = output_parser.parse(text) + assert code_lines == ["echo hello", 'echo "```bash is in this string```"']