From eb76f9c9feddf7d2f0041d528509dd8f4e259272 Mon Sep 17 00:00:00 2001 From: Mahdi Setayesh Date: Fri, 12 Jan 2024 10:58:55 -0800 Subject: [PATCH] community: Fixing a performance issue with AzureSearch to perform batch embedding (#15594) - **Description:** Azure Cognitive Search vector DB store performs slow embedding as it does not utilize the batch embedding functionality. This PR provide a fix to improve the performance of Azure Search class when adding documents to the vector search, - **Issue:** #11313 , - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --- .../vectorstores/azuresearch.py | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index f3ab8ae1e7..992db81f22 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -248,7 +248,7 @@ class AzureSearch(VectorStore): azure_search_endpoint: str, azure_search_key: str, index_name: str, - embedding_function: Callable, + embedding_function: Union[Callable, Embeddings], search_type: str = "hybrid", semantic_configuration_name: Optional[str] = None, semantic_query_language: str = "en-us", @@ -270,6 +270,12 @@ class AzureSearch(VectorStore): """Initialize with necessary components.""" # Initialize base class self.embedding_function = embedding_function + + if isinstance(self.embedding_function, Embeddings): + self.embed_query = self.embedding_function.embed_query + else: + self.embed_query = self.embedding_function + default_fields = [ SimpleField( name=FIELDS_ID, @@ -285,7 +291,7 @@ class AzureSearch(VectorStore): name=FIELDS_CONTENT_VECTOR, type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, - vector_search_dimensions=len(embedding_function("Text")), + vector_search_dimensions=len(self.embed_query("Text")), vector_search_configuration="default", ), SearchableField( @@ -329,6 +335,20 @@ class AzureSearch(VectorStore): """Add texts data to an existing index.""" keys = kwargs.get("keys") ids = [] + + # batching support if embedding function is an Embeddings object + if isinstance(self.embedding_function, Embeddings): + try: + embeddings = self.embedding_function.embed_documents(texts) + except NotImplementedError: + embeddings = [self.embedding_function.embed_query(x) for x in texts] + else: + embeddings = [self.embedding_function(x) for x in texts] + + if len(embeddings) == 0: + logger.debug("Nothing to insert, skipping.") + return [] + # Write data to index data = [] for i, text in enumerate(texts): @@ -344,7 +364,7 @@ class AzureSearch(VectorStore): FIELDS_ID: key, FIELDS_CONTENT: text, FIELDS_CONTENT_VECTOR: np.array( - self.embedding_function(text), dtype=np.float32 + embeddings[i], dtype=np.float32 ).tolist(), FIELDS_METADATA: json.dumps(metadata), } @@ -437,9 +457,7 @@ class AzureSearch(VectorStore): search_text="", vectors=[ Vector( - value=np.array( - self.embedding_function(query), dtype=np.float32 - ).tolist(), + value=np.array(self.embed_query(query), dtype=np.float32).tolist(), k=k, fields=FIELDS_CONTENT_VECTOR, ) @@ -508,9 +526,7 @@ class AzureSearch(VectorStore): search_text=query, vectors=[ Vector( - value=np.array( - self.embedding_function(query), dtype=np.float32 - ).tolist(), + value=np.array(self.embed_query(query), dtype=np.float32).tolist(), k=k, fields=FIELDS_CONTENT_VECTOR, ) @@ -600,9 +616,7 @@ class AzureSearch(VectorStore): search_text=query, vectors=[ Vector( - value=np.array( - self.embedding_function(query), dtype=np.float32 - ).tolist(), + value=np.array(self.embed_query(query), dtype=np.float32).tolist(), k=50, fields=FIELDS_CONTENT_VECTOR, ) @@ -684,7 +698,7 @@ class AzureSearch(VectorStore): azure_search_endpoint, azure_search_key, index_name, - embedding.embed_query, + embedding, ) azure_search.add_texts(texts, metadatas, **kwargs) return azure_search