From 00baddf34cb128f6d87b9bf5eb94e514ccbf832d Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Mon, 28 Aug 2023 15:38:56 +0200 Subject: [PATCH] fixed enterprise search returning an empty array --- .../retrievers/google_cloud_enterprise_search.py | 12 +++++++++--- .../test_google_cloud_enterprise_search.py | 3 +++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py b/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py index 3570047509..4e9c478d2b 100644 --- a/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py +++ b/libs/langchain/langchain/retrievers/google_cloud_enterprise_search.py @@ -114,7 +114,13 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever): def __init__(self, **data: Any) -> None: """Initializes private fields.""" - from google.cloud.discoveryengine_v1beta import SearchServiceClient + try: + from google.cloud.discoveryengine_v1beta import SearchServiceClient + except ImportError: + raise ImportError( + "google.cloud.discoveryengine is not installed." + "Please install it with pip install google-cloud-discoveryengine" + ) super().__init__(**data) self._client = SearchServiceClient(credentials=self.credentials) @@ -137,7 +143,7 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever): document_dict = MessageToDict( result.document._pb, preserving_proto_field_name=True ) - derived_struct_data = document_dict.get("derived_struct_data", None) + derived_struct_data = document_dict.get("derived_struct_data") if not derived_struct_data: continue @@ -150,7 +156,7 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever): else "extractive_segments" ) - for chunk in getattr(derived_struct_data, chunk_type, []): + for chunk in derived_struct_data.get(chunk_type, []): doc_metadata["source"] = derived_struct_data.get("link", "") if chunk_type == "extractive_answers": diff --git a/libs/langchain/tests/integration_tests/retrievers/test_google_cloud_enterprise_search.py b/libs/langchain/tests/integration_tests/retrievers/test_google_cloud_enterprise_search.py index 47f576ac29..86c80cfa27 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_google_cloud_enterprise_search.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_google_cloud_enterprise_search.py @@ -24,6 +24,9 @@ def test_google_cloud_enterprise_search_get_relevant_documents() -> None: """Test the get_relevant_documents() method.""" retriever = GoogleCloudEnterpriseSearchRetriever() documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?") + assert len(documents) > 0 for doc in documents: assert isinstance(doc, Document) assert doc.page_content + assert doc.metadata["id"] + assert doc.metadata["source"]