diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index 5cf91ad7..8f32643d 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -2,6 +2,7 @@ import ast import sys +from io import StringIO from typing import Dict, Optional from pydantic import Field, root_validator @@ -77,8 +78,16 @@ class PythonAstREPLTool(BaseTool): try: return eval(module_end_str, self.globals, self.locals) except Exception: - exec(module_end_str, self.globals, self.locals) - return "" + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + try: + exec(module_end_str, self.globals, self.locals) + sys.stdout = old_stdout + output = mystdout.getvalue() + except Exception as e: + sys.stdout = old_stdout + output = str(e) + return output except Exception as e: return str(e) diff --git a/tests/unit_tests/test_python.py b/tests/unit_tests/test_python.py index 9874a0e6..e133cd2f 100644 --- a/tests/unit_tests/test_python.py +++ b/tests/unit_tests/test_python.py @@ -22,6 +22,17 @@ multiply() ``` """ +_AST_SAMPLE_CODE_EXECUTE = """ +``` +def multiply(a, b): + return(5*6) +a = 5 +b = 6 + +multiply(a, b) +``` +""" + def test_python_repl() -> None: """Test functionality when globals/locals are not provided.""" @@ -80,6 +91,16 @@ def test_python_ast_repl_multiline() -> None: assert output == 30 +def test_python_ast_repl_multi_statement() -> None: + """Test correct functionality for ChatGPT multi statement commands.""" + if sys.version_info < (3, 9): + pytest.skip("Python 3.9+ is required for this test") + tool = PythonAstREPLTool() + output = tool.run(_AST_SAMPLE_CODE_EXECUTE) + print(output) + assert output == 30 + + def test_function() -> None: """Test correct functionality.""" chain = PythonREPL()