Fixed Ast Python Repl for Chatgpt multiline commands (#2406)

Resolves issue https://github.com/hwchase17/langchain/issues/2252

---------

Co-authored-by: Abhik Singla <abhiksingla@microsoft.com>
fix_agent_callbacks
Abhik Singla 1 year ago committed by GitHub
parent 1271c00ff0
commit 955bd2e1db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,6 +50,7 @@ class PythonAstREPLTool(BaseTool):
)
globals: Optional[Dict] = Field(default_factory=dict)
locals: Optional[Dict] = Field(default_factory=dict)
sanitize_input: bool = True
@root_validator(pre=True)
def validate_python_version(cls, values: Dict) -> Dict:
@ -65,6 +66,9 @@ class PythonAstREPLTool(BaseTool):
def _run(self, query: str) -> str:
"""Use the tool."""
try:
if self.sanitize_input:
# Remove the triple backticks from the query.
query = query.strip().strip("```")
tree = ast.parse(query)
module = ast.Module(tree.body[:-1], type_ignores=[])
exec(ast.unparse(module), self.globals, self.locals) # type: ignore

@ -1,7 +1,10 @@
"""Test functionality of Python REPL."""
import sys
import pytest
from langchain.python import PythonREPL
from langchain.tools.python.tool import PythonREPLTool
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
_SAMPLE_CODE = """
```
@ -11,6 +14,14 @@ multiply()
```
"""
_AST_SAMPLE_CODE = """
```
def multiply():
return(5*6)
multiply()
```
"""
def test_python_repl() -> None:
"""Test functionality when globals/locals are not provided."""
@ -60,6 +71,15 @@ def test_functionality_multiline() -> None:
assert output == "30\n"
def test_python_ast_repl_multiline() -> None:
"""Test correct functionality for ChatGPT multiline commands."""
if sys.version_info < (3, 9):
pytest.skip("Python 3.9+ is required for this test")
tool = PythonAstREPLTool()
output = tool.run(_AST_SAMPLE_CODE)
assert output == 30
def test_function() -> None:
"""Test correct functionality."""
chain = PythonREPL()

Loading…
Cancel
Save