mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community: fix AzureSearch vectorstore asyncronous methods (#24921)
**Description**
Fix the asyncronous methods to retrieve documents from AzureSearch
VectorStore. The previous changes from [this
commit](ffe6ca986e
)
create a similar code for the syncronous methods and the asyncronous
ones but the asyncronous client return an asyncronous iterator
"AsyncSearchItemPaged" as said in the issue #24740.
To solve this issue, the syncronous iterators in asyncronous methods
where changed to asyncronous iterators.
@chrislrobert said in [this
comment](https://github.com/langchain-ai/langchain/issues/24740#issuecomment-2254168302)
that there was a still a flaw due to `with` blocks that close the client
after each call. I removed this `with` blocks in the `async_client`
following the same pattern as the sync `client`.
In order to close up the connections, a __del__ method is included to
gently close up clients once the vectorstore object is destroyed.
**Issue:** #24740 and #24064
**Dependencies:** No new dependencies for this change
**Example notebook:** I created a notebook just to test the changes work
and gives the same results as the syncronous methods for vector and
hybrid search. With these changes, the asyncronous methods in the
retriever work as well.
![image](https://github.com/user-attachments/assets/697e431b-9d7f-4d0d-b205-59d051ac2b67)
**Lint and test**: Passes the tests and the linter
This commit is contained in:
parent
6bc451b942
commit
9d08369442
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import itertools
|
||||
import json
|
||||
@ -41,7 +42,12 @@ logger = logging.getLogger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.search.documents import SearchClient, SearchItemPaged
|
||||
from azure.search.documents.aio import SearchClient as AsyncSearchClient
|
||||
from azure.search.documents.aio import (
|
||||
AsyncSearchItemPaged,
|
||||
)
|
||||
from azure.search.documents.aio import (
|
||||
SearchClient as AsyncSearchClient,
|
||||
)
|
||||
from azure.search.documents.indexes.models import (
|
||||
CorsOptions,
|
||||
ScoringProfile,
|
||||
@ -360,6 +366,31 @@ class AzureSearch(VectorStore):
|
||||
self._user_agent = user_agent
|
||||
self._cors_options = cors_options
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Close the sync client
|
||||
if hasattr(self, "client") and self.client:
|
||||
self.client.close()
|
||||
|
||||
# Close the async client
|
||||
if hasattr(self, "async_client") and self.async_client:
|
||||
# Check if we're in an existing event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Schedule the coroutine to close the async client
|
||||
loop.create_task(self.async_client.close())
|
||||
else:
|
||||
# If no event loop is running, run the coroutine directly
|
||||
loop.run_until_complete(self.async_client.close())
|
||||
except RuntimeError:
|
||||
# Handle the case where there's no event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(self.async_client.close())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
# TODO: Support embedding object directly
|
||||
@ -518,21 +549,19 @@ class AzureSearch(VectorStore):
|
||||
ids.append(key)
|
||||
# Upload data in batches
|
||||
if len(data) == MAX_UPLOAD_BATCH_SIZE:
|
||||
async with self.async_client as async_client:
|
||||
response = await async_client.upload_documents(documents=data)
|
||||
# Check if all documents were successfully uploaded
|
||||
if not all(r.succeeded for r in response):
|
||||
raise LangChainException(response)
|
||||
# Reset data
|
||||
data = []
|
||||
response = await self.async_client.upload_documents(documents=data)
|
||||
# Check if all documents were successfully uploaded
|
||||
if not all(r.succeeded for r in response):
|
||||
raise LangChainException(response)
|
||||
# Reset data
|
||||
data = []
|
||||
|
||||
# Considering case where data is an exact multiple of batch-size entries
|
||||
if len(data) == 0:
|
||||
return ids
|
||||
|
||||
# Upload data to index
|
||||
async with self.async_client as async_client:
|
||||
response = await async_client.upload_documents(documents=data)
|
||||
response = await self.async_client.upload_documents(documents=data)
|
||||
# Check if all documents were successfully uploaded
|
||||
if all(r.succeeded for r in response):
|
||||
return ids
|
||||
@ -566,9 +595,8 @@ class AzureSearch(VectorStore):
|
||||
False otherwise.
|
||||
"""
|
||||
if ids:
|
||||
async with self.async_client as async_client:
|
||||
res = await async_client.delete_documents([{"id": i} for i in ids])
|
||||
return len(res) > 0
|
||||
res = await self.async_client.delete_documents([{"id": i} for i in ids])
|
||||
return len(res) > 0
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -748,7 +776,7 @@ class AzureSearch(VectorStore):
|
||||
embedding, "", k, filters=filters, **kwargs
|
||||
)
|
||||
|
||||
return _results_to_documents(results)
|
||||
return await _aresults_to_documents(results)
|
||||
|
||||
def max_marginal_relevance_search_with_score(
|
||||
self,
|
||||
@ -897,7 +925,7 @@ class AzureSearch(VectorStore):
|
||||
embedding, query, k, filters=filters, **kwargs
|
||||
)
|
||||
|
||||
return _results_to_documents(results)
|
||||
return await _aresults_to_documents(results)
|
||||
|
||||
def hybrid_search_with_relevance_scores(
|
||||
self,
|
||||
@ -1050,7 +1078,7 @@ class AzureSearch(VectorStore):
|
||||
*,
|
||||
filters: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> SearchItemPaged[dict]:
|
||||
) -> AsyncSearchItemPaged[dict]:
|
||||
"""Perform vector or hybrid search in the Azure search index.
|
||||
|
||||
Args:
|
||||
@ -1064,20 +1092,19 @@ class AzureSearch(VectorStore):
|
||||
"""
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
|
||||
async with self.async_client as async_client:
|
||||
return await async_client.search(
|
||||
search_text=text_query,
|
||||
vector_queries=[
|
||||
VectorizedQuery(
|
||||
vector=np.array(embedding, dtype=np.float32).tolist(),
|
||||
k_nearest_neighbors=k,
|
||||
fields=FIELDS_CONTENT_VECTOR,
|
||||
)
|
||||
],
|
||||
filter=filters,
|
||||
top=k,
|
||||
**kwargs,
|
||||
)
|
||||
return await self.async_client.search(
|
||||
search_text=text_query,
|
||||
vector_queries=[
|
||||
VectorizedQuery(
|
||||
vector=np.array(embedding, dtype=np.float32).tolist(),
|
||||
k_nearest_neighbors=k,
|
||||
fields=FIELDS_CONTENT_VECTOR,
|
||||
)
|
||||
],
|
||||
filter=filters,
|
||||
top=k,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def semantic_hybrid_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
@ -1289,71 +1316,68 @@ class AzureSearch(VectorStore):
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
|
||||
vector = await self._aembed_query(query)
|
||||
async with self.async_client as async_client:
|
||||
results = await async_client.search(
|
||||
search_text=query,
|
||||
vector_queries=[
|
||||
VectorizedQuery(
|
||||
vector=np.array(vector, dtype=np.float32).tolist(),
|
||||
k_nearest_neighbors=k,
|
||||
fields=FIELDS_CONTENT_VECTOR,
|
||||
)
|
||||
],
|
||||
filter=filters,
|
||||
query_type="semantic",
|
||||
semantic_configuration_name=self.semantic_configuration_name,
|
||||
query_caption="extractive",
|
||||
query_answer="extractive",
|
||||
top=k,
|
||||
**kwargs,
|
||||
)
|
||||
# Get Semantic Answers
|
||||
semantic_answers = (await results.get_answers()) or []
|
||||
semantic_answers_dict: Dict = {}
|
||||
for semantic_answer in semantic_answers:
|
||||
semantic_answers_dict[semantic_answer.key] = {
|
||||
"text": semantic_answer.text,
|
||||
"highlights": semantic_answer.highlights,
|
||||
}
|
||||
# Convert results to Document objects
|
||||
docs = [
|
||||
(
|
||||
Document(
|
||||
page_content=result.pop(FIELDS_CONTENT),
|
||||
metadata={
|
||||
**(
|
||||
json.loads(result[FIELDS_METADATA])
|
||||
if FIELDS_METADATA in result
|
||||
else {
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k != FIELDS_CONTENT_VECTOR
|
||||
}
|
||||
),
|
||||
**{
|
||||
"captions": {
|
||||
"text": result.get("@search.captions", [{}])[
|
||||
0
|
||||
].text,
|
||||
"highlights": result.get("@search.captions", [{}])[
|
||||
0
|
||||
].highlights,
|
||||
}
|
||||
if result.get("@search.captions")
|
||||
else {},
|
||||
"answers": semantic_answers_dict.get(
|
||||
result.get(FIELDS_ID, ""),
|
||||
"",
|
||||
),
|
||||
},
|
||||
},
|
||||
),
|
||||
float(result["@search.score"]),
|
||||
float(result["@search.reranker_score"]),
|
||||
results = await self.async_client.search(
|
||||
search_text=query,
|
||||
vector_queries=[
|
||||
VectorizedQuery(
|
||||
vector=np.array(vector, dtype=np.float32).tolist(),
|
||||
k_nearest_neighbors=k,
|
||||
fields=FIELDS_CONTENT_VECTOR,
|
||||
)
|
||||
async for result in results
|
||||
]
|
||||
return docs
|
||||
],
|
||||
filter=filters,
|
||||
query_type="semantic",
|
||||
semantic_configuration_name=self.semantic_configuration_name,
|
||||
query_caption="extractive",
|
||||
query_answer="extractive",
|
||||
top=k,
|
||||
**kwargs,
|
||||
)
|
||||
# Get Semantic Answers
|
||||
semantic_answers = (await results.get_answers()) or []
|
||||
semantic_answers_dict: Dict = {}
|
||||
for semantic_answer in semantic_answers:
|
||||
semantic_answers_dict[semantic_answer.key] = {
|
||||
"text": semantic_answer.text,
|
||||
"highlights": semantic_answer.highlights,
|
||||
}
|
||||
# Convert results to Document objects
|
||||
docs = [
|
||||
(
|
||||
Document(
|
||||
page_content=result.pop(FIELDS_CONTENT),
|
||||
metadata={
|
||||
**(
|
||||
json.loads(result[FIELDS_METADATA])
|
||||
if FIELDS_METADATA in result
|
||||
else {
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k != FIELDS_CONTENT_VECTOR
|
||||
}
|
||||
),
|
||||
**{
|
||||
"captions": {
|
||||
"text": result.get("@search.captions", [{}])[0].text,
|
||||
"highlights": result.get("@search.captions", [{}])[
|
||||
0
|
||||
].highlights,
|
||||
}
|
||||
if result.get("@search.captions")
|
||||
else {},
|
||||
"answers": semantic_answers_dict.get(
|
||||
result.get(FIELDS_ID, ""),
|
||||
"",
|
||||
),
|
||||
},
|
||||
},
|
||||
),
|
||||
float(result["@search.score"]),
|
||||
float(result["@search.reranker_score"]),
|
||||
)
|
||||
async for result in results
|
||||
]
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
@ -1629,6 +1653,19 @@ def _results_to_documents(
|
||||
return docs
|
||||
|
||||
|
||||
async def _aresults_to_documents(
|
||||
results: AsyncSearchItemPaged[Dict],
|
||||
) -> List[Tuple[Document, float]]:
|
||||
docs = [
|
||||
(
|
||||
_result_to_document(result),
|
||||
float(result["@search.score"]),
|
||||
)
|
||||
async for result in results
|
||||
]
|
||||
return docs
|
||||
|
||||
|
||||
async def _areorder_results_with_maximal_marginal_relevance(
|
||||
results: SearchItemPaged[Dict],
|
||||
query_embedding: np.ndarray,
|
||||
@ -1642,7 +1679,7 @@ async def _areorder_results_with_maximal_marginal_relevance(
|
||||
float(result["@search.score"]),
|
||||
result[FIELDS_CONTENT_VECTOR],
|
||||
)
|
||||
for result in results
|
||||
async for result in results
|
||||
]
|
||||
documents, scores, vectors = map(list, zip(*docs))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user