Add async methods for the AstraDB VectorStore (#16391)

- **Description**: fully async versions are available for astrapy 0.7+.
For older astrapy versions or if the user provides a sync client without
an async one, the async methods will call the sync ones wrapped in
`run_in_executor`
  - **Twitter handle:** cbornet_
pull/16763/head
Christophe Bornet 5 months ago committed by GitHub
parent f8f2649f12
commit 744070ee85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

File diff suppressed because it is too large Load Diff

@ -148,6 +148,33 @@ class TestAstraDB:
)
v_store_2.delete_collection()
async def test_astradb_vectorstore_create_delete_async(self) -> None:
"""Create and delete."""
emb = SomeEmbeddings(dimension=2)
# creation by passing the connection secrets
v_store = AstraDB(
embedding=emb,
collection_name="lc_test_1_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
await v_store.adelete_collection()
# Creation by passing a ready-made astrapy client:
from astrapy.db import AsyncAstraDB
astra_db_client = AsyncAstraDB(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
v_store_2 = AstraDB(
embedding=emb,
collection_name="lc_test_2_async",
async_astra_db_client=astra_db_client,
)
await v_store_2.adelete_collection()
def test_astradb_vectorstore_pre_delete_collection(self) -> None:
"""Create and delete."""
emb = SomeEmbeddings(dimension=2)
@ -183,6 +210,41 @@ class TestAstraDB:
finally:
v_store.delete_collection()
async def test_astradb_vectorstore_pre_delete_collection_async(self) -> None:
"""Create and delete."""
emb = SomeEmbeddings(dimension=2)
# creation by passing the connection secrets
v_store = AstraDB(
embedding=emb,
collection_name="lc_test_pre_del_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
try:
await v_store.aadd_texts(
texts=["aa"],
metadatas=[
{"k": "a", "ord": 0},
],
ids=["a"],
)
res1 = await v_store.asimilarity_search("aa", k=5)
assert len(res1) == 1
v_store = AstraDB(
embedding=emb,
pre_delete_collection=True,
collection_name="lc_test_pre_del_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
res1 = await v_store.asimilarity_search("aa", k=5)
assert len(res1) == 0
finally:
await v_store.adelete_collection()
def test_astradb_vectorstore_from_x(self) -> None:
"""from_texts and from_documents methods."""
emb = SomeEmbeddings(dimension=2)
@ -200,7 +262,7 @@ class TestAstraDB:
finally:
v_store.delete_collection()
# from_texts
# from_documents
v_store_2 = AstraDB.from_documents(
[
Document(page_content="Hee"),
@ -217,6 +279,42 @@ class TestAstraDB:
finally:
v_store_2.delete_collection()
async def test_astradb_vectorstore_from_x_async(self) -> None:
"""from_texts and from_documents methods."""
emb = SomeEmbeddings(dimension=2)
# from_texts
v_store = await AstraDB.afrom_texts(
texts=["Hi", "Ho"],
embedding=emb,
collection_name="lc_test_ft_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
try:
assert (await v_store.asimilarity_search("Ho", k=1))[0].page_content == "Ho"
finally:
await v_store.adelete_collection()
# from_documents
v_store_2 = await AstraDB.afrom_documents(
[
Document(page_content="Hee"),
Document(page_content="Hoi"),
],
embedding=emb,
collection_name="lc_test_fd_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
try:
assert (await v_store_2.asimilarity_search("Hoi", k=1))[
0
].page_content == "Hoi"
finally:
await v_store_2.adelete_collection()
def test_astradb_vectorstore_crud(self, store_someemb: AstraDB) -> None:
"""Basic add/delete/update behaviour."""
res0 = store_someemb.similarity_search("Abc", k=2)
@ -275,25 +373,106 @@ class TestAstraDB:
res4 = store_someemb.similarity_search("ww", k=1, filter={"k": "w"})
assert res4[0].metadata["ord"] == 205
def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDB) -> None:
"""
MMR testing. We work on the unit circle with angle multiples
of 2*pi/20 and prepare a store with known vectors for a controlled
MMR outcome.
"""
async def test_astradb_vectorstore_crud_async(self, store_someemb: AstraDB) -> None:
"""Basic add/delete/update behaviour."""
res0 = await store_someemb.asimilarity_search("Abc", k=2)
assert res0 == []
# write and check again
await store_someemb.aadd_texts(
texts=["aa", "bb", "cc"],
metadatas=[
{"k": "a", "ord": 0},
{"k": "b", "ord": 1},
{"k": "c", "ord": 2},
],
ids=["a", "b", "c"],
)
res1 = await store_someemb.asimilarity_search("Abc", k=5)
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
# partial overwrite and count total entries
await store_someemb.aadd_texts(
texts=["cc", "dd"],
metadatas=[
{"k": "c_new", "ord": 102},
{"k": "d_new", "ord": 103},
],
ids=["c", "d"],
)
res2 = await store_someemb.asimilarity_search("Abc", k=10)
assert len(res2) == 4
# pick one that was just updated and check its metadata
res3 = await store_someemb.asimilarity_search_with_score_id(
query="cc", k=1, filter={"k": "c_new"}
)
print(str(res3))
doc3, score3, id3 = res3[0]
assert doc3.page_content == "cc"
assert doc3.metadata == {"k": "c_new", "ord": 102}
assert score3 > 0.999 # leaving some leeway for approximations...
assert id3 == "c"
# delete and count again
del1_res = await store_someemb.adelete(["b"])
assert del1_res is True
del2_res = await store_someemb.adelete(["a", "c", "Z!"])
assert del2_res is False # a non-existing ID was supplied
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 1
# clear store
await store_someemb.aclear()
assert await store_someemb.asimilarity_search("Abc", k=2) == []
# add_documents with "ids" arg passthrough
await store_someemb.aadd_documents(
[
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
],
ids=["v", "w"],
)
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 2
res4 = await store_someemb.asimilarity_search("ww", k=1, filter={"k": "w"})
assert res4[0].metadata["ord"] == 205
@staticmethod
def _v_from_i(i: int, N: int) -> str:
angle = 2 * math.pi * i / N
vector = [math.cos(angle), math.sin(angle)]
return json.dumps(vector)
def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDB) -> None:
"""
MMR testing. We work on the unit circle with angle multiples
of 2*pi/20 and prepare a store with known vectors for a controlled
MMR outcome.
"""
i_vals = [0, 4, 5, 13]
N_val = 20
store_parseremb.add_texts(
[_v_from_i(i, N_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals]
[self._v_from_i(i, N_val) for i in i_vals],
metadatas=[{"i": i} for i in i_vals],
)
res1 = store_parseremb.max_marginal_relevance_search(
_v_from_i(3, N_val),
self._v_from_i(3, N_val),
k=2,
fetch_k=3,
)
res_i_vals = {doc.metadata["i"] for doc in res1}
assert res_i_vals == {0, 4}
async def test_astradb_vectorstore_mmr_async(
self, store_parseremb: AstraDB
) -> None:
"""
MMR testing. We work on the unit circle with angle multiples
of 2*pi/20 and prepare a store with known vectors for a controlled
MMR outcome.
"""
i_vals = [0, 4, 5, 13]
N_val = 20
await store_parseremb.aadd_texts(
[self._v_from_i(i, N_val) for i in i_vals],
metadatas=[{"i": i} for i in i_vals],
)
res1 = await store_parseremb.amax_marginal_relevance_search(
self._v_from_i(3, N_val),
k=2,
fetch_k=3,
)
@ -381,6 +560,25 @@ class TestAstraDB:
sco_near, sco_far = scores
assert abs(1 - sco_near) < 0.001 and abs(sco_far) < 0.001
async def test_astradb_vectorstore_similarity_scale_async(
self, store_parseremb: AstraDB
) -> None:
"""Scale of the similarity scores."""
await store_parseremb.aadd_texts(
texts=[
json.dumps([1, 1]),
json.dumps([-1, -1]),
],
ids=["near", "far"],
)
res1 = await store_parseremb.asimilarity_search_with_score(
json.dumps([0.5, 0.5]),
k=2,
)
scores = [sco for _, sco in res1]
sco_near, sco_far = scores
assert abs(1 - sco_near) < 0.001 and abs(sco_far) < 0.001
def test_astradb_vectorstore_massive_delete(self, store_someemb: AstraDB) -> None:
"""Larger-scale bulk deletes."""
M = 50
@ -458,6 +656,40 @@ class TestAstraDB:
finally:
v_store.delete_collection()
async def test_astradb_vectorstore_custom_params_async(self) -> None:
"""Custom batch size and concurrency params."""
emb = SomeEmbeddings(dimension=2)
v_store = AstraDB(
embedding=emb,
collection_name="lc_test_c_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
batch_size=17,
bulk_insert_batch_concurrency=13,
bulk_insert_overwrite_concurrency=7,
bulk_delete_concurrency=19,
)
try:
# add_texts
N = 50
texts = [str(i + 1 / 7.0) for i in range(N)]
ids = ["doc_%i" % i for i in range(N)]
await v_store.aadd_texts(texts=texts, ids=ids)
await v_store.aadd_texts(
texts=texts,
ids=ids,
batch_size=19,
batch_concurrency=7,
overwrite_concurrency=13,
)
#
await v_store.adelete(ids[: N // 2])
await v_store.adelete(ids[N // 2 :], concurrency=23)
#
finally:
await v_store.adelete_collection()
def test_astradb_vectorstore_metrics(self) -> None:
"""
Different choices of similarity metric.

Loading…
Cancel
Save