mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community: Refactor PebbloSafeLoader (#25582)
**Refactor PebbloSafeLoader** - Created `APIWrapper` and moved API logic into it. - Moved helper functions to the utility file. - Created smaller functions and methods for better readability. - Properly read environment variables. - Removed unused code. **Issue:** NA **Dependencies:** NA **tests**: Updated
This commit is contained in:
parent
5e3a321f71
commit
1f1679e960
@ -1,31 +1,25 @@
|
||||
"""Pebblo's safe dataloader is a wrapper for document loaders"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
|
||||
import requests # type: ignore
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_community.utilities.pebblo import (
|
||||
APP_DISCOVER_URL,
|
||||
BATCH_SIZE_BYTES,
|
||||
CLASSIFIER_URL,
|
||||
LOADER_DOC_URL,
|
||||
PEBBLO_CLOUD_URL,
|
||||
PLUGIN_VERSION,
|
||||
App,
|
||||
Doc,
|
||||
IndexedDocument,
|
||||
PebbloLoaderAPIWrapper,
|
||||
generate_size_based_batches,
|
||||
get_full_path,
|
||||
get_loader_full_path,
|
||||
get_loader_type,
|
||||
get_runtime,
|
||||
get_source_size,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -37,7 +31,6 @@ class PebbloSafeLoader(BaseLoader):
|
||||
"""
|
||||
|
||||
_discover_sent: bool = False
|
||||
_loader_sent: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -54,22 +47,17 @@ class PebbloSafeLoader(BaseLoader):
|
||||
if not name or not isinstance(name, str):
|
||||
raise NameError("Must specify a valid name.")
|
||||
self.app_name = name
|
||||
self.api_key = os.environ.get("PEBBLO_API_KEY") or api_key
|
||||
self.load_id = str(uuid.uuid4())
|
||||
self.loader = langchain_loader
|
||||
self.load_semantic = os.environ.get("PEBBLO_LOAD_SEMANTIC") or load_semantic
|
||||
self.owner = owner
|
||||
self.description = description
|
||||
self.source_path = get_loader_full_path(self.loader)
|
||||
self.source_owner = PebbloSafeLoader.get_file_owner_from_path(self.source_path)
|
||||
self.docs: List[Document] = []
|
||||
self.docs_with_id: List[IndexedDocument] = []
|
||||
loader_name = str(type(self.loader)).split(".")[-1].split("'")[0]
|
||||
self.source_type = get_loader_type(loader_name)
|
||||
self.source_path_size = self.get_source_size(self.source_path)
|
||||
self.source_aggregate_size = 0
|
||||
self.classifier_url = classifier_url or CLASSIFIER_URL
|
||||
self.classifier_location = classifier_location
|
||||
self.source_path_size = get_source_size(self.source_path)
|
||||
self.batch_size = BATCH_SIZE_BYTES
|
||||
self.loader_details = {
|
||||
"loader": loader_name,
|
||||
@ -83,7 +71,13 @@ class PebbloSafeLoader(BaseLoader):
|
||||
}
|
||||
# generate app
|
||||
self.app = self._get_app_details()
|
||||
self._send_discover()
|
||||
# initialize Pebblo Loader API client
|
||||
self.pb_client = PebbloLoaderAPIWrapper(
|
||||
api_key=api_key,
|
||||
classifier_location=classifier_location,
|
||||
classifier_url=classifier_url,
|
||||
)
|
||||
self.pb_client.send_loader_discover(self.app)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load Documents.
|
||||
@ -113,7 +107,12 @@ class PebbloSafeLoader(BaseLoader):
|
||||
is_last_batch: bool = i == total_batches - 1
|
||||
self.docs = batch
|
||||
self.docs_with_id = self._index_docs()
|
||||
classified_docs = self._classify_doc(loading_end=is_last_batch)
|
||||
classified_docs = self.pb_client.classify_documents(
|
||||
self.docs_with_id,
|
||||
self.app,
|
||||
self.loader_details,
|
||||
loading_end=is_last_batch,
|
||||
)
|
||||
self._add_pebblo_specific_metadata(classified_docs)
|
||||
if self.load_semantic:
|
||||
batch_processed_docs = self._add_semantic_to_docs(classified_docs)
|
||||
@ -147,7 +146,9 @@ class PebbloSafeLoader(BaseLoader):
|
||||
break
|
||||
self.docs = list((doc,))
|
||||
self.docs_with_id = self._index_docs()
|
||||
classified_doc = self._classify_doc()
|
||||
classified_doc = self.pb_client.classify_documents(
|
||||
self.docs_with_id, self.app, self.loader_details
|
||||
)
|
||||
self._add_pebblo_specific_metadata(classified_doc)
|
||||
if self.load_semantic:
|
||||
self.docs = self._add_semantic_to_docs(classified_doc)
|
||||
@ -159,263 +160,6 @@ class PebbloSafeLoader(BaseLoader):
|
||||
def set_discover_sent(cls) -> None:
|
||||
cls._discover_sent = True
|
||||
|
||||
@classmethod
|
||||
def set_loader_sent(cls) -> None:
|
||||
cls._loader_sent = True
|
||||
|
||||
def _classify_doc(self, loading_end: bool = False) -> dict:
|
||||
"""Send documents fetched from loader to pebblo-server. Then send
|
||||
classified documents to Daxa cloud(If api_key is present). Internal method.
|
||||
|
||||
Args:
|
||||
|
||||
loading_end (bool, optional): Flag indicating the halt of data
|
||||
loading by loader. Defaults to False.
|
||||
"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if loading_end is True:
|
||||
PebbloSafeLoader.set_loader_sent()
|
||||
doc_content = [doc.dict() for doc in self.docs_with_id]
|
||||
docs = []
|
||||
for doc in doc_content:
|
||||
doc_metadata = doc.get("metadata", {})
|
||||
doc_authorized_identities = doc_metadata.get("authorized_identities", [])
|
||||
doc_source_path = get_full_path(
|
||||
doc_metadata.get(
|
||||
"full_path", doc_metadata.get("source", self.source_path)
|
||||
)
|
||||
)
|
||||
doc_source_owner = doc_metadata.get(
|
||||
"owner", PebbloSafeLoader.get_file_owner_from_path(doc_source_path)
|
||||
)
|
||||
doc_source_size = doc_metadata.get(
|
||||
"size", self.get_source_size(doc_source_path)
|
||||
)
|
||||
page_content = str(doc.get("page_content"))
|
||||
page_content_size = self.calculate_content_size(page_content)
|
||||
self.source_aggregate_size += page_content_size
|
||||
doc_id = doc.get("pb_id", None) or 0
|
||||
docs.append(
|
||||
{
|
||||
"doc": page_content,
|
||||
"source_path": doc_source_path,
|
||||
"pb_id": doc_id,
|
||||
"last_modified": doc.get("metadata", {}).get("last_modified"),
|
||||
"file_owner": doc_source_owner,
|
||||
**(
|
||||
{"authorized_identities": doc_authorized_identities}
|
||||
if doc_authorized_identities
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{"source_path_size": doc_source_size}
|
||||
if doc_source_size is not None
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
payload: Dict[str, Any] = {
|
||||
"name": self.app_name,
|
||||
"owner": self.owner,
|
||||
"docs": docs,
|
||||
"plugin_version": PLUGIN_VERSION,
|
||||
"load_id": self.load_id,
|
||||
"loader_details": self.loader_details,
|
||||
"loading_end": "false",
|
||||
"source_owner": self.source_owner,
|
||||
"classifier_location": self.classifier_location,
|
||||
}
|
||||
if loading_end is True:
|
||||
payload["loading_end"] = "true"
|
||||
if "loader_details" in payload:
|
||||
payload["loader_details"]["source_aggregate_size"] = (
|
||||
self.source_aggregate_size
|
||||
)
|
||||
payload = Doc(**payload).dict(exclude_unset=True)
|
||||
classified_docs = {}
|
||||
# Raw payload to be sent to classifier
|
||||
if self.classifier_location == "local":
|
||||
load_doc_url = f"{self.classifier_url}{LOADER_DOC_URL}"
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
load_doc_url, headers=headers, json=payload, timeout=300
|
||||
)
|
||||
|
||||
# Updating the structure of pebblo response docs for efficient searching
|
||||
for classified_doc in json.loads(pebblo_resp.text).get("docs", []):
|
||||
classified_docs.update({classified_doc["pb_id"]: classified_doc})
|
||||
if pebblo_resp.status_code not in [
|
||||
HTTPStatus.OK,
|
||||
HTTPStatus.BAD_GATEWAY,
|
||||
]:
|
||||
logger.warning(
|
||||
"Received unexpected HTTP response code: %s",
|
||||
pebblo_resp.status_code,
|
||||
)
|
||||
logger.debug(
|
||||
"send_loader_doc[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
||||
)
|
||||
),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_loader_doc: local %s", e)
|
||||
|
||||
if self.api_key:
|
||||
if self.classifier_location == "local":
|
||||
docs = payload["docs"]
|
||||
for doc_data in docs:
|
||||
classified_data = classified_docs.get(doc_data["pb_id"], {})
|
||||
doc_data.update(
|
||||
{
|
||||
"pb_checksum": classified_data.get("pb_checksum", None),
|
||||
"loader_source_path": classified_data.get(
|
||||
"loader_source_path", None
|
||||
),
|
||||
"entities": classified_data.get("entities", {}),
|
||||
"topics": classified_data.get("topics", {}),
|
||||
}
|
||||
)
|
||||
doc_data.pop("doc")
|
||||
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{LOADER_DOC_URL}"
|
||||
try:
|
||||
pebblo_cloud_response = requests.post(
|
||||
pebblo_cloud_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
logger.debug(
|
||||
"send_loader_doc[cloud]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_cloud_response.request.url,
|
||||
str(pebblo_cloud_response.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_cloud_response.request.body
|
||||
if pebblo_cloud_response.request.body
|
||||
else []
|
||||
)
|
||||
),
|
||||
str(pebblo_cloud_response.status_code),
|
||||
pebblo_cloud_response.json(),
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach Pebblo cloud server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_loader_doc: cloud %s", e)
|
||||
elif self.classifier_location == "pebblo-cloud":
|
||||
logger.warning("API key is missing for sending docs to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending docs to Pebblo cloud.")
|
||||
|
||||
return classified_docs
|
||||
|
||||
@staticmethod
|
||||
def calculate_content_size(page_content: str) -> int:
|
||||
"""Calculate the content size in bytes:
|
||||
- Encode the string to bytes using a specific encoding (e.g., UTF-8)
|
||||
- Get the length of the encoded bytes.
|
||||
|
||||
Args:
|
||||
page_content (str): Data string.
|
||||
|
||||
Returns:
|
||||
int: Size of string in bytes.
|
||||
"""
|
||||
|
||||
# Encode the content to bytes using UTF-8
|
||||
encoded_content = page_content.encode("utf-8")
|
||||
size = len(encoded_content)
|
||||
return size
|
||||
|
||||
def _send_discover(self) -> None:
|
||||
"""Send app discovery payload to pebblo-server. Internal method."""
|
||||
pebblo_resp = None
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = self.app.dict(exclude_unset=True)
|
||||
# Raw discover payload to be sent to classifier
|
||||
if self.classifier_location == "local":
|
||||
app_discover_url = f"{self.classifier_url}{APP_DISCOVER_URL}"
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
app_discover_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
logger.debug(
|
||||
"send_discover[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
||||
)
|
||||
),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||
PebbloSafeLoader.set_discover_sent()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received unexpected HTTP response code:\
|
||||
{pebblo_resp.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||
|
||||
if self.api_key:
|
||||
try:
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
# If the pebblo_resp is None,
|
||||
# then the pebblo server version is not available
|
||||
if pebblo_resp:
|
||||
pebblo_server_version = json.loads(pebblo_resp.text).get(
|
||||
"pebblo_server_version"
|
||||
)
|
||||
payload.update({"pebblo_server_version": pebblo_server_version})
|
||||
|
||||
payload.update({"pebblo_client_version": PLUGIN_VERSION})
|
||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
|
||||
pebblo_cloud_response = requests.post(
|
||||
pebblo_cloud_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"send_discover[cloud]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_cloud_response.request.url,
|
||||
str(pebblo_cloud_response.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_cloud_response.request.body
|
||||
if pebblo_cloud_response.request.body
|
||||
else []
|
||||
)
|
||||
),
|
||||
str(pebblo_cloud_response.status_code),
|
||||
pebblo_cloud_response.json(),
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach Pebblo cloud server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: cloud %s", e)
|
||||
|
||||
def _get_app_details(self) -> App:
|
||||
"""Fetch app details. Internal method.
|
||||
|
||||
@ -434,49 +178,6 @@ class PebbloSafeLoader(BaseLoader):
|
||||
)
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def get_file_owner_from_path(file_path: str) -> str:
|
||||
"""Fetch owner of local file path.
|
||||
|
||||
Args:
|
||||
file_path (str): Local file path.
|
||||
|
||||
Returns:
|
||||
str: Name of owner.
|
||||
"""
|
||||
try:
|
||||
import pwd
|
||||
|
||||
file_owner_uid = os.stat(file_path).st_uid
|
||||
file_owner_name = pwd.getpwuid(file_owner_uid).pw_name
|
||||
except Exception:
|
||||
file_owner_name = "unknown"
|
||||
return file_owner_name
|
||||
|
||||
def get_source_size(self, source_path: str) -> int:
|
||||
"""Fetch size of source path. Source can be a directory or a file.
|
||||
|
||||
Args:
|
||||
source_path (str): Local path of data source.
|
||||
|
||||
Returns:
|
||||
int: Source size in bytes.
|
||||
"""
|
||||
if not source_path:
|
||||
return 0
|
||||
size = 0
|
||||
if os.path.isfile(source_path):
|
||||
size = os.path.getsize(source_path)
|
||||
elif os.path.isdir(source_path):
|
||||
total_size = 0
|
||||
for dirpath, _, filenames in os.walk(source_path):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
if not os.path.islink(fp):
|
||||
total_size += os.path.getsize(fp)
|
||||
size = total_size
|
||||
return size
|
||||
|
||||
def _index_docs(self) -> List[IndexedDocument]:
|
||||
"""
|
||||
Indexes the documents and returns a list of IndexedDocument objects.
|
||||
|
@ -1,25 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
from typing import List, Optional, Tuple
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from requests import Response, request
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLUGIN_VERSION = "0.1.1"
|
||||
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
|
||||
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
|
||||
|
||||
LOADER_DOC_URL = "/v1/loader/doc"
|
||||
APP_DISCOVER_URL = "/v1/app/discover"
|
||||
_DEFAULT_CLASSIFIER_URL = "http://localhost:8000"
|
||||
_DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai"
|
||||
BATCH_SIZE_BYTES = 100 * 1024 # 100 KB
|
||||
|
||||
# Supported loaders for Pebblo safe data loading
|
||||
@ -59,9 +63,15 @@ LOADER_TYPE_MAPPING = {
|
||||
"cloud-folder": cloud_folder,
|
||||
}
|
||||
|
||||
SUPPORTED_LOADERS = (*file_loader, *dir_loader, *in_memory)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class Routes(str, Enum):
|
||||
"""Routes available for the Pebblo API as enumerator."""
|
||||
|
||||
loader_doc = "/v1/loader/doc"
|
||||
loader_app_discover = "/v1/app/discover"
|
||||
retrieval_app_discover = "/v1/app/discover"
|
||||
prompt = "/v1/prompt"
|
||||
prompt_governance = "/v1/prompt/governance"
|
||||
|
||||
|
||||
class IndexedDocument(Document):
|
||||
@ -342,3 +352,386 @@ def generate_size_based_batches(
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
|
||||
|
||||
def get_file_owner_from_path(file_path: str) -> str:
|
||||
"""Fetch owner of local file path.
|
||||
|
||||
Args:
|
||||
file_path (str): Local file path.
|
||||
|
||||
Returns:
|
||||
str: Name of owner.
|
||||
"""
|
||||
try:
|
||||
import pwd
|
||||
|
||||
file_owner_uid = os.stat(file_path).st_uid
|
||||
file_owner_name = pwd.getpwuid(file_owner_uid).pw_name
|
||||
except Exception:
|
||||
file_owner_name = "unknown"
|
||||
return file_owner_name
|
||||
|
||||
|
||||
def get_source_size(source_path: str) -> int:
|
||||
"""Fetch size of source path. Source can be a directory or a file.
|
||||
|
||||
Args:
|
||||
source_path (str): Local path of data source.
|
||||
|
||||
Returns:
|
||||
int: Source size in bytes.
|
||||
"""
|
||||
if not source_path:
|
||||
return 0
|
||||
size = 0
|
||||
if os.path.isfile(source_path):
|
||||
size = os.path.getsize(source_path)
|
||||
elif os.path.isdir(source_path):
|
||||
total_size = 0
|
||||
for dirpath, _, filenames in os.walk(source_path):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
if not os.path.islink(fp):
|
||||
total_size += os.path.getsize(fp)
|
||||
size = total_size
|
||||
return size
|
||||
|
||||
|
||||
def calculate_content_size(data: str) -> int:
|
||||
"""Calculate the content size in bytes:
|
||||
- Encode the string to bytes using a specific encoding (e.g., UTF-8)
|
||||
- Get the length of the encoded bytes.
|
||||
|
||||
Args:
|
||||
data (str): Data string.
|
||||
|
||||
Returns:
|
||||
int: Size of string in bytes.
|
||||
"""
|
||||
encoded_content = data.encode("utf-8")
|
||||
size = len(encoded_content)
|
||||
return size
|
||||
|
||||
|
||||
class PebbloLoaderAPIWrapper(BaseModel):
|
||||
"""Wrapper for Pebblo Loader API."""
|
||||
|
||||
api_key: Optional[str] # Use SecretStr
|
||||
"""API key for Pebblo Cloud"""
|
||||
classifier_location: str = "local"
|
||||
"""Location of the classifier, local or cloud. Defaults to 'local'"""
|
||||
classifier_url: Optional[str]
|
||||
"""URL of the Pebblo Classifier"""
|
||||
cloud_url: Optional[str]
|
||||
"""URL of the Pebblo Cloud"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Validate that api key in environment."""
|
||||
kwargs["api_key"] = get_from_dict_or_env(
|
||||
kwargs, "api_key", "PEBBLO_API_KEY", ""
|
||||
)
|
||||
kwargs["classifier_url"] = get_from_dict_or_env(
|
||||
kwargs, "classifier_url", "PEBBLO_CLASSIFIER_URL", _DEFAULT_CLASSIFIER_URL
|
||||
)
|
||||
kwargs["cloud_url"] = get_from_dict_or_env(
|
||||
kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def send_loader_discover(self, app: App) -> None:
|
||||
"""
|
||||
Send app discovery request to Pebblo server & cloud.
|
||||
|
||||
Args:
|
||||
app (App): App instance to be discovered.
|
||||
"""
|
||||
pebblo_resp = None
|
||||
payload = app.dict(exclude_unset=True)
|
||||
|
||||
if self.classifier_location == "local":
|
||||
# Send app details to local classifier
|
||||
headers = self._make_headers()
|
||||
app_discover_url = f"{self.classifier_url}{Routes.loader_app_discover}"
|
||||
pebblo_resp = self.make_request("POST", app_discover_url, headers, payload)
|
||||
|
||||
if self.api_key:
|
||||
# Send app details to Pebblo cloud if api_key is present
|
||||
headers = self._make_headers(cloud_request=True)
|
||||
if pebblo_resp:
|
||||
pebblo_server_version = json.loads(pebblo_resp.text).get(
|
||||
"pebblo_server_version"
|
||||
)
|
||||
payload.update({"pebblo_server_version": pebblo_server_version})
|
||||
|
||||
payload.update({"pebblo_client_version": PLUGIN_VERSION})
|
||||
pebblo_cloud_url = f"{self.cloud_url}{Routes.loader_app_discover}"
|
||||
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
|
||||
|
||||
def classify_documents(
|
||||
self,
|
||||
docs_with_id: List[IndexedDocument],
|
||||
app: App,
|
||||
loader_details: dict,
|
||||
loading_end: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Send documents to Pebblo server for classification.
|
||||
Then send classified documents to Daxa cloud(If api_key is present).
|
||||
|
||||
Args:
|
||||
docs_with_id (List[IndexedDocument]): List of documents to be classified.
|
||||
app (App): App instance.
|
||||
loader_details (dict): Loader details.
|
||||
loading_end (bool): Boolean, indicating the halt of data loading by loader.
|
||||
"""
|
||||
source_path = loader_details.get("source_path", "")
|
||||
source_owner = get_file_owner_from_path(source_path)
|
||||
# Prepare docs for classification
|
||||
docs, source_aggregate_size = self.prepare_docs_for_classification(
|
||||
docs_with_id, source_path
|
||||
)
|
||||
# Build payload for classification
|
||||
payload = self.build_classification_payload(
|
||||
app, docs, loader_details, source_owner, source_aggregate_size, loading_end
|
||||
)
|
||||
|
||||
classified_docs = {}
|
||||
if self.classifier_location == "local":
|
||||
# Send docs to local classifier
|
||||
headers = self._make_headers()
|
||||
load_doc_url = f"{self.classifier_url}{Routes.loader_doc}"
|
||||
try:
|
||||
pebblo_resp = self.make_request(
|
||||
"POST", load_doc_url, headers, payload, 300
|
||||
)
|
||||
|
||||
if pebblo_resp:
|
||||
# Updating structure of pebblo response docs for efficient searching
|
||||
for classified_doc in json.loads(pebblo_resp.text).get("docs", []):
|
||||
classified_docs.update(
|
||||
{classified_doc["pb_id"]: classified_doc}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in classify_documents: local %s", e)
|
||||
|
||||
if self.api_key:
|
||||
# Send docs to Pebblo cloud if api_key is present
|
||||
if self.classifier_location == "local":
|
||||
# If local classifier is used add the classified information
|
||||
# and remove doc content
|
||||
self.update_doc_data(payload["docs"], classified_docs)
|
||||
self.send_docs_to_pebblo_cloud(payload)
|
||||
elif self.classifier_location == "pebblo-cloud":
|
||||
logger.warning("API key is missing for sending docs to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending docs to Pebblo cloud.")
|
||||
|
||||
return classified_docs
|
||||
|
||||
def send_docs_to_pebblo_cloud(self, payload: dict) -> None:
|
||||
"""
|
||||
Send documents to Pebblo cloud.
|
||||
|
||||
Args:
|
||||
payload (dict): The payload containing documents to be sent.
|
||||
"""
|
||||
headers = self._make_headers(cloud_request=True)
|
||||
pebblo_cloud_url = f"{self.cloud_url}{Routes.loader_doc}"
|
||||
try:
|
||||
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in classify_documents: cloud %s", e)
|
||||
|
||||
def _make_headers(self, cloud_request: bool = False) -> dict:
|
||||
"""
|
||||
Generate headers for the request.
|
||||
|
||||
args:
|
||||
cloud_request (bool): flag indicating whether the request is for Pebblo
|
||||
cloud.
|
||||
returns:
|
||||
dict: Headers for the request.
|
||||
|
||||
"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if cloud_request:
|
||||
# Add API key for Pebblo cloud request
|
||||
if self.api_key:
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
else:
|
||||
logger.warning("API key is missing for Pebblo cloud request.")
|
||||
return headers
|
||||
|
||||
def build_classification_payload(
|
||||
self,
|
||||
app: App,
|
||||
docs: List[dict],
|
||||
loader_details: dict,
|
||||
source_owner: str,
|
||||
source_aggregate_size: int,
|
||||
loading_end: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Build the payload for document classification.
|
||||
|
||||
Args:
|
||||
app (App): App instance.
|
||||
docs (List[dict]): List of documents to be classified.
|
||||
loader_details (dict): Loader details.
|
||||
source_owner (str): Owner of the source.
|
||||
source_aggregate_size (int): Aggregate size of the source.
|
||||
loading_end (bool): Boolean indicating the halt of data loading by loader.
|
||||
|
||||
Returns:
|
||||
dict: Payload for document classification.
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"name": app.name,
|
||||
"owner": app.owner,
|
||||
"docs": docs,
|
||||
"plugin_version": PLUGIN_VERSION,
|
||||
"load_id": app.load_id,
|
||||
"loader_details": loader_details,
|
||||
"loading_end": "false",
|
||||
"source_owner": source_owner,
|
||||
"classifier_location": self.classifier_location,
|
||||
}
|
||||
if loading_end is True:
|
||||
payload["loading_end"] = "true"
|
||||
if "loader_details" in payload:
|
||||
payload["loader_details"]["source_aggregate_size"] = (
|
||||
source_aggregate_size
|
||||
)
|
||||
payload = Doc(**payload).dict(exclude_unset=True)
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def make_request(
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict,
|
||||
payload: Optional[dict] = None,
|
||||
timeout: int = 20,
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
Make a request to the Pebblo API
|
||||
|
||||
Args:
|
||||
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
|
||||
url (str): URL for the request.
|
||||
headers (dict): Headers for the request.
|
||||
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
|
||||
timeout (int): Timeout for the request in seconds.
|
||||
|
||||
Returns:
|
||||
Optional[Response]: Response object if the request is successful.
|
||||
"""
|
||||
try:
|
||||
response = request(
|
||||
method=method, url=url, headers=headers, json=payload, timeout=timeout
|
||||
)
|
||||
logger.debug(
|
||||
"Request: method %s, url %s, len %s response status %s",
|
||||
method,
|
||||
response.request.url,
|
||||
str(len(response.request.body if response.request.body else [])),
|
||||
str(response.status_code),
|
||||
)
|
||||
|
||||
if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||
logger.warning(f"Pebblo Server: Error {response.status_code}")
|
||||
elif response.status_code >= HTTPStatus.BAD_REQUEST:
|
||||
logger.warning(f"Pebblo received an invalid payload: {response.text}")
|
||||
elif response.status_code != HTTPStatus.OK:
|
||||
logger.warning(
|
||||
f"Pebblo returned an unexpected response code: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
|
||||
return response
|
||||
except RequestException:
|
||||
logger.warning("Unable to reach server %s", url)
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in make_request: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def prepare_docs_for_classification(
|
||||
docs_with_id: List[IndexedDocument], source_path: str
|
||||
) -> Tuple[List[dict], int]:
|
||||
"""
|
||||
Prepare documents for classification.
|
||||
|
||||
Args:
|
||||
docs_with_id (List[IndexedDocument]): List of documents to be classified.
|
||||
source_path (str): Source path of the documents.
|
||||
|
||||
Returns:
|
||||
Tuple[List[dict], int]: Documents and the aggregate size of the source.
|
||||
"""
|
||||
docs = []
|
||||
source_aggregate_size = 0
|
||||
doc_content = [doc.dict() for doc in docs_with_id]
|
||||
for doc in doc_content:
|
||||
doc_metadata = doc.get("metadata", {})
|
||||
doc_authorized_identities = doc_metadata.get("authorized_identities", [])
|
||||
doc_source_path = get_full_path(
|
||||
doc_metadata.get(
|
||||
"full_path",
|
||||
doc_metadata.get("source", source_path),
|
||||
)
|
||||
)
|
||||
doc_source_owner = doc_metadata.get(
|
||||
"owner", get_file_owner_from_path(doc_source_path)
|
||||
)
|
||||
doc_source_size = doc_metadata.get("size", get_source_size(doc_source_path))
|
||||
page_content = str(doc.get("page_content"))
|
||||
page_content_size = calculate_content_size(page_content)
|
||||
source_aggregate_size += page_content_size
|
||||
doc_id = doc.get("pb_id", None) or 0
|
||||
docs.append(
|
||||
{
|
||||
"doc": page_content,
|
||||
"source_path": doc_source_path,
|
||||
"pb_id": doc_id,
|
||||
"last_modified": doc.get("metadata", {}).get("last_modified"),
|
||||
"file_owner": doc_source_owner,
|
||||
**(
|
||||
{"authorized_identities": doc_authorized_identities}
|
||||
if doc_authorized_identities
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{"source_path_size": doc_source_size}
|
||||
if doc_source_size is not None
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return docs, source_aggregate_size
|
||||
|
||||
@staticmethod
|
||||
def update_doc_data(docs: List[dict], classified_docs: dict) -> None:
|
||||
"""
|
||||
Update the document data with classified information.
|
||||
|
||||
Args:
|
||||
docs (List[dict]): List of document data to be updated.
|
||||
classified_docs (dict): The dictionary containing classified documents.
|
||||
"""
|
||||
for doc_data in docs:
|
||||
classified_data = classified_docs.get(doc_data["pb_id"], {})
|
||||
# Update the document data with classified information
|
||||
doc_data.update(
|
||||
{
|
||||
"pb_checksum": classified_data.get("pb_checksum"),
|
||||
"loader_source_path": classified_data.get("loader_source_path"),
|
||||
"entities": classified_data.get("entities", {}),
|
||||
"topics": classified_data.get("topics", {}),
|
||||
}
|
||||
)
|
||||
# Remove the document content
|
||||
doc_data.pop("doc")
|
||||
|
@ -144,4 +144,5 @@ def test_pebblo_safe_loader_api_key() -> None:
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert loader.api_key == api_key
|
||||
assert loader.pb_client.api_key == api_key
|
||||
assert loader.pb_client.classifier_location == "local"
|
||||
|
Loading…
Reference in New Issue
Block a user