From 955bd2e1db8d008d628963cb8d2bad5c1d354744 Mon Sep 17 00:00:00 2001 From: Abhik Singla Date: Mon, 10 Apr 2023 21:25:03 -0700 Subject: [PATCH] Fixed Ast Python Repl for Chatgpt multiline commands (#2406) Resolves issue https://github.com/hwchase17/langchain/issues/2252 --------- Co-authored-by: Abhik Singla --- langchain/tools/python/tool.py | 4 ++++ tests/unit_tests/test_python.py | 22 +++++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index 0b2ac846..5cf91ad7 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -50,6 +50,7 @@ class PythonAstREPLTool(BaseTool): ) globals: Optional[Dict] = Field(default_factory=dict) locals: Optional[Dict] = Field(default_factory=dict) + sanitize_input: bool = True @root_validator(pre=True) def validate_python_version(cls, values: Dict) -> Dict: @@ -65,6 +66,9 @@ class PythonAstREPLTool(BaseTool): def _run(self, query: str) -> str: """Use the tool.""" try: + if self.sanitize_input: + # Remove the triple backticks from the query. + query = query.strip().strip("```") tree = ast.parse(query) module = ast.Module(tree.body[:-1], type_ignores=[]) exec(ast.unparse(module), self.globals, self.locals) # type: ignore diff --git a/tests/unit_tests/test_python.py b/tests/unit_tests/test_python.py index fab0c88d..9874a0e6 100644 --- a/tests/unit_tests/test_python.py +++ b/tests/unit_tests/test_python.py @@ -1,7 +1,10 @@ """Test functionality of Python REPL.""" +import sys + +import pytest from langchain.python import PythonREPL -from langchain.tools.python.tool import PythonREPLTool +from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool _SAMPLE_CODE = """ ``` @@ -11,6 +14,14 @@ multiply() ``` """ +_AST_SAMPLE_CODE = """ +``` +def multiply(): + return(5*6) +multiply() +``` +""" + def test_python_repl() -> None: """Test functionality when globals/locals are not provided.""" @@ -60,6 +71,15 @@ def test_functionality_multiline() -> None: assert output == "30\n" +def test_python_ast_repl_multiline() -> None: + """Test correct functionality for ChatGPT multiline commands.""" + if sys.version_info < (3, 9): + pytest.skip("Python 3.9+ is required for this test") + tool = PythonAstREPLTool() + output = tool.run(_AST_SAMPLE_CODE) + assert output == 30 + + def test_function() -> None: """Test correct functionality.""" chain = PythonREPL()