You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/azure-dynamic-sessions/langchain_azure_dynamic_ses.../tools/sessions.py

276 lines
9.0 KiB
Python

import importlib.metadata
import json
import os
import re
import urllib
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from io import BytesIO
from typing import Any, BinaryIO, Callable, List, Optional
from uuid import uuid4
import requests
from azure.core.credentials import AccessToken
from azure.identity import DefaultAzureCredential
from langchain_core.tools import BaseTool
try:
_package_version = importlib.metadata.version("langchain-azure-dynamic-sessions")
except importlib.metadata.PackageNotFoundError:
_package_version = "0.0.0"
USER_AGENT = f"langchain-azure-dynamic-sessions/{_package_version} (Language=Python)"
def _access_token_provider_factory() -> Callable[[], Optional[str]]:
"""Factory function for creating an access token provider function.
Returns:
Callable[[], Optional[str]]: The access token provider function
"""
access_token: Optional[AccessToken] = None
def access_token_provider() -> Optional[str]:
nonlocal access_token
if access_token is None or datetime.fromtimestamp(
access_token.expires_on, timezone.utc
) < datetime.now(timezone.utc) + timedelta(minutes=5):
credential = DefaultAzureCredential()
access_token = credential.get_token("https://dynamicsessions.io/.default")
return access_token.token
return access_token_provider
def _sanitize_input(query: str) -> str:
"""Sanitize input to the python REPL.
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
Args:
query: The query to sanitize
Returns:
str: The sanitized query
"""
# Removes `, whitespace & python from start
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
# Removes whitespace & ` from end
query = re.sub(r"(\s|`)*$", "", query)
return query
@dataclass
class RemoteFileMetadata:
"""Metadata for a file in the session."""
filename: str
"""The filename relative to `/mnt/data`."""
size_in_bytes: int
"""The size of the file in bytes."""
@property
def full_path(self) -> str:
"""Get the full path of the file."""
return f"/mnt/data/{self.filename}"
@staticmethod
def from_dict(data: dict) -> "RemoteFileMetadata":
"""Create a RemoteFileMetadata object from a dictionary."""
properties = data.get("properties", {})
return RemoteFileMetadata(
filename=properties.get("filename"),
size_in_bytes=properties.get("size"),
)
class SessionsPythonREPLTool(BaseTool):
"""A tool for running Python code in an Azure Container Apps dynamic sessions
code interpreter.
Example:
.. code-block:: python
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
tool = SessionsPythonREPLTool(pool_management_endpoint="...")
result = tool.invoke("6 * 7")
"""
name: str = "Python_REPL"
description: str = (
"A Python shell. Use this to execute python commands "
"when you need to perform calculations or computations. "
"Input should be a valid python command. "
"Returns a JSON object with the result, stdout, and stderr. "
)
sanitize_input: bool = True
"""Whether to sanitize input to the python REPL."""
pool_management_endpoint: str
"""The management endpoint of the session pool. Should end with a '/'."""
access_token_provider: Callable[
[], Optional[str]
] = _access_token_provider_factory()
"""A function that returns the access token to use for the session pool."""
session_id: str = str(uuid4())
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
def _build_url(self, path: str) -> str:
pool_management_endpoint = self.pool_management_endpoint
if not pool_management_endpoint:
raise ValueError("pool_management_endpoint is not set")
if not pool_management_endpoint.endswith("/"):
pool_management_endpoint += "/"
encoded_session_id = urllib.parse.quote(self.session_id)
query = f"identifier={encoded_session_id}&api-version=2024-02-02-preview"
query_separator = "&" if "?" in pool_management_endpoint else "?"
full_url = pool_management_endpoint + path + query_separator + query
return full_url
def execute(self, python_code: str) -> Any:
"""Execute Python code in the session."""
if self.sanitize_input:
python_code = _sanitize_input(python_code)
access_token = self.access_token_provider()
api_url = self._build_url("code/execute")
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"User-Agent": USER_AGENT,
}
body = {
"properties": {
"codeInputType": "inline",
"executionType": "synchronous",
"code": python_code,
}
}
response = requests.post(api_url, headers=headers, json=body)
response.raise_for_status()
response_json = response.json()
properties = response_json.get("properties", {})
return properties
def _run(self, python_code: str) -> Any:
response = self.execute(python_code)
# if the result is an image, remove the base64 data
result = response.get("result")
if isinstance(result, dict):
if result.get("type") == "image" and "base64_data" in result:
result.pop("base64_data")
return json.dumps(
{
"result": result,
"stdout": response.get("stdout"),
"stderr": response.get("stderr"),
},
indent=2,
)
def upload_file(
self,
*,
data: Optional[BinaryIO] = None,
remote_file_path: Optional[str] = None,
local_file_path: Optional[str] = None,
) -> RemoteFileMetadata:
"""Upload a file to the session.
Args:
data: The data to upload.
remote_file_path: The path to upload the file to, relative to
`/mnt/data`. If local_file_path is provided, this is defaulted
to its filename.
local_file_path: The path to the local file to upload.
Returns:
RemoteFileMetadata: The metadata for the uploaded file
"""
if data and local_file_path:
raise ValueError("data and local_file_path cannot be provided together")
if data:
file_data = data
elif local_file_path:
if not remote_file_path:
remote_file_path = os.path.basename(local_file_path)
file_data = open(local_file_path, "rb")
access_token = self.access_token_provider()
api_url = self._build_url("files/upload")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
files = [("file", (remote_file_path, file_data, "application/octet-stream"))]
response = requests.request(
"POST", api_url, headers=headers, data={}, files=files
)
response.raise_for_status()
response_json = response.json()
return RemoteFileMetadata.from_dict(response_json["value"][0])
def download_file(
self, *, remote_file_path: str, local_file_path: Optional[str] = None
) -> BinaryIO:
"""Download a file from the session.
Args:
remote_file_path: The path to download the file from,
relative to `/mnt/data`.
local_file_path: The path to save the downloaded file to.
If not provided, the file is returned as a BufferedReader.
Returns:
BinaryIO: The data of the downloaded file.
"""
access_token = self.access_token_provider()
encoded_remote_file_path = urllib.parse.quote(remote_file_path)
api_url = self._build_url(f"files/content/{encoded_remote_file_path}")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
response = requests.get(api_url, headers=headers)
response.raise_for_status()
if local_file_path:
with open(local_file_path, "wb") as f:
f.write(response.content)
return BytesIO(response.content)
def list_files(self) -> List[RemoteFileMetadata]:
"""List the files in the session.
Returns:
list[RemoteFileMetadata]: The metadata for the files in the session
"""
access_token = self.access_token_provider()
api_url = self._build_url("files")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
response = requests.get(api_url, headers=headers)
response.raise_for_status()
response_json = response.json()
return [RemoteFileMetadata.from_dict(entry) for entry in response_json["value"]]