mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
05d31a2f00
Add missing type annotations to objects in community. These missing type annotations will raise type errors in pydantic 2.
244 lines
7.8 KiB
Python
244 lines
7.8 KiB
Python
from __future__ import annotations
|
|
|
|
import ast
|
|
import json
|
|
import os
|
|
from io import StringIO
|
|
from sys import version_info
|
|
from typing import IO, TYPE_CHECKING, Any, Callable, List, Optional, Type, Union
|
|
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForToolRun,
|
|
CallbackManager,
|
|
CallbackManagerForToolRun,
|
|
)
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr
|
|
|
|
from langchain_community.tools import BaseTool, Tool
|
|
from langchain_community.tools.e2b_data_analysis.unparse import Unparser
|
|
|
|
if TYPE_CHECKING:
|
|
from e2b import EnvVars
|
|
from e2b.templates.data_analysis import Artifact
|
|
|
|
base_description = """Evaluates python code in a sandbox environment. \
|
|
The environment is long running and exists across multiple executions. \
|
|
You must send the whole script every time and print your outputs. \
|
|
Script should be pure python code that can be evaluated. \
|
|
It should be in python format NOT markdown. \
|
|
The code should NOT be wrapped in backticks. \
|
|
All python packages including requests, matplotlib, scipy, numpy, pandas, \
|
|
etc are available. Create and display chart using `plt.show()`."""
|
|
|
|
|
|
def _unparse(tree: ast.AST) -> str:
|
|
"""Unparse the AST."""
|
|
if version_info.minor < 9:
|
|
s = StringIO()
|
|
Unparser(tree, file=s)
|
|
source_code = s.getvalue()
|
|
s.close()
|
|
else:
|
|
source_code = ast.unparse(tree) # type: ignore[attr-defined]
|
|
return source_code
|
|
|
|
|
|
def add_last_line_print(code: str) -> str:
|
|
"""Add print statement to the last line if it's missing.
|
|
|
|
Sometimes, the LLM-generated code doesn't have `print(variable_name)`, instead the
|
|
LLM tries to print the variable only by writing `variable_name` (as you would in
|
|
REPL, for example).
|
|
|
|
This methods checks the AST of the generated Python code and adds the print
|
|
statement to the last line if it's missing.
|
|
"""
|
|
tree = ast.parse(code)
|
|
node = tree.body[-1]
|
|
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
|
|
if isinstance(node.value.func, ast.Name) and node.value.func.id == "print":
|
|
return _unparse(tree)
|
|
|
|
if isinstance(node, ast.Expr):
|
|
tree.body[-1] = ast.Expr(
|
|
value=ast.Call(
|
|
func=ast.Name(id="print", ctx=ast.Load()),
|
|
args=[node.value],
|
|
keywords=[],
|
|
)
|
|
)
|
|
|
|
return _unparse(tree)
|
|
|
|
|
|
class UploadedFile(BaseModel):
|
|
"""Description of the uploaded path with its remote path."""
|
|
|
|
name: str
|
|
remote_path: str
|
|
description: str
|
|
|
|
|
|
class E2BDataAnalysisToolArguments(BaseModel):
|
|
"""Arguments for the E2BDataAnalysisTool."""
|
|
|
|
python_code: str = Field(
|
|
...,
|
|
example="print('Hello World')",
|
|
description=(
|
|
"The python script to be evaluated. "
|
|
"The contents will be in main.py. "
|
|
"It should not be in markdown format."
|
|
),
|
|
)
|
|
|
|
|
|
class E2BDataAnalysisTool(BaseTool):
|
|
"""Tool for running python code in a sandboxed environment for data analysis."""
|
|
|
|
name: str = "e2b_data_analysis"
|
|
args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments
|
|
session: Any
|
|
description: str
|
|
_uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list)
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
cwd: Optional[str] = None,
|
|
env_vars: Optional[EnvVars] = None,
|
|
on_stdout: Optional[Callable[[str], Any]] = None,
|
|
on_stderr: Optional[Callable[[str], Any]] = None,
|
|
on_artifact: Optional[Callable[[Artifact], Any]] = None,
|
|
on_exit: Optional[Callable[[int], Any]] = None,
|
|
**kwargs: Any,
|
|
):
|
|
try:
|
|
from e2b import DataAnalysis
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Unable to import e2b, please install with `pip install e2b`."
|
|
) from e
|
|
|
|
# If no API key is provided, E2B will try to read it from the environment
|
|
# variable E2B_API_KEY
|
|
super().__init__(description=base_description, **kwargs)
|
|
self.session = DataAnalysis(
|
|
api_key=api_key,
|
|
cwd=cwd,
|
|
env_vars=env_vars,
|
|
on_stdout=on_stdout,
|
|
on_stderr=on_stderr,
|
|
on_exit=on_exit,
|
|
on_artifact=on_artifact,
|
|
)
|
|
|
|
def close(self) -> None:
|
|
"""Close the cloud sandbox."""
|
|
self._uploaded_files = []
|
|
self.session.close()
|
|
|
|
@property
|
|
def uploaded_files_description(self) -> str:
|
|
if len(self._uploaded_files) == 0:
|
|
return ""
|
|
lines = ["The following files available in the sandbox:"]
|
|
|
|
for f in self._uploaded_files:
|
|
if f.description == "":
|
|
lines.append(f"- path: `{f.remote_path}`")
|
|
else:
|
|
lines.append(
|
|
f"- path: `{f.remote_path}` \n description: `{f.description}`"
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
def _run(
|
|
self,
|
|
python_code: str,
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
callbacks: Optional[CallbackManager] = None,
|
|
) -> str:
|
|
python_code = add_last_line_print(python_code)
|
|
|
|
if callbacks is not None:
|
|
on_artifact = getattr(callbacks.metadata, "on_artifact", None)
|
|
else:
|
|
on_artifact = None
|
|
|
|
stdout, stderr, artifacts = self.session.run_python(
|
|
python_code, on_artifact=on_artifact
|
|
)
|
|
|
|
out = {
|
|
"stdout": stdout,
|
|
"stderr": stderr,
|
|
"artifacts": list(map(lambda artifact: artifact.name, artifacts)),
|
|
}
|
|
return json.dumps(out)
|
|
|
|
async def _arun(
|
|
self,
|
|
python_code: str,
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
) -> str:
|
|
raise NotImplementedError("e2b_data_analysis does not support async")
|
|
|
|
def run_command(
|
|
self,
|
|
cmd: str,
|
|
) -> dict:
|
|
"""Run shell command in the sandbox."""
|
|
proc = self.session.process.start(cmd)
|
|
output = proc.wait()
|
|
return {
|
|
"stdout": output.stdout,
|
|
"stderr": output.stderr,
|
|
"exit_code": output.exit_code,
|
|
}
|
|
|
|
def install_python_packages(self, package_names: Union[str, List[str]]) -> None:
|
|
"""Install python packages in the sandbox."""
|
|
self.session.install_python_packages(package_names)
|
|
|
|
def install_system_packages(self, package_names: Union[str, List[str]]) -> None:
|
|
"""Install system packages (via apt) in the sandbox."""
|
|
self.session.install_system_packages(package_names)
|
|
|
|
def download_file(self, remote_path: str) -> bytes:
|
|
"""Download file from the sandbox."""
|
|
return self.session.download_file(remote_path)
|
|
|
|
def upload_file(self, file: IO, description: str) -> UploadedFile:
|
|
"""Upload file to the sandbox.
|
|
|
|
The file is uploaded to the '/home/user/<filename>' path."""
|
|
remote_path = self.session.upload_file(file)
|
|
|
|
f = UploadedFile(
|
|
name=os.path.basename(file.name),
|
|
remote_path=remote_path,
|
|
description=description,
|
|
)
|
|
self._uploaded_files.append(f)
|
|
self.description = self.description + "\n" + self.uploaded_files_description
|
|
return f
|
|
|
|
def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None:
|
|
"""Remove uploaded file from the sandbox."""
|
|
self.session.filesystem.remove(uploaded_file.remote_path)
|
|
self._uploaded_files = [
|
|
f
|
|
for f in self._uploaded_files
|
|
if f.remote_path != uploaded_file.remote_path
|
|
]
|
|
self.description = self.description + "\n" + self.uploaded_files_description
|
|
|
|
def as_tool(self) -> Tool:
|
|
return Tool.from_function(
|
|
func=self._run,
|
|
name=self.name,
|
|
description=self.description,
|
|
args_schema=self.args_schema,
|
|
)
|