Add async methods to CacheBackedEmbeddings (#16873)

Adds async methods to CacheBackedEmbeddings
pull/17091/head^2
Christophe Bornet 4 months ago committed by GitHub
parent dd68a8716e
commit a8f530bc4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -128,6 +128,41 @@ class CacheBackedEmbeddings(Embeddings):
List[List[float]], vectors
) # Nones should have been resolved by now
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of texts.
The method first checks the cache for the embeddings.
If the embeddings are not found, the method uses the underlying embedder
to embed the documents and stores the results in the cache.
Args:
texts: A list of texts to embed.
Returns:
A list of embeddings for the given texts.
"""
vectors: List[
Union[List[float], None]
] = await self.document_embedding_store.amget(texts)
missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
missing_texts = [texts[i] for i in missing_indices]
if missing_texts:
missing_vectors = await self.underlying_embeddings.aembed_documents(
missing_texts
)
await self.document_embedding_store.amset(
list(zip(missing_texts, missing_vectors))
)
for index, updated_vector in zip(missing_indices, missing_vectors):
vectors[index] = updated_vector
return cast(
List[List[float]], vectors
) # Nones should have been resolved by now
def embed_query(self, text: str) -> List[float]:
"""Embed query text.
@ -148,6 +183,26 @@ class CacheBackedEmbeddings(Embeddings):
"""
return self.underlying_embeddings.embed_query(text)
async def aembed_query(self, text: str) -> List[float]:
"""Embed query text.
This method does not support caching at the moment.
Support for caching queries is easily to implement, but might make
sense to hold off to see the most common patterns.
If the cache has an eviction policy, we may need to be a bit more careful
about sharing the cache between documents and queries. Generally,
one is OK evicting query caches, but document caches should be kept.
Args:
text: The text to embed.
Returns:
The embedding for the given text.
"""
return await self.underlying_embeddings.aembed_query(text)
@classmethod
def from_bytes_store(
cls,

@ -47,3 +47,23 @@ def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
vector = cache_embeddings.embed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"]
vectors = await cache_embeddings.aembed_documents(texts)
expected_vectors: List[List[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
assert vectors == expected_vectors
keys = [
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
]
assert len(keys) == 4
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = await cache_embeddings.aembed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector

Loading…
Cancel
Save