community: Refactor PebbloSafeLoader (#25582)

**Refactor PebbloSafeLoader**
  - Created `APIWrapper` and moved API logic into it.
  - Moved helper functions to the utility file.
  - Created smaller functions and methods for better readability.
  - Properly read environment variables.
  - Removed unused code.

**Issue:** NA
**Dependencies:** NA
**tests**:  Updated
This commit is contained in:
Rajendra Kadam 2024-08-22 21:16:52 +05:30 committed by GitHub
parent 5e3a321f71
commit 1f1679e960
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 422 additions and 327 deletions

View File

@ -1,31 +1,25 @@
"""Pebblo's safe dataloader is a wrapper for document loaders"""
import json
import logging
import os
import uuid
from http import HTTPStatus
from typing import Any, Dict, Iterator, List, Optional
from typing import Dict, Iterator, List, Optional
import requests # type: ignore
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.pebblo import (
APP_DISCOVER_URL,
BATCH_SIZE_BYTES,
CLASSIFIER_URL,
LOADER_DOC_URL,
PEBBLO_CLOUD_URL,
PLUGIN_VERSION,
App,
Doc,
IndexedDocument,
PebbloLoaderAPIWrapper,
generate_size_based_batches,
get_full_path,
get_loader_full_path,
get_loader_type,
get_runtime,
get_source_size,
)
logger = logging.getLogger(__name__)
@ -37,7 +31,6 @@ class PebbloSafeLoader(BaseLoader):
"""
_discover_sent: bool = False
_loader_sent: bool = False
def __init__(
self,
@ -54,22 +47,17 @@ class PebbloSafeLoader(BaseLoader):
if not name or not isinstance(name, str):
raise NameError("Must specify a valid name.")
self.app_name = name
self.api_key = os.environ.get("PEBBLO_API_KEY") or api_key
self.load_id = str(uuid.uuid4())
self.loader = langchain_loader
self.load_semantic = os.environ.get("PEBBLO_LOAD_SEMANTIC") or load_semantic
self.owner = owner
self.description = description
self.source_path = get_loader_full_path(self.loader)
self.source_owner = PebbloSafeLoader.get_file_owner_from_path(self.source_path)
self.docs: List[Document] = []
self.docs_with_id: List[IndexedDocument] = []
loader_name = str(type(self.loader)).split(".")[-1].split("'")[0]
self.source_type = get_loader_type(loader_name)
self.source_path_size = self.get_source_size(self.source_path)
self.source_aggregate_size = 0
self.classifier_url = classifier_url or CLASSIFIER_URL
self.classifier_location = classifier_location
self.source_path_size = get_source_size(self.source_path)
self.batch_size = BATCH_SIZE_BYTES
self.loader_details = {
"loader": loader_name,
@ -83,7 +71,13 @@ class PebbloSafeLoader(BaseLoader):
}
# generate app
self.app = self._get_app_details()
self._send_discover()
# initialize Pebblo Loader API client
self.pb_client = PebbloLoaderAPIWrapper(
api_key=api_key,
classifier_location=classifier_location,
classifier_url=classifier_url,
)
self.pb_client.send_loader_discover(self.app)
def load(self) -> List[Document]:
"""Load Documents.
@ -113,7 +107,12 @@ class PebbloSafeLoader(BaseLoader):
is_last_batch: bool = i == total_batches - 1
self.docs = batch
self.docs_with_id = self._index_docs()
classified_docs = self._classify_doc(loading_end=is_last_batch)
classified_docs = self.pb_client.classify_documents(
self.docs_with_id,
self.app,
self.loader_details,
loading_end=is_last_batch,
)
self._add_pebblo_specific_metadata(classified_docs)
if self.load_semantic:
batch_processed_docs = self._add_semantic_to_docs(classified_docs)
@ -147,7 +146,9 @@ class PebbloSafeLoader(BaseLoader):
break
self.docs = list((doc,))
self.docs_with_id = self._index_docs()
classified_doc = self._classify_doc()
classified_doc = self.pb_client.classify_documents(
self.docs_with_id, self.app, self.loader_details
)
self._add_pebblo_specific_metadata(classified_doc)
if self.load_semantic:
self.docs = self._add_semantic_to_docs(classified_doc)
@ -159,263 +160,6 @@ class PebbloSafeLoader(BaseLoader):
def set_discover_sent(cls) -> None:
cls._discover_sent = True
@classmethod
def set_loader_sent(cls) -> None:
cls._loader_sent = True
def _classify_doc(self, loading_end: bool = False) -> dict:
"""Send documents fetched from loader to pebblo-server. Then send
classified documents to Daxa cloud(If api_key is present). Internal method.
Args:
loading_end (bool, optional): Flag indicating the halt of data
loading by loader. Defaults to False.
"""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if loading_end is True:
PebbloSafeLoader.set_loader_sent()
doc_content = [doc.dict() for doc in self.docs_with_id]
docs = []
for doc in doc_content:
doc_metadata = doc.get("metadata", {})
doc_authorized_identities = doc_metadata.get("authorized_identities", [])
doc_source_path = get_full_path(
doc_metadata.get(
"full_path", doc_metadata.get("source", self.source_path)
)
)
doc_source_owner = doc_metadata.get(
"owner", PebbloSafeLoader.get_file_owner_from_path(doc_source_path)
)
doc_source_size = doc_metadata.get(
"size", self.get_source_size(doc_source_path)
)
page_content = str(doc.get("page_content"))
page_content_size = self.calculate_content_size(page_content)
self.source_aggregate_size += page_content_size
doc_id = doc.get("pb_id", None) or 0
docs.append(
{
"doc": page_content,
"source_path": doc_source_path,
"pb_id": doc_id,
"last_modified": doc.get("metadata", {}).get("last_modified"),
"file_owner": doc_source_owner,
**(
{"authorized_identities": doc_authorized_identities}
if doc_authorized_identities
else {}
),
**(
{"source_path_size": doc_source_size}
if doc_source_size is not None
else {}
),
}
)
payload: Dict[str, Any] = {
"name": self.app_name,
"owner": self.owner,
"docs": docs,
"plugin_version": PLUGIN_VERSION,
"load_id": self.load_id,
"loader_details": self.loader_details,
"loading_end": "false",
"source_owner": self.source_owner,
"classifier_location": self.classifier_location,
}
if loading_end is True:
payload["loading_end"] = "true"
if "loader_details" in payload:
payload["loader_details"]["source_aggregate_size"] = (
self.source_aggregate_size
)
payload = Doc(**payload).dict(exclude_unset=True)
classified_docs = {}
# Raw payload to be sent to classifier
if self.classifier_location == "local":
load_doc_url = f"{self.classifier_url}{LOADER_DOC_URL}"
try:
pebblo_resp = requests.post(
load_doc_url, headers=headers, json=payload, timeout=300
)
# Updating the structure of pebblo response docs for efficient searching
for classified_doc in json.loads(pebblo_resp.text).get("docs", []):
classified_docs.update({classified_doc["pb_id"]: classified_doc})
if pebblo_resp.status_code not in [
HTTPStatus.OK,
HTTPStatus.BAD_GATEWAY,
]:
logger.warning(
"Received unexpected HTTP response code: %s",
pebblo_resp.status_code,
)
logger.debug(
"send_loader_doc[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_loader_doc: local %s", e)
if self.api_key:
if self.classifier_location == "local":
docs = payload["docs"]
for doc_data in docs:
classified_data = classified_docs.get(doc_data["pb_id"], {})
doc_data.update(
{
"pb_checksum": classified_data.get("pb_checksum", None),
"loader_source_path": classified_data.get(
"loader_source_path", None
),
"entities": classified_data.get("entities", {}),
"topics": classified_data.get("topics", {}),
}
)
doc_data.pop("doc")
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{LOADER_DOC_URL}"
try:
pebblo_cloud_response = requests.post(
pebblo_cloud_url, headers=headers, json=payload, timeout=20
)
logger.debug(
"send_loader_doc[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_loader_doc: cloud %s", e)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending docs to Pebblo cloud.")
raise NameError("API key is missing for sending docs to Pebblo cloud.")
return classified_docs
@staticmethod
def calculate_content_size(page_content: str) -> int:
"""Calculate the content size in bytes:
- Encode the string to bytes using a specific encoding (e.g., UTF-8)
- Get the length of the encoded bytes.
Args:
page_content (str): Data string.
Returns:
int: Size of string in bytes.
"""
# Encode the content to bytes using UTF-8
encoded_content = page_content.encode("utf-8")
size = len(encoded_content)
return size
def _send_discover(self) -> None:
"""Send app discovery payload to pebblo-server. Internal method."""
pebblo_resp = None
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = self.app.dict(exclude_unset=True)
# Raw discover payload to be sent to classifier
if self.classifier_location == "local":
app_discover_url = f"{self.classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloSafeLoader.set_discover_sent()
else:
logger.warning(
f"Received unexpected HTTP response code:\
{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if self.api_key:
try:
headers.update({"x-api-key": self.api_key})
# If the pebblo_resp is None,
# then the pebblo server version is not available
if pebblo_resp:
pebblo_server_version = json.loads(pebblo_resp.text).get(
"pebblo_server_version"
)
payload.update({"pebblo_server_version": pebblo_server_version})
payload.update({"pebblo_client_version": PLUGIN_VERSION})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url, headers=headers, json=payload, timeout=20
)
logger.debug(
"send_discover[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: cloud %s", e)
def _get_app_details(self) -> App:
"""Fetch app details. Internal method.
@ -434,49 +178,6 @@ class PebbloSafeLoader(BaseLoader):
)
return app
@staticmethod
def get_file_owner_from_path(file_path: str) -> str:
"""Fetch owner of local file path.
Args:
file_path (str): Local file path.
Returns:
str: Name of owner.
"""
try:
import pwd
file_owner_uid = os.stat(file_path).st_uid
file_owner_name = pwd.getpwuid(file_owner_uid).pw_name
except Exception:
file_owner_name = "unknown"
return file_owner_name
def get_source_size(self, source_path: str) -> int:
"""Fetch size of source path. Source can be a directory or a file.
Args:
source_path (str): Local path of data source.
Returns:
int: Source size in bytes.
"""
if not source_path:
return 0
size = 0
if os.path.isfile(source_path):
size = os.path.getsize(source_path)
elif os.path.isdir(source_path):
total_size = 0
for dirpath, _, filenames in os.walk(source_path):
for f in filenames:
fp = os.path.join(dirpath, f)
if not os.path.islink(fp):
total_size += os.path.getsize(fp)
size = total_size
return size
def _index_docs(self) -> List[IndexedDocument]:
"""
Indexes the documents and returns a list of IndexedDocument objects.

View File

@ -1,25 +1,29 @@
from __future__ import annotations
import json
import logging
import os
import pathlib
import platform
from typing import List, Optional, Tuple
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.env import get_runtime_environment
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils import get_from_dict_or_env
from requests import Response, request
from requests.exceptions import RequestException
from langchain_community.document_loaders.base import BaseLoader
logger = logging.getLogger(__name__)
PLUGIN_VERSION = "0.1.1"
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
LOADER_DOC_URL = "/v1/loader/doc"
APP_DISCOVER_URL = "/v1/app/discover"
_DEFAULT_CLASSIFIER_URL = "http://localhost:8000"
_DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai"
BATCH_SIZE_BYTES = 100 * 1024 # 100 KB
# Supported loaders for Pebblo safe data loading
@ -59,9 +63,15 @@ LOADER_TYPE_MAPPING = {
"cloud-folder": cloud_folder,
}
SUPPORTED_LOADERS = (*file_loader, *dir_loader, *in_memory)
logger = logging.getLogger(__name__)
class Routes(str, Enum):
"""Routes available for the Pebblo API as enumerator."""
loader_doc = "/v1/loader/doc"
loader_app_discover = "/v1/app/discover"
retrieval_app_discover = "/v1/app/discover"
prompt = "/v1/prompt"
prompt_governance = "/v1/prompt/governance"
class IndexedDocument(Document):
@ -342,3 +352,386 @@ def generate_size_based_batches(
batches.append(current_batch)
return batches
def get_file_owner_from_path(file_path: str) -> str:
"""Fetch owner of local file path.
Args:
file_path (str): Local file path.
Returns:
str: Name of owner.
"""
try:
import pwd
file_owner_uid = os.stat(file_path).st_uid
file_owner_name = pwd.getpwuid(file_owner_uid).pw_name
except Exception:
file_owner_name = "unknown"
return file_owner_name
def get_source_size(source_path: str) -> int:
"""Fetch size of source path. Source can be a directory or a file.
Args:
source_path (str): Local path of data source.
Returns:
int: Source size in bytes.
"""
if not source_path:
return 0
size = 0
if os.path.isfile(source_path):
size = os.path.getsize(source_path)
elif os.path.isdir(source_path):
total_size = 0
for dirpath, _, filenames in os.walk(source_path):
for f in filenames:
fp = os.path.join(dirpath, f)
if not os.path.islink(fp):
total_size += os.path.getsize(fp)
size = total_size
return size
def calculate_content_size(data: str) -> int:
"""Calculate the content size in bytes:
- Encode the string to bytes using a specific encoding (e.g., UTF-8)
- Get the length of the encoded bytes.
Args:
data (str): Data string.
Returns:
int: Size of string in bytes.
"""
encoded_content = data.encode("utf-8")
size = len(encoded_content)
return size
class PebbloLoaderAPIWrapper(BaseModel):
"""Wrapper for Pebblo Loader API."""
api_key: Optional[str] # Use SecretStr
"""API key for Pebblo Cloud"""
classifier_location: str = "local"
"""Location of the classifier, local or cloud. Defaults to 'local'"""
classifier_url: Optional[str]
"""URL of the Pebblo Classifier"""
cloud_url: Optional[str]
"""URL of the Pebblo Cloud"""
def __init__(self, **kwargs: Any):
"""Validate that api key in environment."""
kwargs["api_key"] = get_from_dict_or_env(
kwargs, "api_key", "PEBBLO_API_KEY", ""
)
kwargs["classifier_url"] = get_from_dict_or_env(
kwargs, "classifier_url", "PEBBLO_CLASSIFIER_URL", _DEFAULT_CLASSIFIER_URL
)
kwargs["cloud_url"] = get_from_dict_or_env(
kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL
)
super().__init__(**kwargs)
def send_loader_discover(self, app: App) -> None:
"""
Send app discovery request to Pebblo server & cloud.
Args:
app (App): App instance to be discovered.
"""
pebblo_resp = None
payload = app.dict(exclude_unset=True)
if self.classifier_location == "local":
# Send app details to local classifier
headers = self._make_headers()
app_discover_url = f"{self.classifier_url}{Routes.loader_app_discover}"
pebblo_resp = self.make_request("POST", app_discover_url, headers, payload)
if self.api_key:
# Send app details to Pebblo cloud if api_key is present
headers = self._make_headers(cloud_request=True)
if pebblo_resp:
pebblo_server_version = json.loads(pebblo_resp.text).get(
"pebblo_server_version"
)
payload.update({"pebblo_server_version": pebblo_server_version})
payload.update({"pebblo_client_version": PLUGIN_VERSION})
pebblo_cloud_url = f"{self.cloud_url}{Routes.loader_app_discover}"
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
def classify_documents(
self,
docs_with_id: List[IndexedDocument],
app: App,
loader_details: dict,
loading_end: bool = False,
) -> dict:
"""
Send documents to Pebblo server for classification.
Then send classified documents to Daxa cloud(If api_key is present).
Args:
docs_with_id (List[IndexedDocument]): List of documents to be classified.
app (App): App instance.
loader_details (dict): Loader details.
loading_end (bool): Boolean, indicating the halt of data loading by loader.
"""
source_path = loader_details.get("source_path", "")
source_owner = get_file_owner_from_path(source_path)
# Prepare docs for classification
docs, source_aggregate_size = self.prepare_docs_for_classification(
docs_with_id, source_path
)
# Build payload for classification
payload = self.build_classification_payload(
app, docs, loader_details, source_owner, source_aggregate_size, loading_end
)
classified_docs = {}
if self.classifier_location == "local":
# Send docs to local classifier
headers = self._make_headers()
load_doc_url = f"{self.classifier_url}{Routes.loader_doc}"
try:
pebblo_resp = self.make_request(
"POST", load_doc_url, headers, payload, 300
)
if pebblo_resp:
# Updating structure of pebblo response docs for efficient searching
for classified_doc in json.loads(pebblo_resp.text).get("docs", []):
classified_docs.update(
{classified_doc["pb_id"]: classified_doc}
)
except Exception as e:
logger.warning("An Exception caught in classify_documents: local %s", e)
if self.api_key:
# Send docs to Pebblo cloud if api_key is present
if self.classifier_location == "local":
# If local classifier is used add the classified information
# and remove doc content
self.update_doc_data(payload["docs"], classified_docs)
self.send_docs_to_pebblo_cloud(payload)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending docs to Pebblo cloud.")
raise NameError("API key is missing for sending docs to Pebblo cloud.")
return classified_docs
def send_docs_to_pebblo_cloud(self, payload: dict) -> None:
"""
Send documents to Pebblo cloud.
Args:
payload (dict): The payload containing documents to be sent.
"""
headers = self._make_headers(cloud_request=True)
pebblo_cloud_url = f"{self.cloud_url}{Routes.loader_doc}"
try:
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
except Exception as e:
logger.warning("An Exception caught in classify_documents: cloud %s", e)
def _make_headers(self, cloud_request: bool = False) -> dict:
"""
Generate headers for the request.
args:
cloud_request (bool): flag indicating whether the request is for Pebblo
cloud.
returns:
dict: Headers for the request.
"""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if cloud_request:
# Add API key for Pebblo cloud request
if self.api_key:
headers.update({"x-api-key": self.api_key})
else:
logger.warning("API key is missing for Pebblo cloud request.")
return headers
def build_classification_payload(
self,
app: App,
docs: List[dict],
loader_details: dict,
source_owner: str,
source_aggregate_size: int,
loading_end: bool,
) -> dict:
"""
Build the payload for document classification.
Args:
app (App): App instance.
docs (List[dict]): List of documents to be classified.
loader_details (dict): Loader details.
source_owner (str): Owner of the source.
source_aggregate_size (int): Aggregate size of the source.
loading_end (bool): Boolean indicating the halt of data loading by loader.
Returns:
dict: Payload for document classification.
"""
payload: Dict[str, Any] = {
"name": app.name,
"owner": app.owner,
"docs": docs,
"plugin_version": PLUGIN_VERSION,
"load_id": app.load_id,
"loader_details": loader_details,
"loading_end": "false",
"source_owner": source_owner,
"classifier_location": self.classifier_location,
}
if loading_end is True:
payload["loading_end"] = "true"
if "loader_details" in payload:
payload["loader_details"]["source_aggregate_size"] = (
source_aggregate_size
)
payload = Doc(**payload).dict(exclude_unset=True)
return payload
@staticmethod
def make_request(
method: str,
url: str,
headers: dict,
payload: Optional[dict] = None,
timeout: int = 20,
) -> Optional[Response]:
"""
Make a request to the Pebblo API
Args:
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
url (str): URL for the request.
headers (dict): Headers for the request.
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
timeout (int): Timeout for the request in seconds.
Returns:
Optional[Response]: Response object if the request is successful.
"""
try:
response = request(
method=method, url=url, headers=headers, json=payload, timeout=timeout
)
logger.debug(
"Request: method %s, url %s, len %s response status %s",
method,
response.request.url,
str(len(response.request.body if response.request.body else [])),
str(response.status_code),
)
if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
logger.warning(f"Pebblo Server: Error {response.status_code}")
elif response.status_code >= HTTPStatus.BAD_REQUEST:
logger.warning(f"Pebblo received an invalid payload: {response.text}")
elif response.status_code != HTTPStatus.OK:
logger.warning(
f"Pebblo returned an unexpected response code: "
f"{response.status_code}"
)
return response
except RequestException:
logger.warning("Unable to reach server %s", url)
except Exception as e:
logger.warning("An Exception caught in make_request: %s", e)
return None
@staticmethod
def prepare_docs_for_classification(
docs_with_id: List[IndexedDocument], source_path: str
) -> Tuple[List[dict], int]:
"""
Prepare documents for classification.
Args:
docs_with_id (List[IndexedDocument]): List of documents to be classified.
source_path (str): Source path of the documents.
Returns:
Tuple[List[dict], int]: Documents and the aggregate size of the source.
"""
docs = []
source_aggregate_size = 0
doc_content = [doc.dict() for doc in docs_with_id]
for doc in doc_content:
doc_metadata = doc.get("metadata", {})
doc_authorized_identities = doc_metadata.get("authorized_identities", [])
doc_source_path = get_full_path(
doc_metadata.get(
"full_path",
doc_metadata.get("source", source_path),
)
)
doc_source_owner = doc_metadata.get(
"owner", get_file_owner_from_path(doc_source_path)
)
doc_source_size = doc_metadata.get("size", get_source_size(doc_source_path))
page_content = str(doc.get("page_content"))
page_content_size = calculate_content_size(page_content)
source_aggregate_size += page_content_size
doc_id = doc.get("pb_id", None) or 0
docs.append(
{
"doc": page_content,
"source_path": doc_source_path,
"pb_id": doc_id,
"last_modified": doc.get("metadata", {}).get("last_modified"),
"file_owner": doc_source_owner,
**(
{"authorized_identities": doc_authorized_identities}
if doc_authorized_identities
else {}
),
**(
{"source_path_size": doc_source_size}
if doc_source_size is not None
else {}
),
}
)
return docs, source_aggregate_size
@staticmethod
def update_doc_data(docs: List[dict], classified_docs: dict) -> None:
"""
Update the document data with classified information.
Args:
docs (List[dict]): List of document data to be updated.
classified_docs (dict): The dictionary containing classified documents.
"""
for doc_data in docs:
classified_data = classified_docs.get(doc_data["pb_id"], {})
# Update the document data with classified information
doc_data.update(
{
"pb_checksum": classified_data.get("pb_checksum"),
"loader_source_path": classified_data.get("loader_source_path"),
"entities": classified_data.get("entities", {}),
"topics": classified_data.get("topics", {}),
}
)
# Remove the document content
doc_data.pop("doc")

View File

@ -144,4 +144,5 @@ def test_pebblo_safe_loader_api_key() -> None:
)
# Assert
assert loader.api_key == api_key
assert loader.pb_client.api_key == api_key
assert loader.pb_client.classifier_location == "local"