community[patch]: Matching engine, return doc id (#14930)

pull/14814/head^2
Bagatur 10 months ago committed by GitHub
parent 8a3360edf6
commit 345acb26ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,7 +22,7 @@ if TYPE_CHECKING:
from langchain_community.embeddings import TensorflowHubEmbeddings from langchain_community.embeddings import TensorflowHubEmbeddings
logger = logging.getLogger() logger = logging.getLogger(__name__)
class MatchingEngine(VectorStore): class MatchingEngine(VectorStore):
@ -49,6 +49,8 @@ class MatchingEngine(VectorStore):
gcs_client: storage.Client, gcs_client: storage.Client,
gcs_bucket_name: str, gcs_bucket_name: str,
credentials: Optional[Credentials] = None, credentials: Optional[Credentials] = None,
*,
document_id_key: Optional[str] = None,
): ):
"""Google Vertex AI Vector Search (previously Matching Engine) """Google Vertex AI Vector Search (previously Matching Engine)
implementation of the vector store. implementation of the vector store.
@ -78,6 +80,9 @@ class MatchingEngine(VectorStore):
gcs_client: The GCS client. gcs_client: The GCS client.
gcs_bucket_name: The GCS bucket name. gcs_bucket_name: The GCS bucket name.
credentials (Optional): Created GCP credentials. credentials (Optional): Created GCP credentials.
document_id_key (Optional): Key for storing document ID in document
metadata. If None, document ID will not be returned in document
metadata.
""" """
super().__init__() super().__init__()
self._validate_google_libraries_installation() self._validate_google_libraries_installation()
@ -89,6 +94,7 @@ class MatchingEngine(VectorStore):
self.gcs_client = gcs_client self.gcs_client = gcs_client
self.credentials = credentials self.credentials = credentials
self.gcs_bucket_name = gcs_bucket_name self.gcs_bucket_name = gcs_bucket_name
self.document_id_key = document_id_key
@property @property
def embeddings(self) -> Embeddings: def embeddings(self) -> Embeddings:
@ -229,6 +235,7 @@ class MatchingEngine(VectorStore):
List[Tuple[Document, float]]: List of documents most similar to List[Tuple[Document, float]]: List of documents most similar to
the query text and cosine distance in float for each. the query text and cosine distance in float for each.
Lower score represents more similarity. Lower score represents more similarity.
""" """
filter = filter or [] filter = filter or []
@ -255,19 +262,27 @@ class MatchingEngine(VectorStore):
if len(response) == 0: if len(response) == 0:
return [] return []
results = [] docs: List[Tuple[Document, float]] = []
# I'm only getting the first one because queries receives an array # I'm only getting the first one because queries receives an array
# and the similarity_search method only receives one query. This # and the similarity_search method only receives one query. This
# means that the match method will always return an array with only # means that the match method will always return an array with only
# one element. # one element.
for doc in response[0]: for result in response[0]:
page_content = self._download_from_gcs(f"documents/{doc.id}") page_content = self._download_from_gcs(f"documents/{result.id}")
results.append((Document(page_content=page_content), doc.distance)) # TODO: return all metadata.
metadata = {}
if self.document_id_key is not None:
metadata[self.document_id_key] = result.id
document = Document(
page_content=page_content,
metadata=metadata,
)
docs.append((document, result.distance))
logger.debug("Downloaded documents for query.") logger.debug("Downloaded documents for query.")
return results return docs
def similarity_search( def similarity_search(
self, self,
@ -382,6 +397,7 @@ class MatchingEngine(VectorStore):
endpoint_id: str, endpoint_id: str,
credentials_path: Optional[str] = None, credentials_path: Optional[str] = None,
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
**kwargs: Any,
) -> "MatchingEngine": ) -> "MatchingEngine":
"""Takes the object creation out of the constructor. """Takes the object creation out of the constructor.
@ -397,6 +413,7 @@ class MatchingEngine(VectorStore):
the local file system. the local file system.
embedding: The :class:`Embeddings` that will be used for embedding: The :class:`Embeddings` that will be used for
embedding the texts. embedding the texts.
kwargs: Additional keyword arguments to pass to MatchingEngine.__init__().
Returns: Returns:
A configured MatchingEngine with the texts added to the index. A configured MatchingEngine with the texts added to the index.
@ -419,6 +436,7 @@ class MatchingEngine(VectorStore):
gcs_client=gcs_client, gcs_client=gcs_client,
credentials=credentials, credentials=credentials,
gcs_bucket_name=gcs_bucket_name, gcs_bucket_name=gcs_bucket_name,
**kwargs,
) )
@classmethod @classmethod

Loading…
Cancel
Save