|
|
|
@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForToolRun,
|
|
|
|
|
CallbackManagerForToolRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain.pydantic_v1 import BaseModel, Field
|
|
|
|
|
from langchain.pydantic_v1 import BaseModel, Field, PrivateAttr
|
|
|
|
|
from langchain.tools import BaseTool, Tool
|
|
|
|
|
from langchain.tools.e2b_data_analysis.unparse import Unparser
|
|
|
|
|
|
|
|
|
@ -97,7 +97,7 @@ class E2BDataAnalysisTool(BaseTool):
|
|
|
|
|
name = "e2b_data_analysis"
|
|
|
|
|
args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments
|
|
|
|
|
session: Any
|
|
|
|
|
uploaded_files: List[UploadedFile] = Field(default_factory=list)
|
|
|
|
|
_uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list)
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
@ -119,7 +119,8 @@ class E2BDataAnalysisTool(BaseTool):
|
|
|
|
|
|
|
|
|
|
# If no API key is provided, E2B will try to read it from the environment
|
|
|
|
|
# variable E2B_API_KEY
|
|
|
|
|
session = DataAnalysis(
|
|
|
|
|
super().__init__(description=base_description, **kwargs)
|
|
|
|
|
self.session = DataAnalysis(
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
cwd=cwd,
|
|
|
|
|
env_vars=env_vars,
|
|
|
|
@ -128,21 +129,19 @@ class E2BDataAnalysisTool(BaseTool):
|
|
|
|
|
on_exit=on_exit,
|
|
|
|
|
on_artifact=on_artifact,
|
|
|
|
|
)
|
|
|
|
|
super().__init__(session=session, description=base_description, **kwargs)
|
|
|
|
|
self.uploaded_files = []
|
|
|
|
|
|
|
|
|
|
def close(self) -> None:
|
|
|
|
|
"""Close the cloud sandbox."""
|
|
|
|
|
self.uploaded_files = []
|
|
|
|
|
self._uploaded_files = []
|
|
|
|
|
self.session.close()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def uploaded_files_description(self) -> str:
|
|
|
|
|
if len(self.uploaded_files) == 0:
|
|
|
|
|
if len(self._uploaded_files) == 0:
|
|
|
|
|
return ""
|
|
|
|
|
lines = ["The following files available in the sandbox:"]
|
|
|
|
|
|
|
|
|
|
for f in self.uploaded_files:
|
|
|
|
|
for f in self._uploaded_files:
|
|
|
|
|
if f.description == "":
|
|
|
|
|
lines.append(f"- path: `{f.remote_path}`")
|
|
|
|
|
else:
|
|
|
|
@ -206,15 +205,19 @@ class E2BDataAnalysisTool(BaseTool):
|
|
|
|
|
remote_path=remote_path,
|
|
|
|
|
description=description,
|
|
|
|
|
)
|
|
|
|
|
self.uploaded_files.append(f)
|
|
|
|
|
self._uploaded_files.append(f)
|
|
|
|
|
self.description = self.description + "\n" + self.uploaded_files_description
|
|
|
|
|
return f
|
|
|
|
|
|
|
|
|
|
def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None:
|
|
|
|
|
"""Remove uploaded file from the sandbox."""
|
|
|
|
|
self.session.filesystem.remove(uploaded_file.remote_path)
|
|
|
|
|
self.uploaded_files = [
|
|
|
|
|
f for f in self.uploaded_files if f.remote_path != uploaded_file.remote_path
|
|
|
|
|
self._uploaded_files = [
|
|
|
|
|
f
|
|
|
|
|
for f in self._uploaded_files
|
|
|
|
|
if f.remote_path != uploaded_file.remote_path
|
|
|
|
|
]
|
|
|
|
|
self.description = self.description + "\n" + self.uploaded_files_description
|
|
|
|
|
|
|
|
|
|
def as_tool(self) -> Tool:
|
|
|
|
|
return Tool.from_function(
|
|
|
|
|