From 49ca02711e0f40eb0d893934d390eadf2fc39476 Mon Sep 17 00:00:00 2001 From: Deepak S V <42609308+svdeepak99@users.noreply.github.com> Date: Mon, 22 May 2023 09:43:44 -0400 Subject: [PATCH] Improved query, print & exception handling in REPL Tool (#4997) Update to pull request https://github.com/hwchase17/langchain/pull/3215 Summary: 1) Improved the sanitization of query (using regex), by removing python command (since gpt-3.5-turbo sometimes assumes python console as a terminal, and runs python command first which causes error). Also sometimes 1 line python codes contain single backticks. 2) Added 7 new test cases. For more details, view the previous pull request. --------- Co-authored-by: Deepak S V --- langchain/tools/python/tool.py | 34 +++--- tests/unit_tests/tools/python/test_python.py | 110 +++++++++++++++++++ 2 files changed, 127 insertions(+), 17 deletions(-) diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index e5de9c1b..c53904e5 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -1,7 +1,9 @@ """A tool for running python code in a REPL.""" import ast +import re import sys +from contextlib import redirect_stdout from io import StringIO from typing import Any, Dict, Optional @@ -19,14 +21,13 @@ def _get_default_python_repl() -> PythonREPL: return PythonREPL(_globals=globals(), _locals=None) -_MD_PY_BLOCK = "```python" - - def sanitize_input(query: str) -> str: - query = query.strip() - if query[: len(_MD_PY_BLOCK)] == _MD_PY_BLOCK: - query = query[len(_MD_PY_BLOCK) :].strip() - query = query.strip("`").strip() + # Remove whitespace, backtick & python (if llm mistakes python console as terminal) + + # Removes `, whitespace & python from start + query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query) + # Removes whitespace & ` from end + query = re.sub(r"(\s|`)*$", "", query) return query @@ -101,19 +102,18 @@ class PythonAstREPLTool(BaseTool): exec(ast.unparse(module), self.globals, self.locals) # type: ignore module_end = ast.Module(tree.body[-1:], type_ignores=[]) module_end_str = ast.unparse(module_end) # type: ignore + io_buffer = StringIO() try: - return eval(module_end_str, self.globals, self.locals) + with redirect_stdout(io_buffer): + ret = eval(module_end_str, self.globals, self.locals) + if ret is None: + return io_buffer.getvalue() + else: + return ret except Exception: - old_stdout = sys.stdout - sys.stdout = mystdout = StringIO() - try: + with redirect_stdout(io_buffer): exec(module_end_str, self.globals, self.locals) - sys.stdout = old_stdout - output = mystdout.getvalue() - except Exception as e: - sys.stdout = old_stdout - output = repr(e) - return output + return io_buffer.getvalue() except Exception as e: return "{}: {}".format(type(e).__name__, str(e)) diff --git a/tests/unit_tests/tools/python/test_python.py b/tests/unit_tests/tools/python/test_python.py index 412e6946..c46f168f 100644 --- a/tests/unit_tests/tools/python/test_python.py +++ b/tests/unit_tests/tools/python/test_python.py @@ -1,6 +1,7 @@ """Test Python REPL Tools.""" import sys +import numpy as np import pytest from langchain.tools.python.tool import ( @@ -17,6 +18,18 @@ def test_python_repl_tool_single_input() -> None: assert int(tool.run("print(1 + 1)").strip()) == 2 +def test_python_repl_print() -> None: + program = """ +import numpy as np +v1 = np.array([1, 2, 3]) +v2 = np.array([4, 5, 6]) +dot_product = np.dot(v1, v2) +print("The dot product is {:d}.".format(dot_product)) + """ + tool = PythonREPLTool() + assert tool.run(program) == "The dot product is 32.\n" + + @pytest.mark.skipif( sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." ) @@ -27,6 +40,103 @@ def test_python_ast_repl_tool_single_input() -> None: assert tool.run("1 + 1") == 2 +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_return() -> None: + program = """ +``` +import numpy as np +v1 = np.array([1, 2, 3]) +v2 = np.array([4, 5, 6]) +dot_product = np.dot(v1, v2) +int(dot_product) +``` + """ + tool = PythonAstREPLTool() + assert tool.run(program) == 32 + + program = """ +```python +import numpy as np +v1 = np.array([1, 2, 3]) +v2 = np.array([4, 5, 6]) +dot_product = np.dot(v1, v2) +int(dot_product) +``` + """ + assert tool.run(program) == 32 + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_print() -> None: + program = """python +string = "racecar" +if string == string[::-1]: + print(string, "is a palindrome") +else: + print(string, "is not a palindrome")""" + tool = PythonAstREPLTool() + assert tool.run(program) == "racecar is a palindrome\n" + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_repl_print_python_backticks() -> None: + program = "`print('`python` is a great language.')`" + tool = PythonAstREPLTool() + assert tool.run(program) == "`python` is a great language.\n" + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_raise_exception() -> None: + data = {"Name": ["John", "Alice"], "Age": [30, 25]} + program = """ +import pandas as pd +df = pd.DataFrame(data) +df['Gender'] + """ + tool = PythonAstREPLTool(locals={"data": data}) + expected_outputs = ( + "KeyError: 'Gender'", + "ModuleNotFoundError: No module named 'pandas'", + ) + assert tool.run(program) in expected_outputs + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_one_line_print() -> None: + program = 'print("The square of {} is {:.2f}".format(3, 3**2))' + tool = PythonAstREPLTool() + assert tool.run(program) == "The square of 3 is 9.00\n" + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_one_line_return() -> None: + arr = np.array([1, 2, 3, 4, 5]) + tool = PythonAstREPLTool(locals={"arr": arr}) + program = "`(arr**2).sum() # Returns sum of squares`" + assert tool.run(program) == 55 + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_one_line_exception() -> None: + program = "[1, 2, 3][4]" + tool = PythonAstREPLTool() + assert tool.run(program) == "IndexError: list index out of range" + + def test_sanitize_input() -> None: query = """ ```