fixed GoogleCloudEnterpriseSearchRetriever returning an empty array (#9858)

`GoogleCloudEnterpriseSearchRetriever` returned an empty array of
documents earlier, fixed
This commit is contained in:
Bagatur 2023-08-29 17:49:48 -07:00 committed by GitHub
commit d966ba63e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 3 deletions

View File

@ -114,7 +114,13 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
def __init__(self, **data: Any) -> None: def __init__(self, **data: Any) -> None:
"""Initializes private fields.""" """Initializes private fields."""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient from google.cloud.discoveryengine_v1beta import SearchServiceClient
except ImportError:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
)
super().__init__(**data) super().__init__(**data)
self._client = SearchServiceClient(credentials=self.credentials) self._client = SearchServiceClient(credentials=self.credentials)
@ -137,7 +143,7 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
document_dict = MessageToDict( document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True result.document._pb, preserving_proto_field_name=True
) )
derived_struct_data = document_dict.get("derived_struct_data", None) derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data: if not derived_struct_data:
continue continue
@ -150,7 +156,7 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
else "extractive_segments" else "extractive_segments"
) )
for chunk in getattr(derived_struct_data, chunk_type, []): for chunk in derived_struct_data.get(chunk_type, []):
doc_metadata["source"] = derived_struct_data.get("link", "") doc_metadata["source"] = derived_struct_data.get("link", "")
if chunk_type == "extractive_answers": if chunk_type == "extractive_answers":

View File

@ -24,6 +24,9 @@ def test_google_cloud_enterprise_search_get_relevant_documents() -> None:
"""Test the get_relevant_documents() method.""" """Test the get_relevant_documents() method."""
retriever = GoogleCloudEnterpriseSearchRetriever() retriever = GoogleCloudEnterpriseSearchRetriever()
documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?") documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?")
assert len(documents) > 0
for doc in documents: for doc in documents:
assert isinstance(doc, Document) assert isinstance(doc, Document)
assert doc.page_content assert doc.page_content
assert doc.metadata["id"]
assert doc.metadata["source"]