From d7e6770de872a136dee0ba24fd3fc8d82610eb76 Mon Sep 17 00:00:00 2001 From: Holt Skinner <13262395+holtskinner@users.noreply.github.com> Date: Thu, 27 Jul 2023 19:13:49 -0500 Subject: [PATCH] refactor: Code refactoring & simplification for Google Cloud Enterprise Search retriever (#8369) Followup to https://github.com/langchain-ai/langchain/pull/7857 - Changes `_convert_search_response()` to use object attributes instead of converting to dictionary - Simplifies logic for readability --- .../google_cloud_enterprise_search.py | 43 +++++++------------ 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py b/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py index 2acb01b363..b609317a33 100644 --- a/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py +++ b/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py @@ -106,34 +106,23 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever): self, results: Sequence[SearchResult] ) -> List[Document]: """Converts a sequence of search results to a list of LangChain documents.""" - from google.protobuf.json_format import MessageToDict + documents: List[Document] = [] - documents = [] for result in results: - document_dict = MessageToDict(result.document._pb) - derived_struct_data = document_dict.get("derivedStructData", None) - if derived_struct_data: - doc_metadata = document_dict.get("structData", {}) - chunk_type = ( - "extractive_answers" - if self.get_extractive_answers - else "extractive_segments" + derived_struct_data = result.document.derived_struct_data + doc_metadata = result.document.struct_data + doc_metadata.source = derived_struct_data.link or "" + doc_metadata.id = result.document.id + + for chunk in ( + derived_struct_data.extractive_answers + or derived_struct_data.extractive_segments + ): + if hasattr(chunk, "page_number"): + doc_metadata.source += f":{chunk.page_number}" + documents.append( + Document(page_content=chunk.content, metadata=doc_metadata) ) - for chunk in derived_struct_data.get(chunk_type, []): - if chunk_type == "extractive_answers": - doc_metadata["source"] = ( - f"{derived_struct_data.get('link', '')}" - f":{chunk.get('pageNumber', '')}" - ) - else: - doc_metadata[ - "source" - ] = f"{derived_struct_data.get('link', '')}" - doc_metadata["id"] = document_dict["id"] - document = Document( - page_content=chunk.get("content", ""), metadata=doc_metadata - ) - documents.append(document) return documents @@ -162,7 +151,7 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever): extractive_content_spec=extractive_content_spec, ) - request = SearchRequest( + return SearchRequest( query=query, filter=self.filter, serving_config=self._serving_config, @@ -171,8 +160,6 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever): query_expansion_spec=query_expansion_spec, ) - return request - def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: