diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/base.py b/libs/community/langchain_community/chains/pebblo_retrieval/base.py index 97c939b4fc..02d3553c4b 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/base.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/base.py @@ -124,6 +124,11 @@ class PebbloRetrievalQA(Chain): ), "doc": doc.page_content, "vector_db": self.retriever.vectorstore.__class__.__name__, + **( + {"pb_checksum": doc.metadata.get("pb_checksum")} + if doc.metadata.get("pb_checksum") + else {} + ), } for doc in docs if isinstance(doc, Document) @@ -457,25 +462,24 @@ class PebbloRetrievalQA(Chain): if self.api_key: if self.classifier_location == "local": if pebblo_resp: - payload["response"] = ( - json.loads(pebblo_resp.text) - .get("retrieval_data", {}) - .get("response", {}) - ) - payload["context"] = ( - json.loads(pebblo_resp.text) - .get("retrieval_data", {}) - .get("context", []) - ) - payload["prompt"] = ( - json.loads(pebblo_resp.text) - .get("retrieval_data", {}) - .get("prompt", {}) - ) + resp = json.loads(pebblo_resp.text) + if resp: + payload["response"].update( + resp.get("retrieval_data", {}).get("response", {}) + ) + payload["response"].pop("data") + payload["prompt"].update( + resp.get("retrieval_data", {}).get("prompt", {}) + ) + payload["prompt"].pop("data") + context = payload["context"] + for context_data in context: + context_data.pop("doc") + payload["context"] = context else: - payload["response"] = None - payload["context"] = None - payload["prompt"] = None + payload["response"] = {} + payload["prompt"] = {} + payload["context"] = [] headers.update({"x-api-key": self.api_key}) pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}" try: diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/models.py b/libs/community/langchain_community/chains/pebblo_retrieval/models.py index 3b7f94d44c..13a54537b9 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/models.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/models.py @@ -129,6 +129,7 @@ class Context(BaseModel): retrieved_from: Optional[str] doc: Optional[str] vector_db: str + pb_checksum: Optional[str] class Prompt(BaseModel): diff --git a/libs/community/langchain_community/document_loaders/pebblo.py b/libs/community/langchain_community/document_loaders/pebblo.py index a695582e2f..2e31b370cb 100644 --- a/libs/community/langchain_community/document_loaders/pebblo.py +++ b/libs/community/langchain_community/document_loaders/pebblo.py @@ -5,7 +5,7 @@ import logging import os import uuid from http import HTTPStatus -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional import requests # type: ignore from langchain_core.documents import Document @@ -61,7 +61,7 @@ class PebbloSafeLoader(BaseLoader): self.source_path = get_loader_full_path(self.loader) self.source_owner = PebbloSafeLoader.get_file_owner_from_path(self.source_path) self.docs: List[Document] = [] - self.docs_with_id: Union[List[IndexedDocument], List[Document], List] = [] + self.docs_with_id: List[IndexedDocument] = [] loader_name = str(type(self.loader)).split(".")[-1].split("'")[0] self.source_type = get_loader_type(loader_name) self.source_path_size = self.get_source_size(self.source_path) @@ -89,17 +89,13 @@ class PebbloSafeLoader(BaseLoader): list: Documents fetched from load method of the wrapped `loader`. """ self.docs = self.loader.load() - # Add pebblo-specific metadata to docs - self._add_pebblo_specific_metadata() - if not self.load_semantic: - self._classify_doc(self.docs, loading_end=True) - return self.docs self.docs_with_id = self._index_docs() - classified_docs = self._classify_doc(self.docs_with_id, loading_end=True) - self.docs_with_id = self._add_semantic_to_docs( - self.docs_with_id, classified_docs - ) - self.docs = self._unindex_docs(self.docs_with_id) # type: ignore + classified_docs = self._classify_doc(loading_end=True) + self._add_pebblo_specific_metadata(classified_docs) + if self.load_semantic: + self.docs = self._add_semantic_to_docs(classified_docs) + else: + self.docs = self._unindex_docs() # type: ignore return self.docs def lazy_load(self) -> Iterator[Document]: @@ -125,19 +121,14 @@ class PebbloSafeLoader(BaseLoader): self.docs = [] break self.docs = list((doc,)) - # Add pebblo-specific metadata to docs - self._add_pebblo_specific_metadata() - if not self.load_semantic: - self._classify_doc(self.docs, loading_end=True) - yield self.docs[0] + self.docs_with_id = self._index_docs() + classified_doc = self._classify_doc() + self._add_pebblo_specific_metadata(classified_doc) + if self.load_semantic: + self.docs = self._add_semantic_to_docs(classified_doc) else: - self.docs_with_id = self._index_docs() - classified_doc = self._classify_doc(self.docs) - self.docs_with_id = self._add_semantic_to_docs( - self.docs_with_id, classified_doc - ) - self.docs = self._unindex_docs(self.docs_with_id) # type: ignore - yield self.docs[0] + self.docs = self._unindex_docs() + yield self.docs[0] @classmethod def set_discover_sent(cls) -> None: @@ -147,13 +138,12 @@ class PebbloSafeLoader(BaseLoader): def set_loader_sent(cls) -> None: cls._loader_sent = True - def _classify_doc(self, loaded_docs: list, loading_end: bool = False) -> list: + def _classify_doc(self, loading_end: bool = False) -> dict: """Send documents fetched from loader to pebblo-server. Then send classified documents to Daxa cloud(If api_key is present). Internal method. Args: - loaded_docs (list): List of documents fetched from loader's load operation. loading_end (bool, optional): Flag indicating the halt of data loading by loader. Defaults to False. """ @@ -163,9 +153,8 @@ class PebbloSafeLoader(BaseLoader): } if loading_end is True: PebbloSafeLoader.set_loader_sent() - doc_content = [doc.dict() for doc in loaded_docs] + doc_content = [doc.dict() for doc in self.docs_with_id] docs = [] - classified_docs = [] for doc in doc_content: doc_metadata = doc.get("metadata", {}) doc_authorized_identities = doc_metadata.get("authorized_identities", []) @@ -183,12 +172,12 @@ class PebbloSafeLoader(BaseLoader): page_content = str(doc.get("page_content")) page_content_size = self.calculate_content_size(page_content) self.source_aggregate_size += page_content_size - doc_id = doc.get("id", None) or 0 + doc_id = doc.get("pb_id", None) or 0 docs.append( { "doc": page_content, "source_path": doc_source_path, - "id": doc_id, + "pb_id": doc_id, "last_modified": doc.get("metadata", {}).get("last_modified"), "file_owner": doc_source_owner, **( @@ -221,6 +210,7 @@ class PebbloSafeLoader(BaseLoader): self.source_aggregate_size ) payload = Doc(**payload).dict(exclude_unset=True) + classified_docs = {} # Raw payload to be sent to classifier if self.classifier_location == "local": load_doc_url = f"{self.classifier_url}{LOADER_DOC_URL}" @@ -228,7 +218,10 @@ class PebbloSafeLoader(BaseLoader): pebblo_resp = requests.post( load_doc_url, headers=headers, json=payload, timeout=300 ) - classified_docs = json.loads(pebblo_resp.text).get("docs", None) + + # Updating the structure of pebblo response docs for efficient searching + for classified_doc in json.loads(pebblo_resp.text).get("docs", []): + classified_docs.update({classified_doc["pb_id"]: classified_doc}) if pebblo_resp.status_code not in [ HTTPStatus.OK, HTTPStatus.BAD_GATEWAY, @@ -257,7 +250,21 @@ class PebbloSafeLoader(BaseLoader): if self.api_key: if self.classifier_location == "local": - payload["docs"] = classified_docs + docs = payload["docs"] + for doc_data in docs: + classified_data = classified_docs.get(doc_data["pb_id"], {}) + doc_data.update( + { + "pb_checksum": classified_data.get("pb_checksum", None), + "loader_source_path": classified_data.get( + "loader_source_path", None + ), + "entities": classified_data.get("entities", {}), + "topics": classified_data.get("topics", {}), + } + ) + doc_data.pop("doc") + headers.update({"x-api-key": self.api_key}) pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{LOADER_DOC_URL}" try: @@ -453,33 +460,29 @@ class PebbloSafeLoader(BaseLoader): List[IndexedDocument]: A list of IndexedDocument objects with unique IDs. """ docs_with_id = [ - IndexedDocument(id=hex(i)[2:], **doc.dict()) + IndexedDocument(pb_id=str(i), **doc.dict()) for i, doc in enumerate(self.docs) ] return docs_with_id - def _add_semantic_to_docs( - self, docs_with_id: List[IndexedDocument], classified_docs: List[dict] - ) -> List[Document]: + def _add_semantic_to_docs(self, classified_docs: Dict) -> List[Document]: """ Adds semantic metadata to the given list of documents. Args: - docs_with_id (List[IndexedDocument]): A list of IndexedDocument objects - containing the documents with their IDs. - classified_docs (List[dict]): A list of dictionaries containing the - classified documents. + classified_docs (Dict): A dictionary of dictionaries containing the + classified documents with pb_id as key. Returns: List[Document]: A list of Document objects with added semantic metadata. """ indexed_docs = { - doc.id: Document(page_content=doc.page_content, metadata=doc.metadata) - for doc in docs_with_id + doc.pb_id: Document(page_content=doc.page_content, metadata=doc.metadata) + for doc in self.docs_with_id } - for classified_doc in classified_docs: - doc_id = classified_doc.get("id") + for classified_doc in classified_docs.values(): + doc_id = classified_doc.get("pb_id") if doc_id in indexed_docs: self._add_semantic_to_doc(indexed_docs[doc_id], classified_doc) @@ -487,19 +490,16 @@ class PebbloSafeLoader(BaseLoader): return semantic_metadata_docs - def _unindex_docs(self, docs_with_id: List[IndexedDocument]) -> List[Document]: + def _unindex_docs(self) -> List[Document]: """ Converts a list of IndexedDocument objects to a list of Document objects. - Args: - docs_with_id (List[IndexedDocument]): A list of IndexedDocument objects. - Returns: List[Document]: A list of Document objects. """ docs = [ Document(page_content=doc.page_content, metadata=doc.metadata) - for i, doc in enumerate(docs_with_id) + for i, doc in enumerate(self.docs_with_id) ] return docs @@ -522,12 +522,16 @@ class PebbloSafeLoader(BaseLoader): ) return doc - def _add_pebblo_specific_metadata(self) -> None: + def _add_pebblo_specific_metadata(self, classified_docs: dict) -> None: """Add Pebblo specific metadata to documents.""" - for doc in self.docs: + for doc in self.docs_with_id: doc_metadata = doc.metadata doc_metadata["full_path"] = get_full_path( doc_metadata.get( "full_path", doc_metadata.get("source", self.source_path) ) ) + doc_metadata["pb_id"] = doc.pb_id + doc_metadata["pb_checksum"] = classified_docs.get(doc.pb_id, {}).get( + "pb_checksum", None + ) diff --git a/libs/community/langchain_community/utilities/pebblo.py b/libs/community/langchain_community/utilities/pebblo.py index a9bbb47241..568efae7b6 100644 --- a/libs/community/langchain_community/utilities/pebblo.py +++ b/libs/community/langchain_community/utilities/pebblo.py @@ -66,7 +66,7 @@ logger = logging.getLogger(__name__) class IndexedDocument(Document): """Pebblo Indexed Document.""" - id: str + pb_id: str """Unique ID of the document.""" diff --git a/libs/community/tests/unit_tests/document_loaders/test_pebblo.py b/libs/community/tests/unit_tests/document_loaders/test_pebblo.py index 1cee8a849d..d0a71faae7 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_pebblo.py +++ b/libs/community/tests/unit_tests/document_loaders/test_pebblo.py @@ -65,12 +65,26 @@ def test_csv_loader_load_valid_data(mocker: MockerFixture) -> None: full_file_path = os.path.abspath(file_path) expected_docs = [ Document( + metadata={ + "source": full_file_path, + "row": 0, + "full_path": full_file_path, + "pb_id": "0", + # For UT as here we are not calculating checksum + "pb_checksum": None, + }, page_content="column1: value1\ncolumn2: value2\ncolumn3: value3", - metadata={"source": file_path, "row": 0, "full_path": full_file_path}, ), Document( + metadata={ + "source": full_file_path, + "row": 1, + "full_path": full_file_path, + "pb_id": "1", + # For UT as here we are not calculating checksum + "pb_checksum": None, + }, page_content="column1: value4\ncolumn2: value5\ncolumn3: value6", - metadata={"source": file_path, "row": 1, "full_path": full_file_path}, ), ]