Return output of PythonAstREPLTool when falling back to exec() (#2780)

When the code ran by the PythonAstREPLTool contains multiple statements
it will fallback to exec() instead of using eval(). With this change, it
will also return the output of the code in the same way the
PythonREPLTool will.
This commit is contained in:
KullTC 2023-04-13 06:22:46 +02:00 committed by GitHub
parent 1bb0706955
commit 64596b23b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 2 deletions

View File

@ -2,6 +2,7 @@
import ast
import sys
from io import StringIO
from typing import Dict, Optional
from pydantic import Field, root_validator
@ -77,8 +78,16 @@ class PythonAstREPLTool(BaseTool):
try:
return eval(module_end_str, self.globals, self.locals)
except Exception:
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
try:
exec(module_end_str, self.globals, self.locals)
return ""
sys.stdout = old_stdout
output = mystdout.getvalue()
except Exception as e:
sys.stdout = old_stdout
output = str(e)
return output
except Exception as e:
return str(e)

View File

@ -22,6 +22,17 @@ multiply()
```
"""
_AST_SAMPLE_CODE_EXECUTE = """
```
def multiply(a, b):
return(5*6)
a = 5
b = 6
multiply(a, b)
```
"""
def test_python_repl() -> None:
"""Test functionality when globals/locals are not provided."""
@ -80,6 +91,16 @@ def test_python_ast_repl_multiline() -> None:
assert output == 30
def test_python_ast_repl_multi_statement() -> None:
"""Test correct functionality for ChatGPT multi statement commands."""
if sys.version_info < (3, 9):
pytest.skip("Python 3.9+ is required for this test")
tool = PythonAstREPLTool()
output = tool.run(_AST_SAMPLE_CODE_EXECUTE)
print(output)
assert output == 30
def test_function() -> None:
"""Test correct functionality."""
chain = PythonREPL()