diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index a3acc952a2..5d7a6fc8ed 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import base64 import itertools import json @@ -41,7 +42,12 @@ logger = logging.getLogger() if TYPE_CHECKING: from azure.search.documents import SearchClient, SearchItemPaged - from azure.search.documents.aio import SearchClient as AsyncSearchClient + from azure.search.documents.aio import ( + AsyncSearchItemPaged, + ) + from azure.search.documents.aio import ( + SearchClient as AsyncSearchClient, + ) from azure.search.documents.indexes.models import ( CorsOptions, ScoringProfile, @@ -360,6 +366,31 @@ class AzureSearch(VectorStore): self._user_agent = user_agent self._cors_options = cors_options + def __del__(self) -> None: + # Close the sync client + if hasattr(self, "client") and self.client: + self.client.close() + + # Close the async client + if hasattr(self, "async_client") and self.async_client: + # Check if we're in an existing event loop + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Schedule the coroutine to close the async client + loop.create_task(self.async_client.close()) + else: + # If no event loop is running, run the coroutine directly + loop.run_until_complete(self.async_client.close()) + except RuntimeError: + # Handle the case where there's no event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(self.async_client.close()) + finally: + loop.close() + @property def embeddings(self) -> Optional[Embeddings]: # TODO: Support embedding object directly @@ -518,21 +549,19 @@ class AzureSearch(VectorStore): ids.append(key) # Upload data in batches if len(data) == MAX_UPLOAD_BATCH_SIZE: - async with self.async_client as async_client: - response = await async_client.upload_documents(documents=data) - # Check if all documents were successfully uploaded - if not all(r.succeeded for r in response): - raise LangChainException(response) - # Reset data - data = [] + response = await self.async_client.upload_documents(documents=data) + # Check if all documents were successfully uploaded + if not all(r.succeeded for r in response): + raise LangChainException(response) + # Reset data + data = [] # Considering case where data is an exact multiple of batch-size entries if len(data) == 0: return ids # Upload data to index - async with self.async_client as async_client: - response = await async_client.upload_documents(documents=data) + response = await self.async_client.upload_documents(documents=data) # Check if all documents were successfully uploaded if all(r.succeeded for r in response): return ids @@ -566,9 +595,8 @@ class AzureSearch(VectorStore): False otherwise. """ if ids: - async with self.async_client as async_client: - res = await async_client.delete_documents([{"id": i} for i in ids]) - return len(res) > 0 + res = await self.async_client.delete_documents([{"id": i} for i in ids]) + return len(res) > 0 else: return False @@ -748,7 +776,7 @@ class AzureSearch(VectorStore): embedding, "", k, filters=filters, **kwargs ) - return _results_to_documents(results) + return await _aresults_to_documents(results) def max_marginal_relevance_search_with_score( self, @@ -897,7 +925,7 @@ class AzureSearch(VectorStore): embedding, query, k, filters=filters, **kwargs ) - return _results_to_documents(results) + return await _aresults_to_documents(results) def hybrid_search_with_relevance_scores( self, @@ -1050,7 +1078,7 @@ class AzureSearch(VectorStore): *, filters: Optional[str] = None, **kwargs: Any, - ) -> SearchItemPaged[dict]: + ) -> AsyncSearchItemPaged[dict]: """Perform vector or hybrid search in the Azure search index. Args: @@ -1064,20 +1092,19 @@ class AzureSearch(VectorStore): """ from azure.search.documents.models import VectorizedQuery - async with self.async_client as async_client: - return await async_client.search( - search_text=text_query, - vector_queries=[ - VectorizedQuery( - vector=np.array(embedding, dtype=np.float32).tolist(), - k_nearest_neighbors=k, - fields=FIELDS_CONTENT_VECTOR, - ) - ], - filter=filters, - top=k, - **kwargs, - ) + return await self.async_client.search( + search_text=text_query, + vector_queries=[ + VectorizedQuery( + vector=np.array(embedding, dtype=np.float32).tolist(), + k_nearest_neighbors=k, + fields=FIELDS_CONTENT_VECTOR, + ) + ], + filter=filters, + top=k, + **kwargs, + ) def semantic_hybrid_search( self, query: str, k: int = 4, **kwargs: Any @@ -1289,71 +1316,68 @@ class AzureSearch(VectorStore): from azure.search.documents.models import VectorizedQuery vector = await self._aembed_query(query) - async with self.async_client as async_client: - results = await async_client.search( - search_text=query, - vector_queries=[ - VectorizedQuery( - vector=np.array(vector, dtype=np.float32).tolist(), - k_nearest_neighbors=k, - fields=FIELDS_CONTENT_VECTOR, - ) - ], - filter=filters, - query_type="semantic", - semantic_configuration_name=self.semantic_configuration_name, - query_caption="extractive", - query_answer="extractive", - top=k, - **kwargs, - ) - # Get Semantic Answers - semantic_answers = (await results.get_answers()) or [] - semantic_answers_dict: Dict = {} - for semantic_answer in semantic_answers: - semantic_answers_dict[semantic_answer.key] = { - "text": semantic_answer.text, - "highlights": semantic_answer.highlights, - } - # Convert results to Document objects - docs = [ - ( - Document( - page_content=result.pop(FIELDS_CONTENT), - metadata={ - **( - json.loads(result[FIELDS_METADATA]) - if FIELDS_METADATA in result - else { - k: v - for k, v in result.items() - if k != FIELDS_CONTENT_VECTOR - } - ), - **{ - "captions": { - "text": result.get("@search.captions", [{}])[ - 0 - ].text, - "highlights": result.get("@search.captions", [{}])[ - 0 - ].highlights, - } - if result.get("@search.captions") - else {}, - "answers": semantic_answers_dict.get( - result.get(FIELDS_ID, ""), - "", - ), - }, - }, - ), - float(result["@search.score"]), - float(result["@search.reranker_score"]), + results = await self.async_client.search( + search_text=query, + vector_queries=[ + VectorizedQuery( + vector=np.array(vector, dtype=np.float32).tolist(), + k_nearest_neighbors=k, + fields=FIELDS_CONTENT_VECTOR, ) - async for result in results - ] - return docs + ], + filter=filters, + query_type="semantic", + semantic_configuration_name=self.semantic_configuration_name, + query_caption="extractive", + query_answer="extractive", + top=k, + **kwargs, + ) + # Get Semantic Answers + semantic_answers = (await results.get_answers()) or [] + semantic_answers_dict: Dict = {} + for semantic_answer in semantic_answers: + semantic_answers_dict[semantic_answer.key] = { + "text": semantic_answer.text, + "highlights": semantic_answer.highlights, + } + # Convert results to Document objects + docs = [ + ( + Document( + page_content=result.pop(FIELDS_CONTENT), + metadata={ + **( + json.loads(result[FIELDS_METADATA]) + if FIELDS_METADATA in result + else { + k: v + for k, v in result.items() + if k != FIELDS_CONTENT_VECTOR + } + ), + **{ + "captions": { + "text": result.get("@search.captions", [{}])[0].text, + "highlights": result.get("@search.captions", [{}])[ + 0 + ].highlights, + } + if result.get("@search.captions") + else {}, + "answers": semantic_answers_dict.get( + result.get(FIELDS_ID, ""), + "", + ), + }, + }, + ), + float(result["@search.score"]), + float(result["@search.reranker_score"]), + ) + async for result in results + ] + return docs @classmethod def from_texts( @@ -1629,6 +1653,19 @@ def _results_to_documents( return docs +async def _aresults_to_documents( + results: AsyncSearchItemPaged[Dict], +) -> List[Tuple[Document, float]]: + docs = [ + ( + _result_to_document(result), + float(result["@search.score"]), + ) + async for result in results + ] + return docs + + async def _areorder_results_with_maximal_marginal_relevance( results: SearchItemPaged[Dict], query_embedding: np.ndarray, @@ -1642,7 +1679,7 @@ async def _areorder_results_with_maximal_marginal_relevance( float(result["@search.score"]), result[FIELDS_CONTENT_VECTOR], ) - for result in results + async for result in results ] documents, scores, vectors = map(list, zip(*docs))