diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index e9062ecc..e5de9c1b 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -19,6 +19,17 @@ def _get_default_python_repl() -> PythonREPL: return PythonREPL(_globals=globals(), _locals=None) +_MD_PY_BLOCK = "```python" + + +def sanitize_input(query: str) -> str: + query = query.strip() + if query[: len(_MD_PY_BLOCK)] == _MD_PY_BLOCK: + query = query[len(_MD_PY_BLOCK) :].strip() + query = query.strip("`").strip() + return query + + class PythonREPLTool(BaseTool): """A tool for running python code in a REPL.""" @@ -39,7 +50,7 @@ class PythonREPLTool(BaseTool): ) -> Any: """Use the tool.""" if self.sanitize_input: - query = query.strip().strip("```") + query = sanitize_input(query) return self.python_repl.run(query) async def _arun( @@ -84,8 +95,7 @@ class PythonAstREPLTool(BaseTool): """Use the tool.""" try: if self.sanitize_input: - # Remove the triple backticks from the query. - query = query.strip().strip("```") + query = sanitize_input(query) tree = ast.parse(query) module = ast.Module(tree.body[:-1], type_ignores=[]) exec(ast.unparse(module), self.globals, self.locals) # type: ignore diff --git a/tests/unit_tests/tools/python/test_python.py b/tests/unit_tests/tools/python/test_python.py index a44719c6..412e6946 100644 --- a/tests/unit_tests/tools/python/test_python.py +++ b/tests/unit_tests/tools/python/test_python.py @@ -3,7 +3,11 @@ import sys import pytest -from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool +from langchain.tools.python.tool import ( + PythonAstREPLTool, + PythonREPLTool, + sanitize_input, +) def test_python_repl_tool_single_input() -> None: @@ -21,3 +25,30 @@ def test_python_ast_repl_tool_single_input() -> None: tool = PythonAstREPLTool() assert tool.is_single_input assert tool.run("1 + 1") == 2 + + +def test_sanitize_input() -> None: + query = """ + ``` + p = 5 + ``` + """ + expected = "p = 5" + actual = sanitize_input(query) + assert expected == actual + + query = """ + ```python + p = 5 + ``` + """ + expected = "p = 5" + actual = sanitize_input(query) + assert expected == actual + + query = """ + p = 5 + """ + expected = "p = 5" + actual = sanitize_input(query) + assert expected == actual