Integrate E2B's data analysis/code interpreter (#12011)

This PR adds a data [E2B's](https://e2b.dev/) analysis/code interpreter
sandbox as a tool

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Jakub Novak <jakub@e2b.dev>
pull/12243/head
Vasek Mlejnsky 12 months ago committed by GitHub
parent d2cb95c39d
commit 1f8094938f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,367 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# E2B Data Analysis\n",
"\n",
"[E2B's cloud environments](https://e2b.dev) are great runtime sandboxes for LLMs.\n",
"\n",
"E2B's Data Analysis sandbox allows for safe code execution in a sandboxed environment. This is ideal for building tools such as code interpreters, or Advanced Data Analysis like in ChatGPT.\n",
"\n",
"E2B Data Analysis sandbox allows you to:\n",
"- Run Python code\n",
"- Generate charts via matplotlib\n",
"- Install Python packages dynamically durint runtime\n",
"- Install system packages dynamically during runtime\n",
"- Run shell commands\n",
"- Upload and download files\n",
"\n",
"We'll create a simple OpenAI agent that will use E2B's Data Analysis sandbox to perform analysis on a uploaded files using Python."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Get your OpenAI API key and [E2B API key here](https://e2b.dev/docs/getting-started/api-key) and set them as environment variables.\n",
"\n",
"You can find the full API documentation [here](https://e2b.dev/docs).\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You'll need to install `e2b` to get started:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install langchain e2b"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.tools import E2BDataAnalysisTool\n",
"from langchain.agents import initialize_agent, AgentType\n",
"\n",
"os.environ[\"E2B_API_KEY\"] = \"<E2B_API_KEY>\"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"<OPENAI_API_KEY>\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"When creating an instance of the `E2BDataAnalysisTool`, you can pass callbacks to listen to the output of the sandbox. This is useful, for example, when creating more responsive UI. Especially with the combination of streaming output from LLMs."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Artifacts are charts created by matplotlib when `plt.show()` is called\n",
"def save_artifact(artifact):\n",
" print(\"New matplotlib chart generated:\", artifact.name)\n",
" # Download the artifact as `bytes` and leave it up to the user to display them (on frontend, for example)\n",
" file = artifact.download()\n",
" basename = os.path.basename(artifact.name)\n",
"\n",
" # Save the chart to the `charts` directory\n",
" with open(f\"./charts/{basename}\", \"wb\") as f:\n",
" f.write(file)\n",
"\n",
"e2b_data_analysis_tool = E2BDataAnalysisTool(\n",
" # Pass environment variables to the sandbox\n",
" env_vars={\"MY_SECRET\": \"secret_value\"},\n",
" on_stdout=lambda stdout: print(\"stdout:\", stdout),\n",
" on_stderr=lambda stderr: print(\"stderr:\", stderr),\n",
" on_artifact=save_artifact,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Upload an example CSV data file to the sandbox so we can analyze it with our agent. You can use for example [this file](https://storage.googleapis.com/e2b-examples/netflix.csv) about Netflix tv shows."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"name='netflix.csv' remote_path='/home/user/netflix.csv' description='Data about Netflix tv shows including their title, category, director, release date, casting, age rating, etc.'\n"
]
}
],
"source": [
"with open(\"./netflix.csv\") as f:\n",
" remote_path = e2b_data_analysis_tool.upload_file(\n",
" file=f,\n",
" description=\"Data about Netflix tv shows including their title, category, director, release date, casting, age rating, etc.\",\n",
" )\n",
" print(remote_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a `Tool` object and initialize the Langchain agent."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"tools = [e2b_data_analysis_tool.as_tool()]\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-4\", temperature=0)\n",
"agent = initialize_agent(\n",
" tools, llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=True, handle_parsing_errors=True\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can ask the agent questions about the CSV file we uploaded earlier."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `e2b_data_analysis` with `{'python_code': \"import pandas as pd\\n\\n# Load the data\\nnetflix_data = pd.read_csv('/home/user/netflix.csv')\\n\\n# Convert the 'release_year' column to integer\\nnetflix_data['release_year'] = netflix_data['release_year'].astype(int)\\n\\n# Filter the data for movies released between 2000 and 2010\\nfiltered_data = netflix_data[(netflix_data['release_year'] >= 2000) & (netflix_data['release_year'] <= 2010) & (netflix_data['type'] == 'Movie')]\\n\\n# Remove rows where 'duration' is not available\\nfiltered_data = filtered_data[filtered_data['duration'].notna()]\\n\\n# Convert the 'duration' column to integer\\nfiltered_data['duration'] = filtered_data['duration'].str.replace(' min','').astype(int)\\n\\n# Get the top 5 longest movies\\nlongest_movies = filtered_data.nlargest(5, 'duration')\\n\\n# Create a bar chart\\nimport matplotlib.pyplot as plt\\n\\nplt.figure(figsize=(10,5))\\nplt.barh(longest_movies['title'], longest_movies['duration'], color='skyblue')\\nplt.xlabel('Duration (minutes)')\\nplt.title('Top 5 Longest Movies on Netflix (2000-2010)')\\nplt.gca().invert_yaxis()\\nplt.savefig('/home/user/longest_movies.png')\\n\\nlongest_movies[['title', 'duration']]\"}`\n",
"\n",
"\n",
"\u001b[0mstdout: title duration\n",
"stdout: 1019 Lagaan 224\n",
"stdout: 4573 Jodhaa Akbar 214\n",
"stdout: 2731 Kabhi Khushi Kabhie Gham 209\n",
"stdout: 2632 No Direction Home: Bob Dylan 208\n",
"stdout: 2126 What's Your Raashee? 203\n",
"\u001b[36;1m\u001b[1;3m{'stdout': \" title duration\\n1019 Lagaan 224\\n4573 Jodhaa Akbar 214\\n2731 Kabhi Khushi Kabhie Gham 209\\n2632 No Direction Home: Bob Dylan 208\\n2126 What's Your Raashee? 203\", 'stderr': ''}\u001b[0m\u001b[32;1m\u001b[1;3mThe 5 longest movies on Netflix released between 2000 and 2010 are:\n",
"\n",
"1. Lagaan - 224 minutes\n",
"2. Jodhaa Akbar - 214 minutes\n",
"3. Kabhi Khushi Kabhie Gham - 209 minutes\n",
"4. No Direction Home: Bob Dylan - 208 minutes\n",
"5. What's Your Raashee? - 203 minutes\n",
"\n",
"Here is the chart showing their lengths:\n",
"\n",
"![Longest Movies](sandbox:/home/user/longest_movies.png)\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"\"The 5 longest movies on Netflix released between 2000 and 2010 are:\\n\\n1. Lagaan - 224 minutes\\n2. Jodhaa Akbar - 214 minutes\\n3. Kabhi Khushi Kabhie Gham - 209 minutes\\n4. No Direction Home: Bob Dylan - 208 minutes\\n5. What's Your Raashee? - 203 minutes\\n\\nHere is the chart showing their lengths:\\n\\n![Longest Movies](sandbox:/home/user/longest_movies.png)\""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"What are the 5 longest movies on netflix released between 2000 and 2010? Create a chart with their lengths.\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"E2B also allows you to install both Python and system (via `apt`) packages dynamically during runtime like this:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"stdout: Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (2.1.1)\n",
"stdout: Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n",
"stdout: Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.3.post1)\n",
"stdout: Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from pandas) (1.26.1)\n",
"stdout: Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.3)\n",
"stdout: Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n"
]
}
],
"source": [
"# Install Python package\n",
"e2b_data_analysis_tool.install_python_packages('pandas')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Additionally, you can download any file from the sandbox like this:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# The path is a remote path in the sandbox\n",
"files_in_bytes = e2b_data_analysis_tool.download_file('/home/user/netflix.csv')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Lastly, you can run any shell command inside the sandbox via `run_command`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"stderr: \n",
"stderr: WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
"stderr: \n",
"stdout: Hit:1 http://security.ubuntu.com/ubuntu jammy-security InRelease\n",
"stdout: Hit:2 http://archive.ubuntu.com/ubuntu jammy InRelease\n",
"stdout: Hit:3 http://archive.ubuntu.com/ubuntu jammy-updates InRelease\n",
"stdout: Hit:4 http://archive.ubuntu.com/ubuntu jammy-backports InRelease\n",
"stdout: Reading package lists...\n",
"stdout: Building dependency tree...\n",
"stdout: Reading state information...\n",
"stdout: All packages are up to date.\n",
"stdout: Reading package lists...\n",
"stdout: Building dependency tree...\n",
"stdout: Reading state information...\n",
"stdout: Suggested packages:\n",
"stdout: sqlite3-doc\n",
"stdout: The following NEW packages will be installed:\n",
"stdout: sqlite3\n",
"stdout: 0 upgraded, 1 newly installed, 0 to remove and 0 not upgraded.\n",
"stdout: Need to get 768 kB of archives.\n",
"stdout: After this operation, 1873 kB of additional disk space will be used.\n",
"stdout: Get:1 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 sqlite3 amd64 3.37.2-2ubuntu0.1 [768 kB]\n",
"stderr: debconf: delaying package configuration, since apt-utils is not installed\n",
"stdout: Fetched 768 kB in 0s (2258 kB/s)\n",
"stdout: Selecting previously unselected package sqlite3.\n",
"(Reading database ... 23999 files and directories currently installed.)\n",
"stdout: Preparing to unpack .../sqlite3_3.37.2-2ubuntu0.1_amd64.deb ...\n",
"stdout: Unpacking sqlite3 (3.37.2-2ubuntu0.1) ...\n",
"stdout: Setting up sqlite3 (3.37.2-2ubuntu0.1) ...\n",
"stdout: 3.37.2 2022-01-06 13:25:41 872ba256cbf61d9290b571c0e6d82a20c224ca3ad82971edc46b29818d5dalt1\n",
"version: 3.37.2 2022-01-06 13:25:41 872ba256cbf61d9290b571c0e6d82a20c224ca3ad82971edc46b29818d5dalt1\n",
"error: \n",
"exit code: 0\n"
]
}
],
"source": [
"# Install SQLite\n",
"e2b_data_analysis_tool.run_command(\"sudo apt update\")\n",
"e2b_data_analysis_tool.install_system_packages(\"sqlite3\")\n",
"\n",
"# Check the SQLite version\n",
"output = e2b_data_analysis_tool.run_command(\"sqlite3 --version\")\n",
"print(\"version: \", output[\"stdout\"])\n",
"print(\"error: \", output[\"stderr\"])\n",
"print(\"exit code: \", output[\"exit_code\"])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"When your agent is finished, don't forget to close the sandbox"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"e2b_data_analysis_tool.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -639,6 +639,12 @@ def _import_bearly_tool() -> Any:
return BearlyInterpreterTool
def _import_e2b_data_analysis() -> Any:
from langchain.tools.e2b_data_analysis.tool import E2BDataAnalysisTool
return E2BDataAnalysisTool
def __getattr__(name: str) -> Any:
if name == "AINAppOps":
return _import_ainetwork_app()
@ -846,6 +852,8 @@ def __getattr__(name: str) -> Any:
return _import_zapier_tool_ZapierNLARunAction()
elif name == "BearlyInterpreterTool":
return _import_bearly_tool()
elif name == "E2BDataAnalysisTool":
return _import_e2b_data_analysis()
else:
raise AttributeError(f"Could not find: {name}")
@ -958,4 +966,5 @@ __all__ = [
"tool",
"format_tool_to_openai_function",
"BearlyInterpreterTool",
"E2BDataAnalysisTool",
]

@ -0,0 +1,231 @@
from __future__ import annotations
import ast
import json
import os
from io import StringIO
from sys import version_info
from typing import IO, TYPE_CHECKING, Any, Callable, List, Optional, Type
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, Tool
from langchain.tools.e2b_data_analysis.unparse import Unparser
if TYPE_CHECKING:
from e2b import EnvVars
from e2b.templates.data_analysis import Artifact
base_description = """Evaluates python code in a sandbox environment. \
The environment is long running and exists across multiple executions. \
You must send the whole script every time and print your outputs. \
Script should be pure python code that can be evaluated. \
It should be in python format NOT markdown. \
The code should NOT be wrapped in backticks. \
All python packages including requests, matplotlib, scipy, numpy, pandas, \
etc are available. \
If you have any files outputted write them to "/home/user" directory \
path."""
def _unparse(tree: ast.AST) -> str:
"""Unparse the AST."""
if version_info.minor < 9:
s = StringIO()
Unparser(tree, file=s)
source_code = s.getvalue()
s.close()
else:
source_code = ast.unparse(tree) # type: ignore[attr-defined]
return source_code
def add_last_line_print(code: str) -> str:
"""Add print statement to the last line if it's missing.
Sometimes, the LLM-generated code doesn't have `print(variable_name)`, instead the
LLM tries to print the variable only by writing `variable_name` (as you would in
REPL, for example).
This methods checks the AST of the generated Python code and adds the print
statement to the last line if it's missing.
"""
tree = ast.parse(code)
node = tree.body[-1]
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name) and node.value.func.id == "print":
return _unparse(tree)
if isinstance(node, ast.Expr):
tree.body[-1] = ast.Expr(
value=ast.Call(
func=ast.Name(id="print", ctx=ast.Load()),
args=[node.value],
keywords=[],
)
)
return _unparse(tree)
class UploadedFile(BaseModel):
"""Description of the uploaded path with its remote path."""
name: str
remote_path: str
description: str
class E2BDataAnalysisToolArguments(BaseModel):
"""Arguments for the E2BDataAnalysisTool."""
python_code: str = Field(
...,
example="print('Hello World')",
description=(
"The python script to be evaluated. "
"The contents will be in main.py. "
"It should not be in markdown format."
),
)
class E2BDataAnalysisTool(BaseTool):
"""Tool for running python code in a sandboxed environment for data analysis."""
name = "e2b_data_analysis"
args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments
session: Any
_uploaded_files: List[UploadedFile] = Field(default_factory=list)
def __init__(
self,
api_key: Optional[str] = None,
cwd: Optional[str] = None,
env_vars: Optional[EnvVars] = None,
on_stdout: Optional[Callable[[str], Any]] = None,
on_stderr: Optional[Callable[[str], Any]] = None,
on_artifact: Optional[Callable[[Artifact], Any]] = None,
on_exit: Optional[Callable[[int], Any]] = None,
**kwargs: Any,
):
try:
from e2b import DataAnalysis
except ImportError as e:
raise ImportError(
"Unable to import e2b, please install with `pip install e2b`."
) from e
# If no API key is provided, E2B will try to read it from the environment
# variable E2B_API_KEY
session = DataAnalysis(
api_key=api_key,
cwd=cwd,
env_vars=env_vars,
on_stdout=on_stdout,
on_stderr=on_stderr,
on_exit=on_exit,
on_artifact=on_artifact,
)
super().__init__(session=session, **kwargs)
self.description = (
base_description + "\n\n" + self.uploaded_files_description
).strip()
def close(self) -> None:
"""Close the cloud sandbox."""
self._uploaded_files = []
self.session.close()
@property
def uploaded_files_description(self) -> str:
if len(self._uploaded_files) == 0:
return ""
lines = ["The following files available in the sandbox:"]
for f in self._uploaded_files:
if f.description == "":
lines.append(f"- path: `{f.remote_path}`")
else:
lines.append(
f"- path: `{f.remote_path}` \n description: `{f.description}`"
)
return "\n".join(lines)
def _run(
self, python_code: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
python_code = add_last_line_print(python_code)
stdout, stderr, _ = self.session.run_python(python_code)
out = {
"stdout": stdout,
"stderr": stderr,
}
return json.dumps(out)
async def _arun(
self,
python_code: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("e2b_data_analysis does not support async")
def run_command(
self,
cmd: str,
) -> dict:
"""Run shell command in the sandbox."""
proc = self.session.process.start(cmd)
output = proc.wait()
return {
"stdout": output.stdout,
"stderr": output.stderr,
"exit_code": output.exit_code,
}
def install_python_packages(self, package_names: str | List[str]) -> None:
"""Install python packages in the sandbox."""
self.session.install_python_packages(package_names)
def install_system_packages(self, package_names: str | List[str]) -> None:
"""Install system packages (via apt) in the sandbox."""
self.session.install_system_packages(package_names)
def download_file(self, remote_path: str) -> bytes:
"""Download file from the sandbox."""
return self.session.download_file(remote_path)
def upload_file(self, file: IO, description: str) -> UploadedFile:
"""Upload file to the sandbox.
The file is uploaded to the '/home/user/<filename>' path."""
remote_path = self.session.upload_file(file)
f = UploadedFile(
name=os.path.basename(file.name),
remote_path=remote_path,
description=description,
)
self._uploaded_files.append(f)
return f
def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None:
"""Remove uploaded file from the sandbox."""
self.session.filesystem.remove(uploaded_file.remote_path)
self._uploaded_files = [
f
for f in self._uploaded_files
if f.remote_path != uploaded_file.remote_path
]
def as_tool(self) -> Tool:
return Tool.from_function(
func=self._run,
name=self.name,
description=self.description,
args_schema=self.args_schema,
)

@ -0,0 +1,736 @@
# mypy: disable-error-code=no-untyped-def
# Because Python >3.9 doesn't support ast.unparse,
# we copied the unparse functionality from here:
# https://github.com/python/cpython/blob/3.8/Tools/parser/unparse.py
"Usage: unparse.py <path to source file>"
import ast
import io
import sys
import tokenize
# Large float and imaginary literals get turned into infinities in the AST.
# We unparse those infinities to INFSTR.
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
def interleave(inter, f, seq):
"""Call f on each item in seq, calling inter() in between."""
seq = iter(seq)
try:
f(next(seq))
except StopIteration:
pass
else:
for x in seq:
inter()
f(x)
class Unparser:
"""Methods in this class recursively traverse an AST and
output source code for the abstract syntax; original formatting
is disregarded."""
def __init__(self, tree, file=sys.stdout):
"""Unparser(tree, file=sys.stdout) -> None.
Print the source for tree to file."""
self.f = file
self._indent = 0
self.dispatch(tree)
self.f.flush()
def fill(self, text=""):
"Indent a piece of text, according to the current indentation level"
self.f.write("\n" + " " * self._indent + text)
def write(self, text):
"Append a piece of text to the current line."
self.f.write(text)
def enter(self):
"Print ':', and increase the indentation."
self.write(":")
self._indent += 1
def leave(self):
"Decrease the indentation level."
self._indent -= 1
def dispatch(self, tree):
"Dispatcher function, dispatching tree type T to method _T."
if isinstance(tree, list):
for t in tree:
self.dispatch(t)
return
meth = getattr(self, "_" + tree.__class__.__name__)
meth(tree)
############### Unparsing methods ######################
# There should be one method per concrete grammar type #
# Constructors should be grouped by sum type. Ideally, #
# this would follow the order in the grammar, but #
# currently doesn't. #
########################################################
def _Module(self, tree):
for stmt in tree.body:
self.dispatch(stmt)
# stmt
def _Expr(self, tree):
self.fill()
self.dispatch(tree.value)
def _NamedExpr(self, tree):
self.write("(")
self.dispatch(tree.target)
self.write(" := ")
self.dispatch(tree.value)
self.write(")")
def _Import(self, t):
self.fill("import ")
interleave(lambda: self.write(", "), self.dispatch, t.names)
def _ImportFrom(self, t):
self.fill("from ")
self.write("." * t.level)
if t.module:
self.write(t.module)
self.write(" import ")
interleave(lambda: self.write(", "), self.dispatch, t.names)
def _Assign(self, t):
self.fill()
for target in t.targets:
self.dispatch(target)
self.write(" = ")
self.dispatch(t.value)
def _AugAssign(self, t):
self.fill()
self.dispatch(t.target)
self.write(" " + self.binop[t.op.__class__.__name__] + "= ")
self.dispatch(t.value)
def _AnnAssign(self, t):
self.fill()
if not t.simple and isinstance(t.target, ast.Name):
self.write("(")
self.dispatch(t.target)
if not t.simple and isinstance(t.target, ast.Name):
self.write(")")
self.write(": ")
self.dispatch(t.annotation)
if t.value:
self.write(" = ")
self.dispatch(t.value)
def _Return(self, t):
self.fill("return")
if t.value:
self.write(" ")
self.dispatch(t.value)
def _Pass(self, t):
self.fill("pass")
def _Break(self, t):
self.fill("break")
def _Continue(self, t):
self.fill("continue")
def _Delete(self, t):
self.fill("del ")
interleave(lambda: self.write(", "), self.dispatch, t.targets)
def _Assert(self, t):
self.fill("assert ")
self.dispatch(t.test)
if t.msg:
self.write(", ")
self.dispatch(t.msg)
def _Global(self, t):
self.fill("global ")
interleave(lambda: self.write(", "), self.write, t.names)
def _Nonlocal(self, t):
self.fill("nonlocal ")
interleave(lambda: self.write(", "), self.write, t.names)
def _Await(self, t):
self.write("(")
self.write("await")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
def _Yield(self, t):
self.write("(")
self.write("yield")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
def _YieldFrom(self, t):
self.write("(")
self.write("yield from")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
def _Raise(self, t):
self.fill("raise")
if not t.exc:
assert not t.cause
return
self.write(" ")
self.dispatch(t.exc)
if t.cause:
self.write(" from ")
self.dispatch(t.cause)
def _Try(self, t):
self.fill("try")
self.enter()
self.dispatch(t.body)
self.leave()
for ex in t.handlers:
self.dispatch(ex)
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
if t.finalbody:
self.fill("finally")
self.enter()
self.dispatch(t.finalbody)
self.leave()
def _ExceptHandler(self, t):
self.fill("except")
if t.type:
self.write(" ")
self.dispatch(t.type)
if t.name:
self.write(" as ")
self.write(t.name)
self.enter()
self.dispatch(t.body)
self.leave()
def _ClassDef(self, t):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
self.fill("class " + t.name)
self.write("(")
comma = False
for e in t.bases:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
for e in t.keywords:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
self.write(")")
self.enter()
self.dispatch(t.body)
self.leave()
def _FunctionDef(self, t):
self.__FunctionDef_helper(t, "def")
def _AsyncFunctionDef(self, t):
self.__FunctionDef_helper(t, "async def")
def __FunctionDef_helper(self, t, fill_suffix):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
def_str = fill_suffix + " " + t.name + "("
self.fill(def_str)
self.dispatch(t.args)
self.write(")")
if t.returns:
self.write(" -> ")
self.dispatch(t.returns)
self.enter()
self.dispatch(t.body)
self.leave()
def _For(self, t):
self.__For_helper("for ", t)
def _AsyncFor(self, t):
self.__For_helper("async for ", t)
def __For_helper(self, fill, t):
self.fill(fill)
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _If(self, t):
self.fill("if ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# collapse nested ifs into equivalent elifs.
while t.orelse and len(t.orelse) == 1 and isinstance(t.orelse[0], ast.If):
t = t.orelse[0]
self.fill("elif ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# final else
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _While(self, t):
self.fill("while ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _With(self, t):
self.fill("with ")
interleave(lambda: self.write(", "), self.dispatch, t.items)
self.enter()
self.dispatch(t.body)
self.leave()
def _AsyncWith(self, t):
self.fill("async with ")
interleave(lambda: self.write(", "), self.dispatch, t.items)
self.enter()
self.dispatch(t.body)
self.leave()
# expr
def _JoinedStr(self, t):
self.write("f")
string = io.StringIO()
self._fstring_JoinedStr(t, string.write)
self.write(repr(string.getvalue()))
def _FormattedValue(self, t):
self.write("f")
string = io.StringIO()
self._fstring_FormattedValue(t, string.write)
self.write(repr(string.getvalue()))
def _fstring_JoinedStr(self, t, write):
for value in t.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, write)
def _fstring_Constant(self, t, write):
assert isinstance(t.value, str)
value = t.value.replace("{", "{{").replace("}", "}}")
write(value)
def _fstring_FormattedValue(self, t, write):
write("{")
expr = io.StringIO()
Unparser(t.value, expr)
expr = expr.getvalue().rstrip("\n")
if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {"
write(expr)
if t.conversion != -1:
conversion = chr(t.conversion)
assert conversion in "sra"
write(f"!{conversion}")
if t.format_spec:
write(":")
meth = getattr(self, "_fstring_" + type(t.format_spec).__name__)
meth(t.format_spec, write)
write("}")
def _Name(self, t):
self.write(t.id)
def _write_constant(self, value):
if isinstance(value, (float, complex)):
# Substitute overflowing decimal literal for AST infinities.
self.write(repr(value).replace("inf", INFSTR))
else:
self.write(repr(value))
def _Constant(self, t):
value = t.value
if isinstance(value, tuple):
self.write("(")
if len(value) == 1:
self._write_constant(value[0])
self.write(",")
else:
interleave(lambda: self.write(", "), self._write_constant, value)
self.write(")")
elif value is ...:
self.write("...")
else:
if t.kind == "u":
self.write("u")
self._write_constant(t.value)
def _List(self, t):
self.write("[")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("]")
def _ListComp(self, t):
self.write("[")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("]")
def _GeneratorExp(self, t):
self.write("(")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write(")")
def _SetComp(self, t):
self.write("{")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
def _DictComp(self, t):
self.write("{")
self.dispatch(t.key)
self.write(": ")
self.dispatch(t.value)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
def _comprehension(self, t):
if t.is_async:
self.write(" async for ")
else:
self.write(" for ")
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
for if_clause in t.ifs:
self.write(" if ")
self.dispatch(if_clause)
def _IfExp(self, t):
self.write("(")
self.dispatch(t.body)
self.write(" if ")
self.dispatch(t.test)
self.write(" else ")
self.dispatch(t.orelse)
self.write(")")
def _Set(self, t):
assert t.elts # should be at least one element
self.write("{")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("}")
def _Dict(self, t):
self.write("{")
def write_key_value_pair(k, v):
self.dispatch(k)
self.write(": ")
self.dispatch(v)
def write_item(item):
k, v = item
if k is None:
# for dictionary unpacking operator in dicts {**{'y': 2}}
# see PEP 448 for details
self.write("**")
self.dispatch(v)
else:
write_key_value_pair(k, v)
interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values))
self.write("}")
def _Tuple(self, t):
self.write("(")
if len(t.elts) == 1:
elt = t.elts[0]
self.dispatch(elt)
self.write(",")
else:
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write(")")
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
def _UnaryOp(self, t):
self.write("(")
self.write(self.unop[t.op.__class__.__name__])
self.write(" ")
self.dispatch(t.operand)
self.write(")")
binop = {
"Add": "+",
"Sub": "-",
"Mult": "*",
"MatMult": "@",
"Div": "/",
"Mod": "%",
"LShift": "<<",
"RShift": ">>",
"BitOr": "|",
"BitXor": "^",
"BitAnd": "&",
"FloorDiv": "//",
"Pow": "**",
}
def _BinOp(self, t):
self.write("(")
self.dispatch(t.left)
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
self.dispatch(t.right)
self.write(")")
cmpops = {
"Eq": "==",
"NotEq": "!=",
"Lt": "<",
"LtE": "<=",
"Gt": ">",
"GtE": ">=",
"Is": "is",
"IsNot": "is not",
"In": "in",
"NotIn": "not in",
}
def _Compare(self, t):
self.write("(")
self.dispatch(t.left)
for o, e in zip(t.ops, t.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.dispatch(e)
self.write(")")
boolops = {ast.And: "and", ast.Or: "or"}
def _BoolOp(self, t):
self.write("(")
s = " %s " % self.boolops[t.op.__class__]
interleave(lambda: self.write(s), self.dispatch, t.values)
self.write(")")
def _Attribute(self, t):
self.dispatch(t.value)
# Special case: 3.__abs__() is a syntax error, so if t.value
# is an integer literal then we need to either parenthesize
# it or add an extra space to get 3 .__abs__().
if isinstance(t.value, ast.Constant) and isinstance(t.value.value, int):
self.write(" ")
self.write(".")
self.write(t.attr)
def _Call(self, t):
self.dispatch(t.func)
self.write("(")
comma = False
for e in t.args:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
for e in t.keywords:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
self.write(")")
def _Subscript(self, t):
self.dispatch(t.value)
self.write("[")
if (
isinstance(t.slice, ast.Index)
and isinstance(t.slice.value, ast.Tuple)
and t.slice.value.elts
):
if len(t.slice.value.elts) == 1:
elt = t.slice.value.elts[0]
self.dispatch(elt)
self.write(",")
else:
interleave(lambda: self.write(", "), self.dispatch, t.slice.value.elts)
else:
self.dispatch(t.slice)
self.write("]")
def _Starred(self, t):
self.write("*")
self.dispatch(t.value)
# slice
def _Ellipsis(self, t):
self.write("...")
def _Index(self, t):
self.dispatch(t.value)
def _Slice(self, t):
if t.lower:
self.dispatch(t.lower)
self.write(":")
if t.upper:
self.dispatch(t.upper)
if t.step:
self.write(":")
self.dispatch(t.step)
def _ExtSlice(self, t):
if len(t.dims) == 1:
elt = t.dims[0]
self.dispatch(elt)
self.write(",")
else:
interleave(lambda: self.write(", "), self.dispatch, t.dims)
# argument
def _arg(self, t):
self.write(t.arg)
if t.annotation:
self.write(": ")
self.dispatch(t.annotation)
# others
def _arguments(self, t):
first = True
# normal arguments
all_args = t.posonlyargs + t.args
defaults = [None] * (len(all_args) - len(t.defaults)) + t.defaults
for index, elements in enumerate(zip(all_args, defaults), 1):
a, d = elements
if first:
first = False
else:
self.write(", ")
self.dispatch(a)
if d:
self.write("=")
self.dispatch(d)
if index == len(t.posonlyargs):
self.write(", /")
# varargs, or bare '*' if no varargs but keyword-only arguments present
if t.vararg or t.kwonlyargs:
if first:
first = False
else:
self.write(", ")
self.write("*")
if t.vararg:
self.write(t.vararg.arg)
if t.vararg.annotation:
self.write(": ")
self.dispatch(t.vararg.annotation)
# keyword-only arguments
if t.kwonlyargs:
for a, d in zip(t.kwonlyargs, t.kw_defaults):
if first:
first = False
else:
self.write(", ")
self.dispatch(a),
if d:
self.write("=")
self.dispatch(d)
# kwargs
if t.kwarg:
if first:
first = False
else:
self.write(", ")
self.write("**" + t.kwarg.arg)
if t.kwarg.annotation:
self.write(": ")
self.dispatch(t.kwarg.annotation)
def _keyword(self, t):
if t.arg is None:
self.write("**")
else:
self.write(t.arg)
self.write("=")
self.dispatch(t.value)
def _Lambda(self, t):
self.write("(")
self.write("lambda ")
self.dispatch(t.args)
self.write(": ")
self.dispatch(t.body)
self.write(")")
def _alias(self, t):
self.write(t.name)
if t.asname:
self.write(" as " + t.asname)
def _withitem(self, t):
self.dispatch(t.context_expr)
if t.optional_vars:
self.write(" as ")
self.dispatch(t.optional_vars)
def roundtrip(filename, output=sys.stdout):
with open(filename, "rb") as pyfile:
encoding = tokenize.detect_encoding(pyfile.readline)[0]
with open(filename, "r", encoding=encoding) as pyfile:
source = pyfile.read()
tree = compile(source, filename, "exec", ast.PyCF_ONLY_AST)
Unparser(tree, output)

@ -109,6 +109,7 @@ _EXPECTED = [
"format_tool_to_openai_function",
"tool",
"BearlyInterpreterTool",
"E2BDataAnalysisTool",
]

Loading…
Cancel
Save