From 58220cda7270f917d51231e9c2886390abe6e27c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 10 Oct 2023 14:54:09 -0400 Subject: [PATCH] Remove LLM Bash and related bash utilities (#11619) Deprecate LLMBash and related bash utilities --- libs/langchain/langchain/__init__.py | 12 +- libs/langchain/langchain/chains/__init__.py | 2 - .../langchain/chains/llm_bash/__init__.py | 13 +- .../langchain/chains/llm_bash/base.py | 138 ------------- .../langchain/chains/llm_bash/prompt.py | 64 ------ libs/langchain/langchain/chains/loading.py | 5 +- libs/langchain/langchain/tools/shell/tool.py | 17 +- .../langchain/langchain/utilities/__init__.py | 2 - libs/langchain/langchain/utilities/bash.py | 183 ------------------ .../tests/unit_tests/chains/test_llm_bash.py | 109 ----------- libs/langchain/tests/unit_tests/test_bash.py | 102 ---------- .../unit_tests/tools/shell/test_shell.py | 35 +++- 12 files changed, 63 insertions(+), 619 deletions(-) delete mode 100644 libs/langchain/langchain/chains/llm_bash/base.py delete mode 100644 libs/langchain/langchain/chains/llm_bash/prompt.py delete mode 100644 libs/langchain/langchain/utilities/bash.py delete mode 100644 libs/langchain/tests/unit_tests/chains/test_llm_bash.py delete mode 100644 libs/langchain/tests/unit_tests/test_bash.py diff --git a/libs/langchain/langchain/__init__.py b/libs/langchain/langchain/__init__.py index feef1a7ee8..8879c101b2 100644 --- a/libs/langchain/langchain/__init__.py +++ b/libs/langchain/langchain/__init__.py @@ -72,11 +72,15 @@ def __getattr__(name: str) -> Any: return ConversationChain elif name == "LLMBashChain": - from langchain.chains import LLMBashChain + raise ImportError( + "This module has been moved to langchain-experimental. " + "For more details: " + "https://github.com/langchain-ai/langchain/discussions/11352." + "To access this code, install it with `pip install langchain-experimental`." + "`from langchain_experimental.llm_bash.base " + "import LLMBashChain`" + ) - _warn_on_import(name) - - return LLMBashChain elif name == "LLMChain": from langchain.chains import LLMChain diff --git a/libs/langchain/langchain/chains/__init__.py b/libs/langchain/langchain/chains/__init__.py index 4bb5242729..f03927512f 100644 --- a/libs/langchain/langchain/chains/__init__.py +++ b/libs/langchain/langchain/chains/__init__.py @@ -44,7 +44,6 @@ from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain from langchain.chains.graph_qa.sparql import GraphSparqlQAChain from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain -from langchain.chains.llm_bash.base import LLMBashChain from langchain.chains.llm_checker.base import LLMCheckerChain from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_requests import LLMRequestsChain @@ -94,7 +93,6 @@ __all__ = [ "HugeGraphQAChain", "HypotheticalDocumentEmbedder", "KuzuQAChain", - "LLMBashChain", "LLMChain", "LLMCheckerChain", "LLMMathChain", diff --git a/libs/langchain/langchain/chains/llm_bash/__init__.py b/libs/langchain/langchain/chains/llm_bash/__init__.py index e1e848a1a8..74f7c29d89 100644 --- a/libs/langchain/langchain/chains/llm_bash/__init__.py +++ b/libs/langchain/langchain/chains/llm_bash/__init__.py @@ -1 +1,12 @@ -"""Chain that interprets a prompt and executes bash code to perform bash operations.""" +def raise_on_import() -> None: + """Raise an error on import since is deprecated.""" + raise ImportError( + "This module has been moved to langchain-experimental. " + "For more details: https://github.com/langchain-ai/langchain/discussions/11352." + "To access this code, install it with `pip install langchain-experimental`." + "`from langchain_experimental.llm_bash.base " + "import LLMBashChain`" + ) + + +raise_on_import() diff --git a/libs/langchain/langchain/chains/llm_bash/base.py b/libs/langchain/langchain/chains/llm_bash/base.py deleted file mode 100644 index 0dc4c413cf..0000000000 --- a/libs/langchain/langchain/chains/llm_bash/base.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Chain that interprets a prompt and executes bash operations.""" -from __future__ import annotations - -import logging -import warnings -from typing import Any, Dict, List, Optional - -from langchain._api import warn_deprecated -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.chains.base import Chain -from langchain.chains.llm import LLMChain -from langchain.chains.llm_bash.prompt import PROMPT -from langchain.pydantic_v1 import Extra, Field, root_validator -from langchain.schema import BasePromptTemplate, OutputParserException -from langchain.schema.language_model import BaseLanguageModel -from langchain.utilities.bash import BashProcess - -logger = logging.getLogger(__name__) - - -class LLMBashChain(Chain): - """Chain that interprets a prompt and executes bash operations. - - Warning: - This chain can execute arbitrary code using bash. - This can be dangerous if not properly sandboxed. - - Example: - - .. code-block:: python - - from langchain.chains import LLMBashChain - from langchain.llms import OpenAI - llm_bash = LLMBashChain.from_llm(OpenAI()) - """ - - llm_chain: LLMChain - llm: Optional[BaseLanguageModel] = None - """[Deprecated] LLM wrapper to use.""" - input_key: str = "question" #: :meta private: - output_key: str = "answer" #: :meta private: - prompt: BasePromptTemplate = PROMPT - """[Deprecated]""" - bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: - if "llm" in values: - warnings.warn( - "Directly instantiating an LLMBashChain with an llm is deprecated. " - "Please instantiate with llm_chain or using the from_llm class method." - ) - if "llm_chain" not in values and values["llm"] is not None: - prompt = values.get("prompt", PROMPT) - values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) - return values - - @root_validator - def validate_prompt(cls, values: Dict) -> Dict: - if values["llm_chain"].prompt.output_parser is None: - raise ValueError( - "The prompt used by llm_chain is expected to have an output_parser." - ) - return values - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Expect output key. - - :meta private: - """ - return [self.output_key] - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - warn_deprecated( - since="0.0.308", - message=( - "On 2023-10-12 the LLMBashChain " - "will be moved to langchain-experimental" - ), - pending=True, - ) - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - _run_manager.on_text(inputs[self.input_key], verbose=self.verbose) - - t = self.llm_chain.predict( - question=inputs[self.input_key], callbacks=_run_manager.get_child() - ) - _run_manager.on_text(t, color="green", verbose=self.verbose) - t = t.strip() - try: - parser = self.llm_chain.prompt.output_parser - command_list = parser.parse(t) # type: ignore[union-attr] - except OutputParserException as e: - _run_manager.on_chain_error(e, verbose=self.verbose) - raise e - - if self.verbose: - _run_manager.on_text("\nCode: ", verbose=self.verbose) - _run_manager.on_text( - str(command_list), color="yellow", verbose=self.verbose - ) - output = self.bash_process.run(command_list) - _run_manager.on_text("\nAnswer: ", verbose=self.verbose) - _run_manager.on_text(output, color="yellow", verbose=self.verbose) - return {self.output_key: output} - - @property - def _chain_type(self) -> str: - return "llm_bash_chain" - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - prompt: BasePromptTemplate = PROMPT, - **kwargs: Any, - ) -> LLMBashChain: - llm_chain = LLMChain(llm=llm, prompt=prompt) - return cls(llm_chain=llm_chain, **kwargs) diff --git a/libs/langchain/langchain/chains/llm_bash/prompt.py b/libs/langchain/langchain/chains/llm_bash/prompt.py deleted file mode 100644 index 72951d2fe9..0000000000 --- a/libs/langchain/langchain/chains/llm_bash/prompt.py +++ /dev/null @@ -1,64 +0,0 @@ -# flake8: noqa -from __future__ import annotations - -import re -from typing import List - -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseOutputParser, OutputParserException - -_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format: - -Question: "copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'" - -I need to take the following actions: -- List all files in the directory -- Create a new directory -- Copy the files from the first directory into the second directory -```bash -ls -mkdir myNewDirectory -cp -r target/* myNewDirectory -``` - -That is the format. Begin! - -Question: {question}""" - - -class BashOutputParser(BaseOutputParser): - """Parser for bash output.""" - - def parse(self, text: str) -> List[str]: - if "```bash" in text: - return self.get_code_blocks(text) - else: - raise OutputParserException( - f"Failed to parse bash output. Got: {text}", - ) - - @staticmethod - def get_code_blocks(t: str) -> List[str]: - """Get multiple code blocks from the LLM result.""" - code_blocks: List[str] = [] - # Bash markdown code blocks - pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL) - for match in pattern.finditer(t): - matched = match.group(1).strip() - if matched: - code_blocks.extend( - [line for line in matched.split("\n") if line.strip()] - ) - - return code_blocks - - @property - def _type(self) -> str: - return "bash" - - -PROMPT = PromptTemplate( - input_variables=["question"], - template=_PROMPT_TEMPLATE, - output_parser=BashOutputParser(), -) diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index f8fe1396c8..ab8f1c519c 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -15,7 +15,6 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.graph_qa.cypher import GraphCypherQAChain from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain -from langchain.chains.llm_bash.base import LLMBashChain from langchain.chains.llm_checker.base import LLMCheckerChain from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_requests import LLMRequestsChain @@ -183,7 +182,9 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments ) -def _load_llm_bash_chain(config: dict, **kwargs: Any) -> LLMBashChain: +def _load_llm_bash_chain(config: dict, **kwargs: Any) -> Any: + from langchain_experimental.llm_bash.base import LLMBashChain + llm_chain = None if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") diff --git a/libs/langchain/langchain/tools/shell/tool.py b/libs/langchain/langchain/tools/shell/tool.py index 3ad4c3ed98..89522a66d8 100644 --- a/libs/langchain/langchain/tools/shell/tool.py +++ b/libs/langchain/langchain/tools/shell/tool.py @@ -1,7 +1,7 @@ import asyncio import platform import warnings -from typing import List, Optional, Type, Union +from typing import Any, List, Optional, Type, Union from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, @@ -9,7 +9,6 @@ from langchain.callbacks.manager import ( ) from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.tools.base import BaseTool -from langchain.utilities.bash import BashProcess class ShellInput(BaseModel): @@ -35,8 +34,16 @@ class ShellInput(BaseModel): return values -def _get_default_bash_processs() -> BashProcess: - """Get file path from string.""" +def _get_default_bash_process() -> Any: + """Get default bash process.""" + try: + from langchain_experimental.llm_bash.bash import BashProcess + except ImportError: + raise ImportError( + "BashProcess has been moved to langchain experimental." + "To use this tool, install langchain-experimental " + "with `pip install langchain-experimental`." + ) return BashProcess(return_err_output=True) @@ -51,7 +58,7 @@ def _get_platform() -> str: class ShellTool(BaseTool): """Tool to run shell commands.""" - process: BashProcess = Field(default_factory=_get_default_bash_processs) + process: Any = Field(default_factory=_get_default_bash_process) """Bash process to run commands.""" name: str = "terminal" diff --git a/libs/langchain/langchain/utilities/__init__.py b/libs/langchain/langchain/utilities/__init__.py index a091c8d7f4..21f8568780 100644 --- a/libs/langchain/langchain/utilities/__init__.py +++ b/libs/langchain/langchain/utilities/__init__.py @@ -8,7 +8,6 @@ from langchain.utilities.apify import ApifyWrapper from langchain.utilities.arcee import ArceeWrapper from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.awslambda import LambdaWrapper -from langchain.utilities.bash import BashProcess from langchain.utilities.bibtex import BibtexparserWrapper from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.brave_search import BraveSearchWrapper @@ -44,7 +43,6 @@ __all__ = [ "ApifyWrapper", "ArceeWrapper", "ArxivAPIWrapper", - "BashProcess", "BibtexparserWrapper", "BingSearchAPIWrapper", "BraveSearchWrapper", diff --git a/libs/langchain/langchain/utilities/bash.py b/libs/langchain/langchain/utilities/bash.py deleted file mode 100644 index bbca0a7ebb..0000000000 --- a/libs/langchain/langchain/utilities/bash.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Wrapper around subprocess to run commands.""" -from __future__ import annotations - -import platform -import re -import subprocess -from typing import TYPE_CHECKING, List, Union -from uuid import uuid4 - -if TYPE_CHECKING: - import pexpect - - -class BashProcess: - """ - Wrapper class for starting subprocesses. - Uses the python built-in subprocesses.run() - Persistent processes are **not** available - on Windows systems, as pexpect makes use of - Unix pseudoterminals (ptys). MacOS and Linux - are okay. - - Example: - .. code-block:: python - - from langchain.utilities.bash import BashProcess - bash = BashProcess( - strip_newlines = False, - return_err_output = False, - persistent = False - ) - bash.run('echo \'hello world\'') - - """ - - strip_newlines: bool = False - """Whether or not to run .strip() on the output""" - return_err_output: bool = False - """Whether or not to return the output of a failed - command, or just the error message and stacktrace""" - persistent: bool = False - """Whether or not to spawn a persistent session - NOTE: Unavailable for Windows environments""" - - def __init__( - self, - strip_newlines: bool = False, - return_err_output: bool = False, - persistent: bool = False, - ): - """ - Initializes with default settings - """ - self.strip_newlines = strip_newlines - self.return_err_output = return_err_output - self.prompt = "" - self.process = None - if persistent: - self.prompt = str(uuid4()) - self.process = self._initialize_persistent_process(self, self.prompt) - - @staticmethod - def _lazy_import_pexpect() -> pexpect: - """Import pexpect only when needed.""" - if platform.system() == "Windows": - raise ValueError( - "Persistent bash processes are not yet supported on Windows." - ) - try: - import pexpect - - except ImportError: - raise ImportError( - "pexpect required for persistent bash processes." - " To install, run `pip install pexpect`." - ) - return pexpect - - @staticmethod - def _initialize_persistent_process(self: BashProcess, prompt: str) -> pexpect.spawn: - # Start bash in a clean environment - # Doesn't work on windows - """ - Initializes a persistent bash setting in a - clean environment. - NOTE: Unavailable on Windows - - Args: - Prompt(str): the bash command to execute - """ # noqa: E501 - pexpect = self._lazy_import_pexpect() - process = pexpect.spawn( - "env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8" - ) - # Set the custom prompt - process.sendline("PS1=" + prompt) - - process.expect_exact(prompt, timeout=10) - return process - - def run(self, commands: Union[str, List[str]]) -> str: - """ - Run commands in either an existing persistent - subprocess or on in a new subprocess environment. - - Args: - commands(List[str]): a list of commands to - execute in the session - """ # noqa: E501 - if isinstance(commands, str): - commands = [commands] - commands = ";".join(commands) - if self.process is not None: - return self._run_persistent( - commands, - ) - else: - return self._run(commands) - - def _run(self, command: str) -> str: - """ - Runs a command in a subprocess and returns - the output. - - Args: - command: The command to run - """ # noqa: E501 - try: - output = subprocess.run( - command, - shell=True, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ).stdout.decode() - except subprocess.CalledProcessError as error: - if self.return_err_output: - return error.stdout.decode() - return str(error) - if self.strip_newlines: - output = output.strip() - return output - - def process_output(self, output: str, command: str) -> str: - """ - Uses regex to remove the command from the output - - Args: - output: a process' output string - command: the executed command - """ # noqa: E501 - pattern = re.escape(command) + r"\s*\n" - output = re.sub(pattern, "", output, count=1) - return output.strip() - - def _run_persistent(self, command: str) -> str: - """ - Runs commands in a persistent environment - and returns the output. - - Args: - command: the command to execute - """ # noqa: E501 - pexpect = self._lazy_import_pexpect() - if self.process is None: - raise ValueError("Process not initialized") - self.process.sendline(command) - - # Clear the output with an empty string - self.process.expect(self.prompt, timeout=10) - self.process.sendline("") - - try: - self.process.expect([self.prompt, pexpect.EOF], timeout=10) - except pexpect.TIMEOUT: - return f"Timeout error while executing command {command}" - if self.process.after == pexpect.EOF: - return f"Exited with error status: {self.process.exitstatus}" - output = self.process.before - output = self.process_output(output, command) - if self.strip_newlines: - return output.strip() - return output diff --git a/libs/langchain/tests/unit_tests/chains/test_llm_bash.py b/libs/langchain/tests/unit_tests/chains/test_llm_bash.py deleted file mode 100644 index e6ee11d09f..0000000000 --- a/libs/langchain/tests/unit_tests/chains/test_llm_bash.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Test LLM Bash functionality.""" -import sys - -import pytest - -from langchain.chains.llm_bash.base import LLMBashChain -from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser -from langchain.schema import OutputParserException -from tests.unit_tests.llms.fake_llm import FakeLLM - -_SAMPLE_CODE = """ -Unrelated text -```bash -echo hello -``` -Unrelated text -""" - - -_SAMPLE_CODE_2_LINES = """ -Unrelated text -```bash -echo hello - -echo world -``` -Unrelated text -""" - - -@pytest.fixture -def output_parser() -> BashOutputParser: - """Output parser for testing.""" - return BashOutputParser() - - -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="Test not supported on Windows" -) -def test_simple_question() -> None: - """Test simple question that should not need python.""" - question = "Please write a bash script that prints 'Hello World' to the console." - prompt = _PROMPT_TEMPLATE.format(question=question) - queries = {prompt: "```bash\nexpr 1 + 1\n```"} - fake_llm = FakeLLM(queries=queries) - fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a") - output = fake_llm_bash_chain.run(question) - assert output == "2\n" - - -def test_get_code(output_parser: BashOutputParser) -> None: - """Test the parser.""" - code_lines = output_parser.parse(_SAMPLE_CODE) - code = [c for c in code_lines if c.strip()] - assert code == code_lines - assert code == ["echo hello"] - - code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES) - assert code_lines == ["echo hello", "echo hello", "echo world"] - - -def test_parsing_error() -> None: - """Test that LLM Output without a bash block raises an exce""" - question = "Please echo 'hello world' to the terminal." - prompt = _PROMPT_TEMPLATE.format(question=question) - queries = { - prompt: """ -```text -echo 'hello world' -``` -""" - } - fake_llm = FakeLLM(queries=queries) - fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a") - with pytest.raises(OutputParserException): - fake_llm_bash_chain.run(question) - - -def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None: - text = """ -Unrelated text -```bash -echo hello -ls && pwd && ls -``` - -```python -print("hello") -``` - -```bash -echo goodbye -``` -""" - code_lines = output_parser.parse(text) - assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"] - - -def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None: - """Test that backticks w/o a newline are ignored.""" - text = """ -Unrelated text -```bash -echo hello -echo "```bash is in this string```" -``` -""" - code_lines = output_parser.parse(text) - assert code_lines == ["echo hello", 'echo "```bash is in this string```"'] diff --git a/libs/langchain/tests/unit_tests/test_bash.py b/libs/langchain/tests/unit_tests/test_bash.py deleted file mode 100644 index 2b05dffab8..0000000000 --- a/libs/langchain/tests/unit_tests/test_bash.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Test the bash utility.""" -import re -import subprocess -import sys -from pathlib import Path - -import pytest - -from langchain.utilities.bash import BashProcess - - -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="Test not supported on Windows" -) -def test_pwd_command() -> None: - """Test correct functionality.""" - session = BashProcess() - commands = ["pwd"] - output = session.run(commands) - - assert output == subprocess.check_output("pwd", shell=True).decode() - - -@pytest.mark.skip(reason="flaky on GHA, TODO to fix") -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="Test not supported on Windows" -) -def test_pwd_command_persistent() -> None: - """Test correct functionality when the bash process is persistent.""" - session = BashProcess(persistent=True, strip_newlines=True) - commands = ["pwd"] - output = session.run(commands) - - assert subprocess.check_output("pwd", shell=True).decode().strip() in output - - session.run(["cd .."]) - new_output = session.run(["pwd"]) - # Assert that the new_output is a parent of the old output - assert Path(output).parent == Path(new_output) - - -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="Test not supported on Windows" -) -def test_incorrect_command() -> None: - """Test handling of incorrect command.""" - session = BashProcess() - output = session.run(["invalid_command"]) - assert output == "Command 'invalid_command' returned non-zero exit status 127." - - -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="Test not supported on Windows" -) -def test_incorrect_command_return_err_output() -> None: - """Test optional returning of shell output on incorrect command.""" - session = BashProcess(return_err_output=True) - output = session.run(["invalid_command"]) - assert re.match( - r"^/bin/sh:.*invalid_command.*(?:not found|Permission denied).*$", output - ) - - -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="Test not supported on Windows" -) -def test_create_directory_and_files(tmp_path: Path) -> None: - """Test creation of a directory and files in a temporary directory.""" - session = BashProcess(strip_newlines=True) - - # create a subdirectory in the temporary directory - temp_dir = tmp_path / "test_dir" - temp_dir.mkdir() - - # run the commands in the temporary directory - commands = [ - f"touch {temp_dir}/file1.txt", - f"touch {temp_dir}/file2.txt", - f"echo 'hello world' > {temp_dir}/file2.txt", - f"cat {temp_dir}/file2.txt", - ] - - output = session.run(commands) - assert output == "hello world" - - # check that the files were created in the temporary directory - output = session.run([f"ls {temp_dir}"]) - assert output == "file1.txt\nfile2.txt" - - -@pytest.mark.skip(reason="flaky on GHA, TODO to fix") -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="Test not supported on Windows" -) -def test_create_bash_persistent() -> None: - """Test the pexpect persistent bash terminal""" - session = BashProcess(persistent=True) - response = session.run("echo hello") - response += session.run("echo world") - - assert "hello" in response - assert "world" in response diff --git a/libs/langchain/tests/unit_tests/tools/shell/test_shell.py b/libs/langchain/tests/unit_tests/tools/shell/test_shell.py index c444203080..3d41907b14 100644 --- a/libs/langchain/tests/unit_tests/tools/shell/test_shell.py +++ b/libs/langchain/tests/unit_tests/tools/shell/test_shell.py @@ -1,4 +1,5 @@ import warnings +from typing import List import pytest @@ -22,8 +23,25 @@ def test_shell_input_validation() -> None: ) +class PlaceholderProcess: + def __init__(self, output: str = "") -> None: + self._commands: List[str] = [] + self.output = output + + def _run(self, commands: List[str]) -> str: + self._commands = commands + return self.output + + def run(self, commands: List[str]) -> str: + return self._run(commands) + + async def arun(self, commands: List[str]) -> str: + return self._run(commands) + + def test_shell_tool_init() -> None: - shell_tool = ShellTool() + placeholder = PlaceholderProcess() + shell_tool = ShellTool(process=placeholder) assert shell_tool.name == "terminal" assert isinstance(shell_tool.description, str) assert shell_tool.args_schema == ShellInput @@ -31,19 +49,22 @@ def test_shell_tool_init() -> None: def test_shell_tool_run() -> None: - shell_tool = ShellTool() + placeholder = PlaceholderProcess(output="hello") + shell_tool = ShellTool(process=placeholder) result = shell_tool._run(commands=test_commands) - assert result.strip() == "Hello, World!\nAnother command" + assert result.strip() == "hello" @pytest.mark.asyncio async def test_shell_tool_arun() -> None: - shell_tool = ShellTool() + placeholder = PlaceholderProcess(output="hello") + shell_tool = ShellTool(process=placeholder) result = await shell_tool._arun(commands=test_commands) - assert result.strip() == "Hello, World!\nAnother command" + assert result.strip() == "hello" def test_shell_tool_run_str() -> None: - shell_tool = ShellTool() + placeholder = PlaceholderProcess(output="hello") + shell_tool = ShellTool(process=placeholder) result = shell_tool._run(commands="echo 'Hello, World!'") - assert result.strip() == "Hello, World!" + assert result.strip() == "hello"