Add async methods to BaseStore (#16669)

- **Description:**

The BaseStore methods are currently blocking. Some implementations
(AstraDBStore, RedisStore) would benefit from having async methods.
Also once we have async methods for BaseStore, we can implement the
async `aembed_documents` in CacheBackedEmbeddings to cache the
embeddings asynchronously.

* adds async methods amget, amset, amedelete and ayield_keys to
BaseStore
  * implements the async methods for InMemoryStore
  * adds tests for InMemoryStore async methods

- **Twitter handle:** cbornet_
pull/12972/head^2
Christophe Bornet 5 months ago committed by GitHub
parent 17e886388b
commit a0ec045495
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,5 +1,17 @@
from abc import ABC, abstractmethod
from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import (
AsyncIterator,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from langchain_core.runnables import run_in_executor
K = TypeVar("K")
V = TypeVar("V")
@ -20,6 +32,18 @@ class BaseStore(Generic[K, V], ABC):
If a key is not found, the corresponding value will be None.
"""
async def amget(self, keys: Sequence[K]) -> List[Optional[V]]:
"""Get the values associated with the given keys.
Args:
keys (Sequence[K]): A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
return await run_in_executor(None, self.mget, keys)
@abstractmethod
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
"""Set the values for the given keys.
@ -28,6 +52,14 @@ class BaseStore(Generic[K, V], ABC):
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
return await run_in_executor(None, self.mset, key_value_pairs)
@abstractmethod
def mdelete(self, keys: Sequence[K]) -> None:
"""Delete the given keys and their associated values.
@ -36,6 +68,14 @@ class BaseStore(Generic[K, V], ABC):
keys (Sequence[K]): A sequence of keys to delete.
"""
async def amdelete(self, keys: Sequence[K]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[K]): A sequence of keys to delete.
"""
return await run_in_executor(None, self.mdelete, keys)
@abstractmethod
def yield_keys(
self, *, prefix: Optional[str] = None
@ -52,5 +92,27 @@ class BaseStore(Generic[K, V], ABC):
depending on what makes more sense for the given store.
"""
async def ayield_keys(
self, *, prefix: Optional[str] = None
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (str): The prefix to match.
Returns:
Iterator[K | str]: An iterator over keys that match the given prefix.
This method is allowed to return an iterator over either K or str
depending on what makes more sense for the given store.
"""
iterator = await run_in_executor(None, self.yield_keys, prefix=prefix)
done = object()
while True:
item = await run_in_executor(None, lambda it: next(it, done), iterator)
if item is done:
break
yield item
ByteStore = BaseStore[str, bytes]

@ -5,6 +5,7 @@ primarily for unit testing purposes.
"""
from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
@ -60,6 +61,18 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
"""
return [self.store.get(key) for key in keys]
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
"""Get the values associated with the given keys.
Args:
keys (Sequence[str]): A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
return self.mget(keys)
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
"""Set the values for the given keys.
@ -72,6 +85,17 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
for key, value in key_value_pairs:
self.store[key] = value
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
Returns:
None
"""
return self.mset(key_value_pairs)
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
@ -82,6 +106,14 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
if key in self.store:
del self.store[key]
async def amdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[str]): A sequence of keys to delete.
"""
self.mdelete(keys)
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Get an iterator over keys that match the given prefix.
@ -98,6 +130,23 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
if key.startswith(prefix):
yield key
async def ayield_keys(self, prefix: Optional[str] = None) -> AsyncIterator[str]:
"""Get an async iterator over keys that match the given prefix.
Args:
prefix (str, optional): The prefix to match. Defaults to None.
Returns:
AsyncIterator[str]: An async iterator over keys that match the given prefix.
"""
if prefix is None:
for key in self.store.keys():
yield key
else:
for key in self.store.keys():
if key.startswith(prefix):
yield key
InMemoryStore = InMemoryBaseStore[Any]
InMemoryByteStore = InMemoryBaseStore[bytes]

@ -13,6 +13,18 @@ def test_mget() -> None:
assert non_existent_value == [None]
async def test_amget() -> None:
store = InMemoryStore()
await store.amset([("key1", "value1"), ("key2", "value2")])
values = await store.amget(["key1", "key2"])
assert values == ["value1", "value2"]
# Test non-existent key
non_existent_value = await store.amget(["key3"])
assert non_existent_value == [None]
def test_mset() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2")])
@ -21,6 +33,14 @@ def test_mset() -> None:
assert values == ["value1", "value2"]
async def test_amset() -> None:
store = InMemoryStore()
await store.amset([("key1", "value1"), ("key2", "value2")])
values = await store.amget(["key1", "key2"])
assert values == ["value1", "value2"]
def test_mdelete() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2")])
@ -34,6 +54,19 @@ def test_mdelete() -> None:
store.mdelete(["key3"]) # No error should be raised
async def test_amdelete() -> None:
store = InMemoryStore()
await store.amset([("key1", "value1"), ("key2", "value2")])
await store.amdelete(["key1"])
values = await store.amget(["key1", "key2"])
assert values == [None, "value2"]
# Test deleting non-existent key
await store.amdelete(["key3"]) # No error should be raised
def test_yield_keys() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
@ -46,3 +79,17 @@ def test_yield_keys() -> None:
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
assert keys_with_invalid_prefix == []
async def test_ayield_keys() -> None:
store = InMemoryStore()
await store.amset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
keys = [key async for key in store.ayield_keys()]
assert set(keys) == {"key1", "key2", "key3"}
keys_with_prefix = [key async for key in store.ayield_keys(prefix="key")]
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
keys_with_invalid_prefix = [key async for key in store.ayield_keys(prefix="x")]
assert keys_with_invalid_prefix == []

Loading…
Cancel
Save