From acfda4d1d8b3cd98de381ff58ba7fd6b91c6c204 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 1 Apr 2023 12:54:06 -0700 Subject: [PATCH] Harrison/multiline commands (#2280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marc Päpper --- langchain/tools/python/tool.py | 3 +++ tests/unit_tests/test_python.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index 0a1adb4628..0b2ac8460b 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -25,9 +25,12 @@ class PythonREPLTool(BaseTool): "with `print(...)`." ) python_repl: PythonREPL = Field(default_factory=_get_default_python_repl) + sanitize_input: bool = True def _run(self, query: str) -> str: """Use the tool.""" + if self.sanitize_input: + query = query.strip().strip("```") return self.python_repl.run(query) async def _arun(self, query: str) -> str: diff --git a/tests/unit_tests/test_python.py b/tests/unit_tests/test_python.py index d1eb3ae921..fab0c88d20 100644 --- a/tests/unit_tests/test_python.py +++ b/tests/unit_tests/test_python.py @@ -1,6 +1,15 @@ """Test functionality of Python REPL.""" from langchain.python import PythonREPL +from langchain.tools.python.tool import PythonREPLTool + +_SAMPLE_CODE = """ +``` +def multiply(): + print(5*6) +multiply() +``` +""" def test_python_repl() -> None: @@ -43,6 +52,14 @@ def test_functionality() -> None: assert output == "2\n" +def test_functionality_multiline() -> None: + """Test correct functionality for ChatGPT multiline commands.""" + chain = PythonREPL() + tool = PythonREPLTool(python_repl=chain) + output = tool.run(_SAMPLE_CODE) + assert output == "30\n" + + def test_function() -> None: """Test correct functionality.""" chain = PythonREPL()