forked from Archives/langchain
Persistent Bash Shell (#3580)
Clean up linting and make more idiomatic by using an output parser --------- Co-authored-by: FergusFettes <fergusfettes@gmail.com>
This commit is contained in:
parent
c5451f4298
commit
ee670c448e
@ -39,11 +39,27 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"apify.ipynb\n",
|
||||
"arxiv.ipynb\n",
|
||||
"bash.ipynb\n",
|
||||
"bing_search.ipynb\n",
|
||||
"chatgpt_plugins.ipynb\n",
|
||||
"ddg.ipynb\n",
|
||||
"google_places.ipynb\n",
|
||||
"google_search.ipynb\n",
|
||||
"google_serper.ipynb\n",
|
||||
"gradio_tools.ipynb\n",
|
||||
"human_tools.ipynb\n",
|
||||
"ifttt.ipynb\n",
|
||||
"openweathermap.ipynb\n",
|
||||
"python.ipynb\n",
|
||||
"requests.ipynb\n",
|
||||
"search_tools.ipynb\n",
|
||||
"searx_search.ipynb\n",
|
||||
"serpapi.ipynb\n",
|
||||
"wikipedia.ipynb\n",
|
||||
"wolfram_alpha.ipynb\n",
|
||||
"zapier.ipynb\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
@ -54,9 +70,94 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"id": "e7896f8e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"apify.ipynb\n",
|
||||
"arxiv.ipynb\n",
|
||||
"bash.ipynb\n",
|
||||
"bing_search.ipynb\n",
|
||||
"chatgpt_plugins.ipynb\n",
|
||||
"ddg.ipynb\n",
|
||||
"google_places.ipynb\n",
|
||||
"google_search.ipynb\n",
|
||||
"google_serper.ipynb\n",
|
||||
"gradio_tools.ipynb\n",
|
||||
"human_tools.ipynb\n",
|
||||
"ifttt.ipynb\n",
|
||||
"openweathermap.ipynb\n",
|
||||
"python.ipynb\n",
|
||||
"requests.ipynb\n",
|
||||
"search_tools.ipynb\n",
|
||||
"searx_search.ipynb\n",
|
||||
"serpapi.ipynb\n",
|
||||
"wikipedia.ipynb\n",
|
||||
"wolfram_alpha.ipynb\n",
|
||||
"zapier.ipynb\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"bash.run(\"cd ..\")\n",
|
||||
"# The commands are executed in a new subprocess each time, meaning that\n",
|
||||
"# this call will return the same results as the last.\n",
|
||||
"print(bash.run(\"ls\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "851fee9f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Terminal Persistance\n",
|
||||
"\n",
|
||||
"By default, the bash command will be executed in a new subprocess each time. To retain a persistent bash session, we can use the `persistent=True` arg."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "4a93ea2c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"bash = BashProcess(persistent=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a1e98b78",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"custom_tools.ipynb\t\tmulti_input_tool.ipynb\n",
|
||||
"examples\t\t\ttool_input_validation.ipynb\n",
|
||||
"getting_started.md\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"bash.run(\"cd ..\")\n",
|
||||
"# Note the list of files is different\n",
|
||||
"print(bash.run(\"ls\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e13c1c9c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
@ -77,7 +178,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
"version": "3.8.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -24,8 +24,8 @@
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"echo \"Hello World\"\n",
|
||||
"```\u001b[0m['```bash', 'echo \"Hello World\"', '```']\n",
|
||||
"\n",
|
||||
"```\u001b[0m\n",
|
||||
"Code: \u001b[33;1m\u001b[1;3m['echo \"Hello World\"']\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
@ -65,7 +65,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -93,7 +93,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -107,8 +107,8 @@
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"printf \"Hello World\\n\"\n",
|
||||
"```\u001b[0m['```bash', 'printf \"Hello World\\\\n\"', '```']\n",
|
||||
"\n",
|
||||
"```\u001b[0m\n",
|
||||
"Code: \u001b[33;1m\u001b[1;3m['printf \"Hello World\\\\n\"']\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
@ -120,7 +120,7 @@
|
||||
"'Hello World\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 29,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -132,6 +132,114 @@
|
||||
"\n",
|
||||
"bash_chain.run(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Persistent Terminal\n",
|
||||
"\n",
|
||||
"By default, the chain will run in a separate subprocess each time it is called. This behavior can be changed by instantiating with a persistent bash process."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
|
||||
"List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"ls\n",
|
||||
"cd ..\n",
|
||||
"```\u001b[0m\n",
|
||||
"Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3mapi.ipynb\t\t\tllm_summarization_checker.ipynb\n",
|
||||
"constitutional_chain.ipynb\tmoderation.ipynb\n",
|
||||
"llm_bash.ipynb\t\t\topenai_openapi.yaml\n",
|
||||
"llm_checker.ipynb\t\topenapi.ipynb\n",
|
||||
"llm_math.ipynb\t\t\tpal.ipynb\n",
|
||||
"llm_requests.ipynb\t\tsqlite.ipynb\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'api.ipynb\\t\\t\\tllm_summarization_checker.ipynb\\r\\nconstitutional_chain.ipynb\\tmoderation.ipynb\\r\\nllm_bash.ipynb\\t\\t\\topenai_openapi.yaml\\r\\nllm_checker.ipynb\\t\\topenapi.ipynb\\r\\nllm_math.ipynb\\t\\t\\tpal.ipynb\\r\\nllm_requests.ipynb\\t\\tsqlite.ipynb'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.utilities.bash import BashProcess\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"persistent_process = BashProcess(persistent=True)\n",
|
||||
"bash_chain = LLMBashChain.from_bash_process(llm=llm, bash_process=persistent_process, verbose=True)\n",
|
||||
"\n",
|
||||
"text = \"List the current directory then move up a level.\"\n",
|
||||
"\n",
|
||||
"bash_chain.run(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
|
||||
"List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"ls\n",
|
||||
"cd ..\n",
|
||||
"```\u001b[0m\n",
|
||||
"Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3mexamples\t\tgetting_started.ipynb\tindex_examples\n",
|
||||
"generic\t\t\thow_to_guides.rst\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'examples\\t\\tgetting_started.ipynb\\tindex_examples\\r\\ngeneric\\t\\t\\thow_to_guides.rst'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Run the same command again and see that the state is maintained between calls\n",
|
||||
"bash_chain.run(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -150,7 +258,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.8.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1,15 +1,46 @@
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
|
||||
from typing import Dict, List
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import Extra
|
||||
from pydantic import Extra, Field
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BashOutputParser(BaseOutputParser):
|
||||
"""Parser for bash output."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
if "```bash" in text:
|
||||
return self.get_code_blocks(text)
|
||||
else:
|
||||
raise OutputParserException(
|
||||
f"Failed to parse bash output. Got: {text}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_code_blocks(t: str) -> List[str]:
|
||||
"""Get multiple code blocks from the LLM result."""
|
||||
code_blocks: List[str] = []
|
||||
# Bash markdown code blocks
|
||||
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
|
||||
for match in pattern.finditer(t):
|
||||
matched = match.group(1).strip()
|
||||
if matched:
|
||||
code_blocks.extend(
|
||||
[line for line in matched.split("\n") if line.strip()]
|
||||
)
|
||||
|
||||
return code_blocks
|
||||
|
||||
|
||||
class LLMBashChain(Chain):
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations.
|
||||
@ -26,6 +57,8 @@ class LLMBashChain(Chain):
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
output_parser: BaseOutputParser = Field(default_factory=BashOutputParser)
|
||||
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -51,29 +84,40 @@ class LLMBashChain(Chain):
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
|
||||
bash_executor = BashProcess()
|
||||
|
||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||
|
||||
t = llm_executor.predict(question=inputs[self.input_key])
|
||||
self.callback_manager.on_text(t, color="green", verbose=self.verbose)
|
||||
|
||||
t = t.strip()
|
||||
if t.startswith("```bash"):
|
||||
# Split the string into a list of substrings
|
||||
command_list = t.split("\n")
|
||||
print(command_list)
|
||||
try:
|
||||
command_list = self.output_parser.parse(t)
|
||||
except OutputParserException as e:
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
|
||||
# Remove the first and last substrings
|
||||
command_list = [s for s in command_list[1:-1]]
|
||||
output = bash_executor.run(command_list)
|
||||
if self.verbose:
|
||||
self.callback_manager.on_text("\nCode: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(
|
||||
str(command_list), color="yellow", verbose=self.verbose
|
||||
)
|
||||
|
||||
output = self.bash_process.run(command_list)
|
||||
|
||||
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {t}")
|
||||
return {self.output_key: output}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_bash_chain"
|
||||
|
||||
@classmethod
|
||||
def from_bash_process(
|
||||
cls,
|
||||
bash_process: BashProcess,
|
||||
llm: BaseLanguageModel,
|
||||
**kwargs: Any,
|
||||
) -> "LLMBashChain":
|
||||
"""Create a LLMBashChain from a BashProcess."""
|
||||
return cls(llm=llm, bash_process=bash_process, **kwargs)
|
||||
|
@ -1,24 +1,59 @@
|
||||
"""Wrapper around subprocess to run commands."""
|
||||
import re
|
||||
import subprocess
|
||||
from typing import List, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pexpect
|
||||
|
||||
|
||||
class BashProcess:
|
||||
"""Executes bash commands and returns the output."""
|
||||
|
||||
def __init__(self, strip_newlines: bool = False, return_err_output: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
strip_newlines: bool = False,
|
||||
return_err_output: bool = False,
|
||||
persistent: bool = False,
|
||||
):
|
||||
"""Initialize with stripping newlines."""
|
||||
self.strip_newlines = strip_newlines
|
||||
self.return_err_output = return_err_output
|
||||
self.prompt = ""
|
||||
self.process = None
|
||||
if persistent:
|
||||
self.prompt = str(uuid4())
|
||||
self.process = self._initialize_persistent_process(self.prompt)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_persistent_process(prompt: str) -> pexpect.spawn:
|
||||
# Start bash in a clean environment
|
||||
process = pexpect.spawn(
|
||||
"env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8"
|
||||
)
|
||||
# Set the custom prompt
|
||||
process.sendline("PS1=" + prompt)
|
||||
|
||||
process.expect_exact(prompt, timeout=10)
|
||||
return process
|
||||
|
||||
def run(self, commands: Union[str, List[str]]) -> str:
|
||||
"""Run commands and return final output."""
|
||||
if isinstance(commands, str):
|
||||
commands = [commands]
|
||||
commands = ";".join(commands)
|
||||
if self.process is not None:
|
||||
return self._run_persistent(
|
||||
commands,
|
||||
)
|
||||
else:
|
||||
return self._run(commands)
|
||||
|
||||
def _run(self, command: str) -> str:
|
||||
"""Run commands and return final output."""
|
||||
try:
|
||||
output = subprocess.run(
|
||||
commands,
|
||||
command,
|
||||
shell=True,
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
@ -31,3 +66,31 @@ class BashProcess:
|
||||
if self.strip_newlines:
|
||||
output = output.strip()
|
||||
return output
|
||||
|
||||
def process_output(self, output: str, command: str) -> str:
|
||||
# Remove the command from the output using a regular expression
|
||||
pattern = re.escape(command) + r"\s*\n"
|
||||
output = re.sub(pattern, "", output, count=1)
|
||||
return output.strip()
|
||||
|
||||
def _run_persistent(self, command: str) -> str:
|
||||
"""Run commands and return final output."""
|
||||
if self.process is None:
|
||||
raise ValueError("Process not initialized")
|
||||
self.process.sendline(command)
|
||||
|
||||
# Clear the output with an empty string
|
||||
self.process.expect(self.prompt, timeout=10)
|
||||
self.process.sendline("")
|
||||
|
||||
try:
|
||||
self.process.expect([self.prompt, pexpect.EOF], timeout=10)
|
||||
except pexpect.TIMEOUT:
|
||||
return f"Timeout error while executing command {command}"
|
||||
if self.process.after == pexpect.EOF:
|
||||
return f"Exited with error status: {self.process.exitstatus}"
|
||||
output = self.process.before
|
||||
output = self.process_output(output, command)
|
||||
if self.strip_newlines:
|
||||
return output.strip()
|
||||
return output
|
||||
|
@ -3,26 +3,107 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm_bash.base import LLMBashChain
|
||||
from langchain.chains.llm_bash.base import BashOutputParser, LLMBashChain
|
||||
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE
|
||||
from langchain.schema import OutputParserException
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
_SAMPLE_CODE = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
```
|
||||
Unrelated text
|
||||
"""
|
||||
|
||||
|
||||
_SAMPLE_CODE_2_LINES = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
|
||||
echo world
|
||||
```
|
||||
Unrelated text
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_bash_chain() -> LLMBashChain:
|
||||
"""Fake LLM Bash chain for testing."""
|
||||
question = "Please write a bash script that prints 'Hello World' to the console."
|
||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
def output_parser() -> BashOutputParser:
|
||||
"""Output parser for testing."""
|
||||
return BashOutputParser()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_simple_question(fake_llm_bash_chain: LLMBashChain) -> None:
|
||||
def test_simple_question() -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
question = "Please write a bash script that prints 'Hello World' to the console."
|
||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
output = fake_llm_bash_chain.run(question)
|
||||
assert output == "2\n"
|
||||
|
||||
|
||||
def test_get_code(output_parser: BashOutputParser) -> None:
|
||||
"""Test the parser."""
|
||||
code_lines = output_parser.parse(_SAMPLE_CODE)
|
||||
code = [c for c in code_lines if c.strip()]
|
||||
assert code == code_lines
|
||||
assert code == ["echo hello"]
|
||||
|
||||
code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES)
|
||||
assert code_lines == ["echo hello", "echo hello", "echo world"]
|
||||
|
||||
|
||||
def test_parsing_error() -> None:
|
||||
"""Test that LLM Output without a bash block raises an exce"""
|
||||
question = "Please echo 'hello world' to the terminal."
|
||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||
queries = {
|
||||
prompt: """
|
||||
```text
|
||||
echo 'hello world'
|
||||
```
|
||||
"""
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
with pytest.raises(OutputParserException):
|
||||
fake_llm_bash_chain.run(question)
|
||||
|
||||
|
||||
def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None:
|
||||
text = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
ls && pwd && ls
|
||||
```
|
||||
|
||||
```python
|
||||
print("hello")
|
||||
```
|
||||
|
||||
```bash
|
||||
echo goodbye
|
||||
```
|
||||
"""
|
||||
code_lines = output_parser.parse(text)
|
||||
assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"]
|
||||
|
||||
|
||||
def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None:
|
||||
"""Test that backticks w/o a newline are ignored."""
|
||||
text = """
|
||||
Unrelated text
|
||||
```bash
|
||||
echo hello
|
||||
echo "```bash is in this string```"
|
||||
```
|
||||
"""
|
||||
code_lines = output_parser.parse(text)
|
||||
assert code_lines == ["echo hello", 'echo "```bash is in this string```"']
|
||||
|
@ -21,6 +21,23 @@ def test_pwd_command() -> None:
|
||||
assert output == subprocess.check_output("pwd", shell=True).decode()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_pwd_command_persistent() -> None:
|
||||
"""Test correct functionality when the bash process is persistent."""
|
||||
session = BashProcess(persistent=True, strip_newlines=True)
|
||||
commands = ["pwd"]
|
||||
output = session.run(commands)
|
||||
|
||||
assert subprocess.check_output("pwd", shell=True).decode().strip() in output
|
||||
|
||||
session.run(["cd .."])
|
||||
new_output = session.run(["pwd"])
|
||||
# Assert that the new_output is a parent of the old output
|
||||
assert Path(output).parent == Path(new_output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
@ -66,3 +83,16 @@ def test_create_directory_and_files(tmp_path: Path) -> None:
|
||||
# check that the files were created in the temporary directory
|
||||
output = session.run([f"ls {temp_dir}"])
|
||||
assert output == "file1.txt\nfile2.txt"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||
)
|
||||
def test_create_bash_persistent() -> None:
|
||||
"""Test the pexpect persistent bash terminal"""
|
||||
session = BashProcess(persistent=True)
|
||||
response = session.run("echo hello")
|
||||
response += session.run("echo world")
|
||||
|
||||
assert "hello" in response
|
||||
assert "world" in response
|
||||
|
Loading…
Reference in New Issue
Block a user