Optionally return shell output on incorrect command (#894) (#899)

This allows the LLM to correct its previous command by looking at the
error message output to the shell.

Additionally, this uses subprocess.run because that is now recommended
over subprocess.check_output:

https://docs.python.org/3/library/subprocess.html#using-the-subprocess-module

Co-authored-by: Amos Ng <me@amos.ng>
This commit is contained in:
Harrison Chase 2023-02-06 12:46:16 -08:00 committed by GitHub
parent 3aa53b44dd
commit 93a091cfb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 2 deletions

View File

@ -6,9 +6,10 @@ from typing import List, Union
class BashProcess: class BashProcess:
"""Executes bash commands and returns the output.""" """Executes bash commands and returns the output."""
def __init__(self, strip_newlines: bool = False): def __init__(self, strip_newlines: bool = False, return_err_output: bool = False):
"""Initialize with stripping newlines.""" """Initialize with stripping newlines."""
self.strip_newlines = strip_newlines self.strip_newlines = strip_newlines
self.return_err_output = return_err_output
def run(self, commands: Union[str, List[str]]) -> str: def run(self, commands: Union[str, List[str]]) -> str:
"""Run commands and return final output.""" """Run commands and return final output."""
@ -16,8 +17,16 @@ class BashProcess:
commands = [commands] commands = [commands]
commands = ";".join(commands) commands = ";".join(commands)
try: try:
output = subprocess.check_output(commands, shell=True).decode() output = subprocess.run(
commands,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
).stdout.decode()
except subprocess.CalledProcessError as error: except subprocess.CalledProcessError as error:
if self.return_err_output:
return error.stdout.decode()
return str(error) return str(error)
if self.strip_newlines: if self.strip_newlines:
output = output.strip() output = output.strip()

View File

@ -21,6 +21,13 @@ def test_incorrect_command() -> None:
assert output == "Command 'invalid_command' returned non-zero exit status 127." assert output == "Command 'invalid_command' returned non-zero exit status 127."
def test_incorrect_command_return_err_output() -> None:
"""Test optional returning of shell output on incorrect command."""
session = BashProcess(return_err_output=True)
output = session.run(["invalid_command"])
assert output == "/bin/sh: 1: invalid_command: not found\n"
def test_create_directory_and_files(tmp_path: Path) -> None: def test_create_directory_and_files(tmp_path: Path) -> None:
"""Test creation of a directory and files in a temporary directory.""" """Test creation of a directory and files in a temporary directory."""
session = BashProcess(strip_newlines=True) session = BashProcess(strip_newlines=True)