diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index 0b2ac846..5cf91ad7 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -50,6 +50,7 @@ class PythonAstREPLTool(BaseTool): ) globals: Optional[Dict] = Field(default_factory=dict) locals: Optional[Dict] = Field(default_factory=dict) + sanitize_input: bool = True @root_validator(pre=True) def validate_python_version(cls, values: Dict) -> Dict: @@ -65,6 +66,9 @@ class PythonAstREPLTool(BaseTool): def _run(self, query: str) -> str: """Use the tool.""" try: + if self.sanitize_input: + # Remove the triple backticks from the query. + query = query.strip().strip("```") tree = ast.parse(query) module = ast.Module(tree.body[:-1], type_ignores=[]) exec(ast.unparse(module), self.globals, self.locals) # type: ignore diff --git a/tests/unit_tests/test_python.py b/tests/unit_tests/test_python.py index fab0c88d..9874a0e6 100644 --- a/tests/unit_tests/test_python.py +++ b/tests/unit_tests/test_python.py @@ -1,7 +1,10 @@ """Test functionality of Python REPL.""" +import sys + +import pytest from langchain.python import PythonREPL -from langchain.tools.python.tool import PythonREPLTool +from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool _SAMPLE_CODE = """ ``` @@ -11,6 +14,14 @@ multiply() ``` """ +_AST_SAMPLE_CODE = """ +``` +def multiply(): + return(5*6) +multiply() +``` +""" + def test_python_repl() -> None: """Test functionality when globals/locals are not provided.""" @@ -60,6 +71,15 @@ def test_functionality_multiline() -> None: assert output == "30\n" +def test_python_ast_repl_multiline() -> None: + """Test correct functionality for ChatGPT multiline commands.""" + if sys.version_info < (3, 9): + pytest.skip("Python 3.9+ is required for this test") + tool = PythonAstREPLTool() + output = tool.run(_AST_SAMPLE_CODE) + assert output == 30 + + def test_function() -> None: """Test correct functionality.""" chain = PythonREPL()