langchain/tests/unit_tests/test_python.py
Abhik Singla 955bd2e1db
Fixed Ast Python Repl for Chatgpt multiline commands (#2406)
Resolves issue https://github.com/hwchase17/langchain/issues/2252

---------

Co-authored-by: Abhik Singla <abhiksingla@microsoft.com>
2023-04-10 21:25:03 -07:00

93 lines
2.3 KiB
Python

"""Test functionality of Python REPL."""
import sys
import pytest
from langchain.python import PythonREPL
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
_SAMPLE_CODE = """
```
def multiply():
print(5*6)
multiply()
```
"""
_AST_SAMPLE_CODE = """
```
def multiply():
return(5*6)
multiply()
```
"""
def test_python_repl() -> None:
"""Test functionality when globals/locals are not provided."""
repl = PythonREPL()
# Run a simple initial command.
repl.run("foo = 1")
assert repl.locals is not None
assert repl.locals["foo"] == 1
# Now run a command that accesses `foo` to make sure it still has it.
repl.run("bar = foo * 2")
assert repl.locals is not None
assert repl.locals["bar"] == 2
def test_python_repl_no_previous_variables() -> None:
"""Test that it does not have access to variables created outside the scope."""
foo = 3 # noqa: F841
repl = PythonREPL()
output = repl.run("print(foo)")
assert output == "name 'foo' is not defined"
def test_python_repl_pass_in_locals() -> None:
"""Test functionality when passing in locals."""
_locals = {"foo": 4}
repl = PythonREPL(_locals=_locals)
repl.run("bar = foo * 2")
assert repl.locals is not None
assert repl.locals["bar"] == 8
def test_functionality() -> None:
"""Test correct functionality."""
chain = PythonREPL()
code = "print(1 + 1)"
output = chain.run(code)
assert output == "2\n"
def test_functionality_multiline() -> None:
"""Test correct functionality for ChatGPT multiline commands."""
chain = PythonREPL()
tool = PythonREPLTool(python_repl=chain)
output = tool.run(_SAMPLE_CODE)
assert output == "30\n"
def test_python_ast_repl_multiline() -> None:
"""Test correct functionality for ChatGPT multiline commands."""
if sys.version_info < (3, 9):
pytest.skip("Python 3.9+ is required for this test")
tool = PythonAstREPLTool()
output = tool.run(_AST_SAMPLE_CODE)
assert output == 30
def test_function() -> None:
"""Test correct functionality."""
chain = PythonREPL()
code = "def add(a, b): " " return a + b"
output = chain.run(code)
assert output == ""
code = "print(add(1, 2))"
output = chain.run(code)
assert output == "3\n"