Fixed Ast Python Repl for Chatgpt multiline commands (#2406)

Resolves issue https://github.com/hwchase17/langchain/issues/2252

---------

Co-authored-by: Abhik Singla <abhiksingla@microsoft.com>
fix_agent_callbacks
Abhik Singla 1 year ago committed by GitHub
parent 1271c00ff0
commit 955bd2e1db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,6 +50,7 @@ class PythonAstREPLTool(BaseTool):
) )
globals: Optional[Dict] = Field(default_factory=dict) globals: Optional[Dict] = Field(default_factory=dict)
locals: Optional[Dict] = Field(default_factory=dict) locals: Optional[Dict] = Field(default_factory=dict)
sanitize_input: bool = True
@root_validator(pre=True) @root_validator(pre=True)
def validate_python_version(cls, values: Dict) -> Dict: def validate_python_version(cls, values: Dict) -> Dict:
@ -65,6 +66,9 @@ class PythonAstREPLTool(BaseTool):
def _run(self, query: str) -> str: def _run(self, query: str) -> str:
"""Use the tool.""" """Use the tool."""
try: try:
if self.sanitize_input:
# Remove the triple backticks from the query.
query = query.strip().strip("```")
tree = ast.parse(query) tree = ast.parse(query)
module = ast.Module(tree.body[:-1], type_ignores=[]) module = ast.Module(tree.body[:-1], type_ignores=[])
exec(ast.unparse(module), self.globals, self.locals) # type: ignore exec(ast.unparse(module), self.globals, self.locals) # type: ignore

@ -1,7 +1,10 @@
"""Test functionality of Python REPL.""" """Test functionality of Python REPL."""
import sys
import pytest
from langchain.python import PythonREPL from langchain.python import PythonREPL
from langchain.tools.python.tool import PythonREPLTool from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
_SAMPLE_CODE = """ _SAMPLE_CODE = """
``` ```
@ -11,6 +14,14 @@ multiply()
``` ```
""" """
_AST_SAMPLE_CODE = """
```
def multiply():
return(5*6)
multiply()
```
"""
def test_python_repl() -> None: def test_python_repl() -> None:
"""Test functionality when globals/locals are not provided.""" """Test functionality when globals/locals are not provided."""
@ -60,6 +71,15 @@ def test_functionality_multiline() -> None:
assert output == "30\n" assert output == "30\n"
def test_python_ast_repl_multiline() -> None:
"""Test correct functionality for ChatGPT multiline commands."""
if sys.version_info < (3, 9):
pytest.skip("Python 3.9+ is required for this test")
tool = PythonAstREPLTool()
output = tool.run(_AST_SAMPLE_CODE)
assert output == 30
def test_function() -> None: def test_function() -> None:
"""Test correct functionality.""" """Test correct functionality."""
chain = PythonREPL() chain = PythonREPL()

Loading…
Cancel
Save