|
|
@ -50,6 +50,7 @@ class PythonAstREPLTool(BaseTool):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
globals: Optional[Dict] = Field(default_factory=dict)
|
|
|
|
globals: Optional[Dict] = Field(default_factory=dict)
|
|
|
|
locals: Optional[Dict] = Field(default_factory=dict)
|
|
|
|
locals: Optional[Dict] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
sanitize_input: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
@root_validator(pre=True)
|
|
|
|
def validate_python_version(cls, values: Dict) -> Dict:
|
|
|
|
def validate_python_version(cls, values: Dict) -> Dict:
|
|
|
@ -65,6 +66,9 @@ class PythonAstREPLTool(BaseTool):
|
|
|
|
def _run(self, query: str) -> str:
|
|
|
|
def _run(self, query: str) -> str:
|
|
|
|
"""Use the tool."""
|
|
|
|
"""Use the tool."""
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
|
|
|
|
if self.sanitize_input:
|
|
|
|
|
|
|
|
# Remove the triple backticks from the query.
|
|
|
|
|
|
|
|
query = query.strip().strip("```")
|
|
|
|
tree = ast.parse(query)
|
|
|
|
tree = ast.parse(query)
|
|
|
|
module = ast.Module(tree.body[:-1], type_ignores=[])
|
|
|
|
module = ast.Module(tree.body[:-1], type_ignores=[])
|
|
|
|
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
|
|
|
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
|
|
|