You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/experimental/tests/unit_tests/test_python.py

35 lines
1.3 KiB
Python

experimental: clean python repl input(experimental:Added code for PythonREPL) (#20930) Update python.py(experimental:Added code for PythonREPL) Added code for PythonREPL, defining a static method 'sanitize_input' that takes the string 'query' as input and returns a sanitizing string. The purpose of this method is to remove unwanted characters from the input string, Specifically: 1. Delete the whitespace at the beginning and end of the string (' \s'). 2. Remove the quotation marks (`` ` ``) at the beginning and end of the string. 3. Remove the keyword "python" at the beginning of the string (case insensitive) because the user may have typed it. This method uses regular expressions (regex) to implement sanitizing. It all started with this code: from langchain.agents import Tool from langchain_experimental.utilities import PythonREPL python_repl = PythonREPL() repl_tool = Tool( name="python_repl", description="Remove redundant formatting marks at the beginning and end of source code from input.Use a Python shell to execute python commands. If you want to see the output of a value, you should print it out with `print(...)`.", func=python_repl.run, ) When I call the agent to write a piece of code for me and execute it with the defined code, I must get an error: SyntaxError('invalid syntax', ('<string>', 1, 1,'In', 1, 2)) After checking, I found that pythonREPL has less formatting of input code than the soon-to-be deprecated pythonREPL tool, so I added this step to it, so that no matter what code I ask the agent to write for me, it can be executed smoothly and get the output result. I have tried modifying the prompt words to solve this problem before, but it did not work, and by adding a simple format check, the problem is well resolved. <img width="1271" alt="image" src="https://github.com/langchain-ai/langchain/assets/164149097/c49a685f-d246-4b11-b655-fd952fc2f04c"> --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
5 months ago
import unittest
from langchain_experimental.utilities import PythonREPL
class TestSanitizeInput(unittest.TestCase):
def test_whitespace_removal(self) -> None:
query = " print('Hello, world!') "
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
def test_python_removal(self) -> None:
query = "python print('Hello, world!') "
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
def test_backtick_removal(self) -> None:
query = "`print('Hello, world!')`"
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
def test_combined_removal(self) -> None:
query = " `python print('Hello, world!')` "
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
def test_mixed_case_removal(self) -> None:
query = " pYtHoN print('Hello, world!') "
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
if __name__ == "__main__":
unittest.main()