mirror of https://github.com/hwchase17/langchain
parent
29b9a890d4
commit
5e2d5047af
@ -0,0 +1 @@
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
|
@ -0,0 +1,125 @@
|
||||
"""Chain that interprets a prompt and executes bash operations."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema import BasePromptTemplate, OutputParserException
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from langchain_experimental.llm_bash.bash import BashProcess
|
||||
from langchain_experimental.llm_bash.prompt import PROMPT
|
||||
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMBashChain(Chain):
|
||||
"""Chain that interprets a prompt and executes bash operations.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import LLMBashChain
|
||||
from langchain.llms import OpenAI
|
||||
llm_bash = LLMBashChain.from_llm(OpenAI())
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated]"""
|
||||
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMBashChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain or using the from_llm class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@root_validator
|
||||
def validate_prompt(cls, values: Dict) -> Dict:
|
||||
if values["llm_chain"].prompt.output_parser is None:
|
||||
raise ValueError(
|
||||
"The prompt used by llm_chain is expected to have an output_parser."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
|
||||
t = self.llm_chain.predict(
|
||||
question=inputs[self.input_key], callbacks=_run_manager.get_child()
|
||||
)
|
||||
_run_manager.on_text(t, color="green", verbose=self.verbose)
|
||||
t = t.strip()
|
||||
try:
|
||||
parser = self.llm_chain.prompt.output_parser
|
||||
command_list = parser.parse(t) # type: ignore[union-attr]
|
||||
except OutputParserException as e:
|
||||
_run_manager.on_chain_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
|
||||
if self.verbose:
|
||||
_run_manager.on_text("\nCode: ", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(command_list), color="yellow", verbose=self.verbose
|
||||
)
|
||||
output = self.bash_process.run(command_list)
|
||||
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
return {self.output_key: output}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_bash_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMBashChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
@ -0,0 +1,184 @@
|
||||
"""Wrapper around subprocess to run commands."""
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pexpect
|
||||
|
||||
|
||||
class BashProcess:
|
||||
"""
|
||||
Wrapper class for starting subprocesses.
|
||||
Uses the python built-in subprocesses.run()
|
||||
Persistent processes are **not** available
|
||||
on Windows systems, as pexpect makes use of
|
||||
Unix pseudoterminals (ptys). MacOS and Linux
|
||||
are okay.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
bash = BashProcess(
|
||||
strip_newlines = False,
|
||||
return_err_output = False,
|
||||
persistent = False
|
||||
)
|
||||
bash.run('echo \'hello world\'')
|
||||
|
||||
"""
|
||||
|
||||
strip_newlines: bool = False
|
||||
"""Whether or not to run .strip() on the output"""
|
||||
return_err_output: bool = False
|
||||
"""Whether or not to return the output of a failed
|
||||
command, or just the error message and stacktrace"""
|
||||
persistent: bool = False
|
||||
"""Whether or not to spawn a persistent session
|
||||
NOTE: Unavailable for Windows environments"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strip_newlines: bool = False,
|
||||
return_err_output: bool = False,
|
||||
persistent: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes with default settings
|
||||
"""
|
||||
self.strip_newlines = strip_newlines
|
||||
self.return_err_output = return_err_output
|
||||
self.prompt = ""
|
||||
self.process = None
|
||||
if persistent:
|
||||
self.prompt = str(uuid4())
|
||||
self.process = self._initialize_persistent_process(self, self.prompt)
|
||||
|
||||
@staticmethod
|
||||
def _lazy_import_pexpect() -> pexpect:
|
||||
"""Import pexpect only when needed."""
|
||||
if platform.system() == "Windows":
|
||||
raise ValueError(
|
||||
"Persistent bash processes are not yet supported on Windows."
|
||||
)
|
||||
try:
|
||||
import pexpect
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pexpect required for persistent bash processes."
|
||||
" To install, run `pip install pexpect`."
|
||||
)
|
||||
return pexpect
|
||||
|
||||
@staticmethod
|
||||
def _initialize_persistent_process(self: BashProcess, prompt: str) -> pexpect.spawn:
|
||||
# Start bash in a clean environment
|
||||
# Doesn't work on windows
|
||||
"""
|
||||
Initializes a persistent bash setting in a
|
||||
clean environment.
|
||||
NOTE: Unavailable on Windows
|
||||
|
||||
Args:
|
||||
Prompt(str): the bash command to execute
|
||||
""" # noqa: E501
|
||||
pexpect = self._lazy_import_pexpect()
|
||||
process = pexpect.spawn(
|
||||
"env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8"
|
||||
)
|
||||
# Set the custom prompt
|
||||
process.sendline("PS1=" + prompt)
|
||||
|
||||
process.expect_exact(prompt, timeout=10)
|
||||
return process
|
||||
|
||||
def run(self, commands: Union[str, List[str]]) -> str:
|
||||
"""
|
||||
Run commands in either an existing persistent
|
||||
subprocess or on in a new subprocess environment.
|
||||
|
||||
Args:
|
||||
commands(List[str]): a list of commands to
|
||||
execute in the session
|
||||
""" # noqa: E501
|
||||
if isinstance(commands, str):
|
||||
commands = [commands]
|
||||
commands = ";".join(commands)
|
||||
if self.process is not None:
|
||||
return self._run_persistent(
|
||||
commands,
|
||||
)
|
||||
else:
|
||||
return self._run(commands)
|
||||
|
||||
def _run(self, command: str) -> str:
|
||||
"""
|
||||
Runs a command in a subprocess and returns
|
||||
the output.
|
||||
|
||||
Args:
|
||||
command: The command to run
|
||||
""" # noqa: E501
|
||||
try:
|
||||
output = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
).stdout.decode()
|
||||
except subprocess.CalledProcessError as error:
|
||||
if self.return_err_output:
|
||||
return error.stdout.decode()
|
||||
return str(error)
|
||||
if self.strip_newlines:
|
||||
output = output.strip()
|
||||
return output
|
||||
|
||||
def process_output(self, output: str, command: str) -> str:
|
||||
"""
|
||||
Uses regex to remove the command from the output
|
||||
|
||||
Args:
|
||||
output: a process' output string
|
||||
command: the executed command
|
||||
""" # noqa: E501
|
||||
pattern = re.escape(command) + r"\s*\n"
|
||||
output = re.sub(pattern, "", output, count=1)
|
||||
return output.strip()
|
||||
|
||||
def _run_persistent(self, command: str) -> str:
|
||||
"""
|
||||
Runs commands in a persistent environment
|
||||
and returns the output.
|
||||
|
||||
Args:
|
||||
command: the command to execute
|
||||
""" # noqa: E501
|
||||
pexpect = self._lazy_import_pexpect()
|
||||
if self.process is None:
|
||||
raise ValueError("Process not initialized")
|
||||
self.process.sendline(command)
|
||||
|
||||
# Clear the output with an empty string
|
||||
self.process.expect(self.prompt, timeout=10)
|
||||
self.process.sendline("")
|
||||
|
||||
try:
|
||||
self.process.expect([self.prompt, pexpect.EOF], timeout=10)
|
||||
except pexpect.TIMEOUT:
|
||||
return f"Timeout error while executing command {command}"
|
||||
if self.process.after == pexpect.EOF:
|
||||
return f"Exited with error status: {self.process.exitstatus}"
|
||||
output = self.process.before
|
||||
output = self.process_output(output, command)
|
||||
if self.strip_newlines:
|
||||
return output.strip()
|
||||
return output
|
@ -0,0 +1,64 @@
|
||||
# flake8: noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
|
||||
|
||||
Question: "copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'"
|
||||
|
||||
I need to take the following actions:
|
||||
- List all files in the directory
|
||||
- Create a new directory
|
||||
- Copy the files from the first directory into the second directory
|
||||
```bash
|
||||
ls
|
||||
mkdir myNewDirectory
|
||||
cp -r target/* myNewDirectory
|
||||
```
|
||||
|
||||
That is the format. Begin!
|
||||
|
||||
Question: {question}"""
|
||||
|
||||
|
||||
class BashOutputParser(BaseOutputParser):
|
||||
"""Parser for bash output."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
if "```bash" in text:
|
||||
return self.get_code_blocks(text)
|
||||
else:
|
||||
raise OutputParserException(
|
||||
f"Failed to parse bash output. Got: {text}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_code_blocks(t: str) -> List[str]:
|
||||
"""Get multiple code blocks from the LLM result."""
|
||||
code_blocks: List[str] = []
|
||||
# Bash markdown code blocks
|
||||
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
|
||||
for match in pattern.finditer(t):
|
||||
matched = match.group(1).strip()
|
||||
if matched:
|
||||
code_blocks.extend(
|
||||
[line for line in matched.split("\n") if line.strip()]
|
||||
)
|
||||
|
||||
return code_blocks
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "bash"
|
||||
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
output_parser=BashOutputParser(),
|
||||
)
|
@ -0,0 +1,102 @@
|
||||
"""Test the bash utility."""
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_experimental.llm_bash.bash import BashProcess
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_pwd_command() -> None:
|
||||
"""Test correct functionality."""
|
||||
session = BashProcess()
|
||||
commands = ["pwd"]
|
||||
output = session.run(commands)
|
||||
|
||||
assert output == subprocess.check_output("pwd", shell=True).decode()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky on GHA, TODO to fix")
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_pwd_command_persistent() -> None:
|
||||
"""Test correct functionality when the bash process is persistent."""
|
||||
session = BashProcess(persistent=True, strip_newlines=True)
|
||||
commands = ["pwd"]
|
||||
output = session.run(commands)
|
||||
|
||||
assert subprocess.check_output("pwd", shell=True).decode().strip() in output
|
||||
|
||||
session.run(["cd .."])
|
||||
new_output = session.run(["pwd"])
|
||||
# Assert that the new_output is a parent of the old output
|
||||
assert Path(output).parent == Path(new_output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_incorrect_command() -> None:
|
||||
"""Test handling of incorrect command."""
|
||||
session = BashProcess()
|
||||
output = session.run(["invalid_command"])
|
||||
assert output == "Command 'invalid_command' returned non-zero exit status 127."
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_incorrect_command_return_err_output() -> None:
|
||||
"""Test optional returning of shell output on incorrect command."""
|
||||
session = BashProcess(return_err_output=True)
|
||||
output = session.run(["invalid_command"])
|
||||
assert re.match(
|
||||
r"^/bin/sh:.*invalid_command.*(?:not found|Permission denied).*$", output
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_create_directory_and_files(tmp_path: Path) -> None:
|
||||
"""Test creation of a directory and files in a temporary directory."""
|
||||
session = BashProcess(strip_newlines=True)
|
||||
|
||||
# create a subdirectory in the temporary directory
|
||||
temp_dir = tmp_path / "test_dir"
|
||||
temp_dir.mkdir()
|
||||
|
||||
# run the commands in the temporary directory
|
||||
commands = [
|
||||
f"touch {temp_dir}/file1.txt",
|
||||
f"touch {temp_dir}/file2.txt",
|
||||
f"echo 'hello world' > {temp_dir}/file2.txt",
|
||||
f"cat {temp_dir}/file2.txt",
|
||||
]
|
||||
|
||||
output = session.run(commands)
|
||||
assert output == "hello world"
|
||||
|
||||
# check that the files were created in the temporary directory
|
||||
output = session.run([f"ls {temp_dir}"])
|
||||
assert output == "file1.txt\nfile2.txt"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky on GHA, TODO to fix")
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_create_bash_persistent() -> None:
|
||||
"""Test the pexpect persistent bash terminal"""
|
||||
session = BashProcess(persistent=True)
|
||||
response = session.run("echo hello")
|
||||
response += session.run("echo world")
|
||||
|
||||
assert "hello" in response
|
||||
assert "world" in response
|
@ -0,0 +1,109 @@
|
||||
"""Test LLM Bash functionality."""
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from langchain.schema import OutputParserException
|
||||
|
||||
from langchain_experimental.llm_bash.base import LLMBashChain
|
||||
from langchain_experimental.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser
|
||||
from tests.unit_tests.fake_llm import FakeLLM
|
||||
|
||||
_SAMPLE_CODE = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
```
|
||||
Unrelated text
|
||||
"""
|
||||
|
||||
|
||||
_SAMPLE_CODE_2_LINES = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
|
||||
echo world
|
||||
```
|
||||
Unrelated text
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_parser() -> BashOutputParser:
|
||||
"""Output parser for testing."""
|
||||
return BashOutputParser()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_simple_question() -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
question = "Please write a bash script that prints 'Hello World' to the console."
|
||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
output = fake_llm_bash_chain.run(question)
|
||||
assert output == "2\n"
|
||||
|
||||
|
||||
def test_get_code(output_parser: BashOutputParser) -> None:
|
||||
"""Test the parser."""
|
||||
code_lines = output_parser.parse(_SAMPLE_CODE)
|
||||
code = [c for c in code_lines if c.strip()]
|
||||
assert code == code_lines
|
||||
assert code == ["echo hello"]
|
||||
|
||||
code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES)
|
||||
assert code_lines == ["echo hello", "echo hello", "echo world"]
|
||||
|
||||
|
||||
def test_parsing_error() -> None:
|
||||
"""Test that LLM Output without a bash block raises an exce"""
|
||||
question = "Please echo 'hello world' to the terminal."
|
||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||
queries = {
|
||||
prompt: """
|
||||
```text
|
||||
echo 'hello world'
|
||||
```
|
||||
"""
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
with pytest.raises(OutputParserException):
|
||||
fake_llm_bash_chain.run(question)
|
||||
|
||||
|
||||
def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None:
|
||||
text = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
ls && pwd && ls
|
||||
```
|
||||
|
||||
```python
|
||||
print("hello")
|
||||
```
|
||||
|
||||
```bash
|
||||
echo goodbye
|
||||
```
|
||||
"""
|
||||
code_lines = output_parser.parse(text)
|
||||
assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"]
|
||||
|
||||
|
||||
def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None:
|
||||
"""Test that backticks w/o a newline are ignored."""
|
||||
text = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
echo "```bash is in this string```"
|
||||
```
|
||||
"""
|
||||
code_lines = output_parser.parse(text)
|
||||
assert code_lines == ["echo hello", 'echo "```bash is in this string```"']
|
Loading…
Reference in New Issue