From 64596b23b99bb028acb9e8fe6a7eb27fea307311 Mon Sep 17 00:00:00 2001 From: KullTC Date: Thu, 13 Apr 2023 06:22:46 +0200 Subject: [PATCH] Return output of PythonAstREPLTool when falling back to exec() (#2780) When the code ran by the PythonAstREPLTool contains multiple statements it will fallback to exec() instead of using eval(). With this change, it will also return the output of the code in the same way the PythonREPLTool will. --- langchain/tools/python/tool.py | 13 +++++++++++-- tests/unit_tests/test_python.py | 21 +++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index 5cf91ad7..8f32643d 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -2,6 +2,7 @@ import ast import sys +from io import StringIO from typing import Dict, Optional from pydantic import Field, root_validator @@ -77,8 +78,16 @@ class PythonAstREPLTool(BaseTool): try: return eval(module_end_str, self.globals, self.locals) except Exception: - exec(module_end_str, self.globals, self.locals) - return "" + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + try: + exec(module_end_str, self.globals, self.locals) + sys.stdout = old_stdout + output = mystdout.getvalue() + except Exception as e: + sys.stdout = old_stdout + output = str(e) + return output except Exception as e: return str(e) diff --git a/tests/unit_tests/test_python.py b/tests/unit_tests/test_python.py index 9874a0e6..e133cd2f 100644 --- a/tests/unit_tests/test_python.py +++ b/tests/unit_tests/test_python.py @@ -22,6 +22,17 @@ multiply() ``` """ +_AST_SAMPLE_CODE_EXECUTE = """ +``` +def multiply(a, b): + return(5*6) +a = 5 +b = 6 + +multiply(a, b) +``` +""" + def test_python_repl() -> None: """Test functionality when globals/locals are not provided.""" @@ -80,6 +91,16 @@ def test_python_ast_repl_multiline() -> None: assert output == 30 +def test_python_ast_repl_multi_statement() -> None: + """Test correct functionality for ChatGPT multi statement commands.""" + if sys.version_info < (3, 9): + pytest.skip("Python 3.9+ is required for this test") + tool = PythonAstREPLTool() + output = tool.run(_AST_SAMPLE_CODE_EXECUTE) + print(output) + assert output == 30 + + def test_function() -> None: """Test correct functionality.""" chain = PythonREPL()