community: Fixing a performance issue with AzureSearch to perform batch embedding (#15594)

- **Description:** Azure Cognitive Search vector DB store performs slow
embedding as it does not utilize the batch embedding functionality. This
PR provide a fix to improve the performance of Azure Search class when
adding documents to the vector search,
  - **Issue:** #11313 ,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/15976/head
Mahdi Setayesh 6 months ago committed by GitHub
parent bc60203d0f
commit eb76f9c9fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -248,7 +248,7 @@ class AzureSearch(VectorStore):
azure_search_endpoint: str,
azure_search_key: str,
index_name: str,
embedding_function: Callable,
embedding_function: Union[Callable, Embeddings],
search_type: str = "hybrid",
semantic_configuration_name: Optional[str] = None,
semantic_query_language: str = "en-us",
@ -270,6 +270,12 @@ class AzureSearch(VectorStore):
"""Initialize with necessary components."""
# Initialize base class
self.embedding_function = embedding_function
if isinstance(self.embedding_function, Embeddings):
self.embed_query = self.embedding_function.embed_query
else:
self.embed_query = self.embedding_function
default_fields = [
SimpleField(
name=FIELDS_ID,
@ -285,7 +291,7 @@ class AzureSearch(VectorStore):
name=FIELDS_CONTENT_VECTOR,
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=len(embedding_function("Text")),
vector_search_dimensions=len(self.embed_query("Text")),
vector_search_configuration="default",
),
SearchableField(
@ -329,6 +335,20 @@ class AzureSearch(VectorStore):
"""Add texts data to an existing index."""
keys = kwargs.get("keys")
ids = []
# batching support if embedding function is an Embeddings object
if isinstance(self.embedding_function, Embeddings):
try:
embeddings = self.embedding_function.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_function.embed_query(x) for x in texts]
else:
embeddings = [self.embedding_function(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
# Write data to index
data = []
for i, text in enumerate(texts):
@ -344,7 +364,7 @@ class AzureSearch(VectorStore):
FIELDS_ID: key,
FIELDS_CONTENT: text,
FIELDS_CONTENT_VECTOR: np.array(
self.embedding_function(text), dtype=np.float32
embeddings[i], dtype=np.float32
).tolist(),
FIELDS_METADATA: json.dumps(metadata),
}
@ -437,9 +457,7 @@ class AzureSearch(VectorStore):
search_text="",
vectors=[
Vector(
value=np.array(
self.embedding_function(query), dtype=np.float32
).tolist(),
value=np.array(self.embed_query(query), dtype=np.float32).tolist(),
k=k,
fields=FIELDS_CONTENT_VECTOR,
)
@ -508,9 +526,7 @@ class AzureSearch(VectorStore):
search_text=query,
vectors=[
Vector(
value=np.array(
self.embedding_function(query), dtype=np.float32
).tolist(),
value=np.array(self.embed_query(query), dtype=np.float32).tolist(),
k=k,
fields=FIELDS_CONTENT_VECTOR,
)
@ -600,9 +616,7 @@ class AzureSearch(VectorStore):
search_text=query,
vectors=[
Vector(
value=np.array(
self.embedding_function(query), dtype=np.float32
).tolist(),
value=np.array(self.embed_query(query), dtype=np.float32).tolist(),
k=50,
fields=FIELDS_CONTENT_VECTOR,
)
@ -684,7 +698,7 @@ class AzureSearch(VectorStore):
azure_search_endpoint,
azure_search_key,
index_name,
embedding.embed_query,
embedding,
)
azure_search.add_texts(texts, metadatas, **kwargs)
return azure_search

Loading…
Cancel
Save