diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index f3ab8ae1e7..992db81f22 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -248,7 +248,7 @@ class AzureSearch(VectorStore): azure_search_endpoint: str, azure_search_key: str, index_name: str, - embedding_function: Callable, + embedding_function: Union[Callable, Embeddings], search_type: str = "hybrid", semantic_configuration_name: Optional[str] = None, semantic_query_language: str = "en-us", @@ -270,6 +270,12 @@ class AzureSearch(VectorStore): """Initialize with necessary components.""" # Initialize base class self.embedding_function = embedding_function + + if isinstance(self.embedding_function, Embeddings): + self.embed_query = self.embedding_function.embed_query + else: + self.embed_query = self.embedding_function + default_fields = [ SimpleField( name=FIELDS_ID, @@ -285,7 +291,7 @@ class AzureSearch(VectorStore): name=FIELDS_CONTENT_VECTOR, type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, - vector_search_dimensions=len(embedding_function("Text")), + vector_search_dimensions=len(self.embed_query("Text")), vector_search_configuration="default", ), SearchableField( @@ -329,6 +335,20 @@ class AzureSearch(VectorStore): """Add texts data to an existing index.""" keys = kwargs.get("keys") ids = [] + + # batching support if embedding function is an Embeddings object + if isinstance(self.embedding_function, Embeddings): + try: + embeddings = self.embedding_function.embed_documents(texts) + except NotImplementedError: + embeddings = [self.embedding_function.embed_query(x) for x in texts] + else: + embeddings = [self.embedding_function(x) for x in texts] + + if len(embeddings) == 0: + logger.debug("Nothing to insert, skipping.") + return [] + # Write data to index data = [] for i, text in enumerate(texts): @@ -344,7 +364,7 @@ class AzureSearch(VectorStore): FIELDS_ID: key, FIELDS_CONTENT: text, FIELDS_CONTENT_VECTOR: np.array( - self.embedding_function(text), dtype=np.float32 + embeddings[i], dtype=np.float32 ).tolist(), FIELDS_METADATA: json.dumps(metadata), } @@ -437,9 +457,7 @@ class AzureSearch(VectorStore): search_text="", vectors=[ Vector( - value=np.array( - self.embedding_function(query), dtype=np.float32 - ).tolist(), + value=np.array(self.embed_query(query), dtype=np.float32).tolist(), k=k, fields=FIELDS_CONTENT_VECTOR, ) @@ -508,9 +526,7 @@ class AzureSearch(VectorStore): search_text=query, vectors=[ Vector( - value=np.array( - self.embedding_function(query), dtype=np.float32 - ).tolist(), + value=np.array(self.embed_query(query), dtype=np.float32).tolist(), k=k, fields=FIELDS_CONTENT_VECTOR, ) @@ -600,9 +616,7 @@ class AzureSearch(VectorStore): search_text=query, vectors=[ Vector( - value=np.array( - self.embedding_function(query), dtype=np.float32 - ).tolist(), + value=np.array(self.embed_query(query), dtype=np.float32).tolist(), k=50, fields=FIELDS_CONTENT_VECTOR, ) @@ -684,7 +698,7 @@ class AzureSearch(VectorStore): azure_search_endpoint, azure_search_key, index_name, - embedding.embed_query, + embedding, ) azure_search.add_texts(texts, metadatas, **kwargs) return azure_search