|
|
|
@ -19,6 +19,17 @@ def _get_default_python_repl() -> PythonREPL:
|
|
|
|
|
return PythonREPL(_globals=globals(), _locals=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_MD_PY_BLOCK = "```python"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sanitize_input(query: str) -> str:
|
|
|
|
|
query = query.strip()
|
|
|
|
|
if query[: len(_MD_PY_BLOCK)] == _MD_PY_BLOCK:
|
|
|
|
|
query = query[len(_MD_PY_BLOCK) :].strip()
|
|
|
|
|
query = query.strip("`").strip()
|
|
|
|
|
return query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PythonREPLTool(BaseTool):
|
|
|
|
|
"""A tool for running python code in a REPL."""
|
|
|
|
|
|
|
|
|
@ -39,7 +50,7 @@ class PythonREPLTool(BaseTool):
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""Use the tool."""
|
|
|
|
|
if self.sanitize_input:
|
|
|
|
|
query = query.strip().strip("```")
|
|
|
|
|
query = sanitize_input(query)
|
|
|
|
|
return self.python_repl.run(query)
|
|
|
|
|
|
|
|
|
|
async def _arun(
|
|
|
|
@ -84,8 +95,7 @@ class PythonAstREPLTool(BaseTool):
|
|
|
|
|
"""Use the tool."""
|
|
|
|
|
try:
|
|
|
|
|
if self.sanitize_input:
|
|
|
|
|
# Remove the triple backticks from the query.
|
|
|
|
|
query = query.strip().strip("```")
|
|
|
|
|
query = sanitize_input(query)
|
|
|
|
|
tree = ast.parse(query)
|
|
|
|
|
module = ast.Module(tree.body[:-1], type_ignores=[])
|
|
|
|
|
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
|
|
|
|