forked from Archives/langchain
Persistent Bash Shell (#3580)
Clean up linting and make more idiomatic by using an output parser --------- Co-authored-by: FergusFettes <fergusfettes@gmail.com>
This commit is contained in:
parent
c5451f4298
commit
ee670c448e
@ -39,11 +39,27 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
|
"apify.ipynb\n",
|
||||||
|
"arxiv.ipynb\n",
|
||||||
"bash.ipynb\n",
|
"bash.ipynb\n",
|
||||||
|
"bing_search.ipynb\n",
|
||||||
|
"chatgpt_plugins.ipynb\n",
|
||||||
|
"ddg.ipynb\n",
|
||||||
|
"google_places.ipynb\n",
|
||||||
"google_search.ipynb\n",
|
"google_search.ipynb\n",
|
||||||
|
"google_serper.ipynb\n",
|
||||||
|
"gradio_tools.ipynb\n",
|
||||||
|
"human_tools.ipynb\n",
|
||||||
|
"ifttt.ipynb\n",
|
||||||
|
"openweathermap.ipynb\n",
|
||||||
"python.ipynb\n",
|
"python.ipynb\n",
|
||||||
"requests.ipynb\n",
|
"requests.ipynb\n",
|
||||||
|
"search_tools.ipynb\n",
|
||||||
|
"searx_search.ipynb\n",
|
||||||
"serpapi.ipynb\n",
|
"serpapi.ipynb\n",
|
||||||
|
"wikipedia.ipynb\n",
|
||||||
|
"wolfram_alpha.ipynb\n",
|
||||||
|
"zapier.ipynb\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -54,9 +70,94 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 4,
|
||||||
|
"id": "e7896f8e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"apify.ipynb\n",
|
||||||
|
"arxiv.ipynb\n",
|
||||||
|
"bash.ipynb\n",
|
||||||
|
"bing_search.ipynb\n",
|
||||||
|
"chatgpt_plugins.ipynb\n",
|
||||||
|
"ddg.ipynb\n",
|
||||||
|
"google_places.ipynb\n",
|
||||||
|
"google_search.ipynb\n",
|
||||||
|
"google_serper.ipynb\n",
|
||||||
|
"gradio_tools.ipynb\n",
|
||||||
|
"human_tools.ipynb\n",
|
||||||
|
"ifttt.ipynb\n",
|
||||||
|
"openweathermap.ipynb\n",
|
||||||
|
"python.ipynb\n",
|
||||||
|
"requests.ipynb\n",
|
||||||
|
"search_tools.ipynb\n",
|
||||||
|
"searx_search.ipynb\n",
|
||||||
|
"serpapi.ipynb\n",
|
||||||
|
"wikipedia.ipynb\n",
|
||||||
|
"wolfram_alpha.ipynb\n",
|
||||||
|
"zapier.ipynb\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"bash.run(\"cd ..\")\n",
|
||||||
|
"# The commands are executed in a new subprocess each time, meaning that\n",
|
||||||
|
"# this call will return the same results as the last.\n",
|
||||||
|
"print(bash.run(\"ls\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
"id": "851fee9f",
|
"id": "851fee9f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Terminal Persistance\n",
|
||||||
|
"\n",
|
||||||
|
"By default, the bash command will be executed in a new subprocess each time. To retain a persistent bash session, we can use the `persistent=True` arg."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "4a93ea2c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"bash = BashProcess(persistent=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "a1e98b78",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"custom_tools.ipynb\t\tmulti_input_tool.ipynb\n",
|
||||||
|
"examples\t\t\ttool_input_validation.ipynb\n",
|
||||||
|
"getting_started.md\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"bash.run(\"cd ..\")\n",
|
||||||
|
"# Note the list of files is different\n",
|
||||||
|
"print(bash.run(\"ls\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e13c1c9c",
|
||||||
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
@ -77,7 +178,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.9"
|
"version": "3.8.16"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -24,8 +24,8 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"```bash\n",
|
"```bash\n",
|
||||||
"echo \"Hello World\"\n",
|
"echo \"Hello World\"\n",
|
||||||
"```\u001b[0m['```bash', 'echo \"Hello World\"', '```']\n",
|
"```\u001b[0m\n",
|
||||||
"\n",
|
"Code: \u001b[33;1m\u001b[1;3m['echo \"Hello World\"']\u001b[0m\n",
|
||||||
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
|
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
|
||||||
"\u001b[0m\n",
|
"\u001b[0m\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
@ -65,7 +65,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 28,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -93,7 +93,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 29,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -107,8 +107,8 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"```bash\n",
|
"```bash\n",
|
||||||
"printf \"Hello World\\n\"\n",
|
"printf \"Hello World\\n\"\n",
|
||||||
"```\u001b[0m['```bash', 'printf \"Hello World\\\\n\"', '```']\n",
|
"```\u001b[0m\n",
|
||||||
"\n",
|
"Code: \u001b[33;1m\u001b[1;3m['printf \"Hello World\\\\n\"']\u001b[0m\n",
|
||||||
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
|
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
|
||||||
"\u001b[0m\n",
|
"\u001b[0m\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
@ -120,7 +120,7 @@
|
|||||||
"'Hello World\\n'"
|
"'Hello World\\n'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 29,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -132,6 +132,114 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"bash_chain.run(text)"
|
"bash_chain.run(text)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Persistent Terminal\n",
|
||||||
|
"\n",
|
||||||
|
"By default, the chain will run in a separate subprocess each time it is called. This behavior can be changed by instantiating with a persistent bash process."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
|
||||||
|
"List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"ls\n",
|
||||||
|
"cd ..\n",
|
||||||
|
"```\u001b[0m\n",
|
||||||
|
"Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n",
|
||||||
|
"Answer: \u001b[33;1m\u001b[1;3mapi.ipynb\t\t\tllm_summarization_checker.ipynb\n",
|
||||||
|
"constitutional_chain.ipynb\tmoderation.ipynb\n",
|
||||||
|
"llm_bash.ipynb\t\t\topenai_openapi.yaml\n",
|
||||||
|
"llm_checker.ipynb\t\topenapi.ipynb\n",
|
||||||
|
"llm_math.ipynb\t\t\tpal.ipynb\n",
|
||||||
|
"llm_requests.ipynb\t\tsqlite.ipynb\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'api.ipynb\\t\\t\\tllm_summarization_checker.ipynb\\r\\nconstitutional_chain.ipynb\\tmoderation.ipynb\\r\\nllm_bash.ipynb\\t\\t\\topenai_openapi.yaml\\r\\nllm_checker.ipynb\\t\\topenapi.ipynb\\r\\nllm_math.ipynb\\t\\t\\tpal.ipynb\\r\\nllm_requests.ipynb\\t\\tsqlite.ipynb'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain.utilities.bash import BashProcess\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"persistent_process = BashProcess(persistent=True)\n",
|
||||||
|
"bash_chain = LLMBashChain.from_bash_process(llm=llm, bash_process=persistent_process, verbose=True)\n",
|
||||||
|
"\n",
|
||||||
|
"text = \"List the current directory then move up a level.\"\n",
|
||||||
|
"\n",
|
||||||
|
"bash_chain.run(text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
|
||||||
|
"List the current directory then move up a level.\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"ls\n",
|
||||||
|
"cd ..\n",
|
||||||
|
"```\u001b[0m\n",
|
||||||
|
"Code: \u001b[33;1m\u001b[1;3m['ls', 'cd ..']\u001b[0m\n",
|
||||||
|
"Answer: \u001b[33;1m\u001b[1;3mexamples\t\tgetting_started.ipynb\tindex_examples\n",
|
||||||
|
"generic\t\t\thow_to_guides.rst\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'examples\\t\\tgetting_started.ipynb\\tindex_examples\\r\\ngeneric\\t\\t\\thow_to_guides.rst'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Run the same command again and see that the state is maintained between calls\n",
|
||||||
|
"bash_chain.run(text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -150,7 +258,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.6"
|
"version": "3.8.16"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,15 +1,46 @@
|
|||||||
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
|
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
|
||||||
from typing import Dict, List
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from pydantic import Extra
|
from pydantic import Extra, Field
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_bash.prompt import PROMPT
|
from langchain.chains.llm_bash.prompt import PROMPT
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel
|
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
|
||||||
from langchain.utilities.bash import BashProcess
|
from langchain.utilities.bash import BashProcess
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BashOutputParser(BaseOutputParser):
|
||||||
|
"""Parser for bash output."""
|
||||||
|
|
||||||
|
def parse(self, text: str) -> List[str]:
|
||||||
|
if "```bash" in text:
|
||||||
|
return self.get_code_blocks(text)
|
||||||
|
else:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Failed to parse bash output. Got: {text}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_code_blocks(t: str) -> List[str]:
|
||||||
|
"""Get multiple code blocks from the LLM result."""
|
||||||
|
code_blocks: List[str] = []
|
||||||
|
# Bash markdown code blocks
|
||||||
|
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
|
||||||
|
for match in pattern.finditer(t):
|
||||||
|
matched = match.group(1).strip()
|
||||||
|
if matched:
|
||||||
|
code_blocks.extend(
|
||||||
|
[line for line in matched.split("\n") if line.strip()]
|
||||||
|
)
|
||||||
|
|
||||||
|
return code_blocks
|
||||||
|
|
||||||
|
|
||||||
class LLMBashChain(Chain):
|
class LLMBashChain(Chain):
|
||||||
"""Chain that interprets a prompt and executes bash code to perform bash operations.
|
"""Chain that interprets a prompt and executes bash code to perform bash operations.
|
||||||
@ -26,6 +57,8 @@ class LLMBashChain(Chain):
|
|||||||
input_key: str = "question" #: :meta private:
|
input_key: str = "question" #: :meta private:
|
||||||
output_key: str = "answer" #: :meta private:
|
output_key: str = "answer" #: :meta private:
|
||||||
prompt: BasePromptTemplate = PROMPT
|
prompt: BasePromptTemplate = PROMPT
|
||||||
|
output_parser: BaseOutputParser = Field(default_factory=BashOutputParser)
|
||||||
|
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -51,29 +84,40 @@ class LLMBashChain(Chain):
|
|||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
|
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
|
||||||
bash_executor = BashProcess()
|
|
||||||
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||||
|
|
||||||
t = llm_executor.predict(question=inputs[self.input_key])
|
t = llm_executor.predict(question=inputs[self.input_key])
|
||||||
self.callback_manager.on_text(t, color="green", verbose=self.verbose)
|
self.callback_manager.on_text(t, color="green", verbose=self.verbose)
|
||||||
|
|
||||||
t = t.strip()
|
t = t.strip()
|
||||||
if t.startswith("```bash"):
|
try:
|
||||||
# Split the string into a list of substrings
|
command_list = self.output_parser.parse(t)
|
||||||
command_list = t.split("\n")
|
except OutputParserException as e:
|
||||||
print(command_list)
|
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||||
|
raise e
|
||||||
|
|
||||||
# Remove the first and last substrings
|
if self.verbose:
|
||||||
command_list = [s for s in command_list[1:-1]]
|
self.callback_manager.on_text("\nCode: ", verbose=self.verbose)
|
||||||
output = bash_executor.run(command_list)
|
self.callback_manager.on_text(
|
||||||
|
str(command_list), color="yellow", verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.bash_process.run(command_list)
|
||||||
|
|
||||||
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||||
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
|
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown format from LLM: {t}")
|
|
||||||
return {self.output_key: output}
|
return {self.output_key: output}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "llm_bash_chain"
|
return "llm_bash_chain"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bash_process(
|
||||||
|
cls,
|
||||||
|
bash_process: BashProcess,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "LLMBashChain":
|
||||||
|
"""Create a LLMBashChain from a BashProcess."""
|
||||||
|
return cls(llm=llm, bash_process=bash_process, **kwargs)
|
||||||
|
@ -1,24 +1,59 @@
|
|||||||
"""Wrapper around subprocess to run commands."""
|
"""Wrapper around subprocess to run commands."""
|
||||||
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pexpect
|
||||||
|
|
||||||
|
|
||||||
class BashProcess:
|
class BashProcess:
|
||||||
"""Executes bash commands and returns the output."""
|
"""Executes bash commands and returns the output."""
|
||||||
|
|
||||||
def __init__(self, strip_newlines: bool = False, return_err_output: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
strip_newlines: bool = False,
|
||||||
|
return_err_output: bool = False,
|
||||||
|
persistent: bool = False,
|
||||||
|
):
|
||||||
"""Initialize with stripping newlines."""
|
"""Initialize with stripping newlines."""
|
||||||
self.strip_newlines = strip_newlines
|
self.strip_newlines = strip_newlines
|
||||||
self.return_err_output = return_err_output
|
self.return_err_output = return_err_output
|
||||||
|
self.prompt = ""
|
||||||
|
self.process = None
|
||||||
|
if persistent:
|
||||||
|
self.prompt = str(uuid4())
|
||||||
|
self.process = self._initialize_persistent_process(self.prompt)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _initialize_persistent_process(prompt: str) -> pexpect.spawn:
|
||||||
|
# Start bash in a clean environment
|
||||||
|
process = pexpect.spawn(
|
||||||
|
"env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8"
|
||||||
|
)
|
||||||
|
# Set the custom prompt
|
||||||
|
process.sendline("PS1=" + prompt)
|
||||||
|
|
||||||
|
process.expect_exact(prompt, timeout=10)
|
||||||
|
return process
|
||||||
|
|
||||||
def run(self, commands: Union[str, List[str]]) -> str:
|
def run(self, commands: Union[str, List[str]]) -> str:
|
||||||
"""Run commands and return final output."""
|
"""Run commands and return final output."""
|
||||||
if isinstance(commands, str):
|
if isinstance(commands, str):
|
||||||
commands = [commands]
|
commands = [commands]
|
||||||
commands = ";".join(commands)
|
commands = ";".join(commands)
|
||||||
|
if self.process is not None:
|
||||||
|
return self._run_persistent(
|
||||||
|
commands,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._run(commands)
|
||||||
|
|
||||||
|
def _run(self, command: str) -> str:
|
||||||
|
"""Run commands and return final output."""
|
||||||
try:
|
try:
|
||||||
output = subprocess.run(
|
output = subprocess.run(
|
||||||
commands,
|
command,
|
||||||
shell=True,
|
shell=True,
|
||||||
check=True,
|
check=True,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
@ -31,3 +66,31 @@ class BashProcess:
|
|||||||
if self.strip_newlines:
|
if self.strip_newlines:
|
||||||
output = output.strip()
|
output = output.strip()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def process_output(self, output: str, command: str) -> str:
|
||||||
|
# Remove the command from the output using a regular expression
|
||||||
|
pattern = re.escape(command) + r"\s*\n"
|
||||||
|
output = re.sub(pattern, "", output, count=1)
|
||||||
|
return output.strip()
|
||||||
|
|
||||||
|
def _run_persistent(self, command: str) -> str:
|
||||||
|
"""Run commands and return final output."""
|
||||||
|
if self.process is None:
|
||||||
|
raise ValueError("Process not initialized")
|
||||||
|
self.process.sendline(command)
|
||||||
|
|
||||||
|
# Clear the output with an empty string
|
||||||
|
self.process.expect(self.prompt, timeout=10)
|
||||||
|
self.process.sendline("")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.process.expect([self.prompt, pexpect.EOF], timeout=10)
|
||||||
|
except pexpect.TIMEOUT:
|
||||||
|
return f"Timeout error while executing command {command}"
|
||||||
|
if self.process.after == pexpect.EOF:
|
||||||
|
return f"Exited with error status: {self.process.exitstatus}"
|
||||||
|
output = self.process.before
|
||||||
|
output = self.process_output(output, command)
|
||||||
|
if self.strip_newlines:
|
||||||
|
return output.strip()
|
||||||
|
return output
|
||||||
|
@ -3,26 +3,107 @@ import sys
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.chains.llm_bash.base import LLMBashChain
|
from langchain.chains.llm_bash.base import BashOutputParser, LLMBashChain
|
||||||
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE
|
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE
|
||||||
|
from langchain.schema import OutputParserException
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
_SAMPLE_CODE = """
|
||||||
|
Unrelated text
|
||||||
|
```bash
|
||||||
|
echo hello
|
||||||
|
```
|
||||||
|
Unrelated text
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
_SAMPLE_CODE_2_LINES = """
|
||||||
|
Unrelated text
|
||||||
|
```bash
|
||||||
|
echo hello
|
||||||
|
|
||||||
|
echo world
|
||||||
|
```
|
||||||
|
Unrelated text
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def fake_llm_bash_chain() -> LLMBashChain:
|
def output_parser() -> BashOutputParser:
|
||||||
"""Fake LLM Bash chain for testing."""
|
"""Output parser for testing."""
|
||||||
question = "Please write a bash script that prints 'Hello World' to the console."
|
return BashOutputParser()
|
||||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
|
||||||
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
|
|
||||||
fake_llm = FakeLLM(queries=queries)
|
|
||||||
return LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
)
|
)
|
||||||
def test_simple_question(fake_llm_bash_chain: LLMBashChain) -> None:
|
def test_simple_question() -> None:
|
||||||
"""Test simple question that should not need python."""
|
"""Test simple question that should not need python."""
|
||||||
question = "Please write a bash script that prints 'Hello World' to the console."
|
question = "Please write a bash script that prints 'Hello World' to the console."
|
||||||
|
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||||
|
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
|
||||||
|
fake_llm = FakeLLM(queries=queries)
|
||||||
|
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||||
output = fake_llm_bash_chain.run(question)
|
output = fake_llm_bash_chain.run(question)
|
||||||
assert output == "2\n"
|
assert output == "2\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_code(output_parser: BashOutputParser) -> None:
|
||||||
|
"""Test the parser."""
|
||||||
|
code_lines = output_parser.parse(_SAMPLE_CODE)
|
||||||
|
code = [c for c in code_lines if c.strip()]
|
||||||
|
assert code == code_lines
|
||||||
|
assert code == ["echo hello"]
|
||||||
|
|
||||||
|
code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES)
|
||||||
|
assert code_lines == ["echo hello", "echo hello", "echo world"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parsing_error() -> None:
|
||||||
|
"""Test that LLM Output without a bash block raises an exce"""
|
||||||
|
question = "Please echo 'hello world' to the terminal."
|
||||||
|
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||||
|
queries = {
|
||||||
|
prompt: """
|
||||||
|
```text
|
||||||
|
echo 'hello world'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
}
|
||||||
|
fake_llm = FakeLLM(queries=queries)
|
||||||
|
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||||
|
with pytest.raises(OutputParserException):
|
||||||
|
fake_llm_bash_chain.run(question)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None:
|
||||||
|
text = """
|
||||||
|
Unrelated text
|
||||||
|
```bash
|
||||||
|
echo hello
|
||||||
|
ls && pwd && ls
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
print("hello")
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo goodbye
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
code_lines = output_parser.parse(text)
|
||||||
|
assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None:
|
||||||
|
"""Test that backticks w/o a newline are ignored."""
|
||||||
|
text = """
|
||||||
|
Unrelated text
|
||||||
|
```bash
|
||||||
|
echo hello
|
||||||
|
echo "```bash is in this string```"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
code_lines = output_parser.parse(text)
|
||||||
|
assert code_lines == ["echo hello", 'echo "```bash is in this string```"']
|
||||||
|
@ -21,6 +21,23 @@ def test_pwd_command() -> None:
|
|||||||
assert output == subprocess.check_output("pwd", shell=True).decode()
|
assert output == subprocess.check_output("pwd", shell=True).decode()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
|
)
|
||||||
|
def test_pwd_command_persistent() -> None:
|
||||||
|
"""Test correct functionality when the bash process is persistent."""
|
||||||
|
session = BashProcess(persistent=True, strip_newlines=True)
|
||||||
|
commands = ["pwd"]
|
||||||
|
output = session.run(commands)
|
||||||
|
|
||||||
|
assert subprocess.check_output("pwd", shell=True).decode().strip() in output
|
||||||
|
|
||||||
|
session.run(["cd .."])
|
||||||
|
new_output = session.run(["pwd"])
|
||||||
|
# Assert that the new_output is a parent of the old output
|
||||||
|
assert Path(output).parent == Path(new_output)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
)
|
)
|
||||||
@ -66,3 +83,16 @@ def test_create_directory_and_files(tmp_path: Path) -> None:
|
|||||||
# check that the files were created in the temporary directory
|
# check that the files were created in the temporary directory
|
||||||
output = session.run([f"ls {temp_dir}"])
|
output = session.run([f"ls {temp_dir}"])
|
||||||
assert output == "file1.txt\nfile2.txt"
|
assert output == "file1.txt\nfile2.txt"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
|
)
|
||||||
|
def test_create_bash_persistent() -> None:
|
||||||
|
"""Test the pexpect persistent bash terminal"""
|
||||||
|
session = BashProcess(persistent=True)
|
||||||
|
response = session.run("echo hello")
|
||||||
|
response += session.run("echo world")
|
||||||
|
|
||||||
|
assert "hello" in response
|
||||||
|
assert "world" in response
|
||||||
|
Loading…
Reference in New Issue
Block a user