Add base storage interface, 2 implementations and utility encoder (#8895)

This PR defines an abstract interface for key value stores.

It provides 2 implementations: 
1. Local File System
2. In memory -- used to facilitate testing

It also provides an encoder utility to help take care of serialization
from arbitrary data to data that can be stored by the given store
wfh/async_eval_default
Eugene Yurtsev 1 year ago committed by GitHub
parent 7543a3d70e
commit 15f650ae8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
"""**Schemas** are the LangChain Base Classes and Interfaces.""" """**Schemas** are the LangChain Base Classes and Interfaces."""
from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import BaseDocumentTransformer, Document from langchain.schema.document import BaseDocumentTransformer, Document
from langchain.schema.exceptions import LangChainException
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory from langchain.schema.memory import BaseChatMessageHistory, BaseMemory
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
@ -31,12 +32,14 @@ from langchain.schema.output_parser import (
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.prompt_template import BasePromptTemplate, format_document from langchain.schema.prompt_template import BasePromptTemplate, format_document
from langchain.schema.retriever import BaseRetriever from langchain.schema.retriever import BaseRetriever
from langchain.schema.storage import BaseStore
RUN_KEY = "__run" RUN_KEY = "__run"
Memory = BaseMemory Memory = BaseMemory
__all__ = [ __all__ = [
"BaseMemory", "BaseMemory",
"BaseStore",
"BaseChatMessageHistory", "BaseChatMessageHistory",
"AgentFinish", "AgentFinish",
"AgentAction", "AgentAction",
@ -59,6 +62,7 @@ __all__ = [
"ChatGeneration", "ChatGeneration",
"Generation", "Generation",
"PromptValue", "PromptValue",
"LangChainException",
"BaseRetriever", "BaseRetriever",
"RUN_KEY", "RUN_KEY",
"Memory", "Memory",

@ -0,0 +1,2 @@
class LangChainException(Exception):
"""General LangChain exception."""

@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union
K = TypeVar("K")
V = TypeVar("V")
class BaseStore(Generic[K, V], ABC):
"""Abstract interface for a key-value store."""
@abstractmethod
def mget(self, keys: Sequence[K]) -> List[Optional[V]]:
"""Get the values associated with the given keys.
Args:
keys (Sequence[K]): A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
@abstractmethod
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
@abstractmethod
def mdelete(self, keys: Sequence[K]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[K]): A sequence of keys to delete.
"""
@abstractmethod
def yield_keys(
self, *, prefix: Optional[str] = None
) -> Union[Iterator[K], Iterator[str]]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (str): The prefix to match.
Returns:
Iterator[K | str]: An iterator over keys that match the given prefix.
This method is allowed to return an iterator over either K or str
depending on what makes more sense for the given store.
"""

@ -0,0 +1,17 @@
"""Implementations of key-value stores and storage helpers.
Module provides implementations of various key-value stores that conform
to a simple key-value interface.
The primary goal of these storages is to support implementation of caching.
"""
from langchain.storage.encoder_backed import EncoderBackedStore
from langchain.storage.file_system import LocalFileStore
from langchain.storage.in_memory import InMemoryStore
__all__ = [
"EncoderBackedStore",
"LocalFileStore",
"InMemoryStore",
]

@ -0,0 +1,95 @@
from typing import (
Any,
Callable,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from langchain.schema import BaseStore
K = TypeVar("K")
V = TypeVar("V")
class EncoderBackedStore(BaseStore[K, V]):
"""Wraps a store with key and value encoders/decoders.
Examples that uses JSON for encoding/decoding:
.. code-block:: python
import json
def key_encoder(key: int) -> str:
return json.dumps(key)
def value_serializer(value: float) -> str:
return json.dumps(value)
def value_deserializer(serialized_value: str) -> float:
return json.loads(serialized_value)
# Create an instance of the abstract store
abstract_store = MyCustomStore()
# Create an instance of the encoder-backed store
store = EncoderBackedStore(
store=abstract_store,
key_encoder=key_encoder,
value_serializer=value_serializer,
value_deserializer=value_deserializer
)
# Use the encoder-backed store methods
store.mset([(1, 3.14), (2, 2.718)])
values = store.mget([1, 2]) # Retrieves [3.14, 2.718]
store.mdelete([1, 2]) # Deletes the keys 1 and 2
"""
def __init__(
self,
store: BaseStore[str, Any],
key_encoder: Callable[[K], str],
value_serializer: Callable[[V], bytes],
value_deserializer: Callable[[Any], V],
) -> None:
"""Initialize an EncodedStore."""
self.store = store
self.key_encoder = key_encoder
self.value_serializer = value_serializer
self.value_deserializer = value_deserializer
def mget(self, keys: Sequence[K]) -> List[Optional[V]]:
"""Get the values associated with the given keys."""
encoded_keys: List[str] = [self.key_encoder(key) for key in keys]
values = self.store.mget(encoded_keys)
return [
self.value_deserializer(value) if value is not None else value
for value in values
]
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
"""Set the values for the given keys."""
encoded_pairs = [
(self.key_encoder(key), self.value_serializer(value))
for key, value in key_value_pairs
]
self.store.mset(encoded_pairs)
def mdelete(self, keys: Sequence[K]) -> None:
"""Delete the given keys and their associated values."""
encoded_keys = [self.key_encoder(key) for key in keys]
self.store.mdelete(encoded_keys)
def yield_keys(
self, *, prefix: Optional[str] = None
) -> Union[Iterator[K], Iterator[str]]:
"""Get an iterator over keys that match the given prefix."""
# For the time being this does not return K, but str
# it's for debugging purposes. Should fix this.
yield from self.store.yield_keys(prefix=prefix)

@ -0,0 +1,5 @@
from langchain.schema import LangChainException
class InvalidKeyException(LangChainException):
"""Raised when a key is invalid; e.g., uses incorrect characters."""

@ -0,0 +1,120 @@
import re
from pathlib import Path
from typing import Iterator, List, Optional, Sequence, Tuple, Union
from langchain.schema import BaseStore
from langchain.storage.exceptions import InvalidKeyException
class LocalFileStore(BaseStore[str, bytes]):
"""BaseStore interface that works on the local file system.
Examples:
Create a LocalFileStore instance and perform operations on it:
.. code-block:: python
from langchain.storage import LocalFileStore
# Instantiate the LocalFileStore with the root path
file_store = LocalFileStore("/path/to/root")
# Set values for keys
file_store.mset([("key1", b"value1"), ("key2", b"value2")])
# Get values for keys
values = file_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
# Delete keys
file_store.mdelete(["key1"])
# Iterate over keys
for key in file_store.yield_keys():
print(key)
"""
def __init__(self, root_path: Union[str, Path]) -> None:
"""Implement the BaseStore interface for the local file system.
Args:
root_path (Union[str, Path]): The root path of the file store. All keys are
interpreted as paths relative to this root.
"""
self.root_path = Path(root_path)
def _get_full_path(self, key: str) -> Path:
"""Get the full path for a given key relative to the root path.
Args:
key (str): The key relative to the root path.
Returns:
Path: The full path for the given key.
"""
if not re.match(r"^[a-zA-Z0-9_.\-/]+$", key):
raise InvalidKeyException(f"Invalid characters in key: {key}")
return self.root_path / key
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
"""Get the values associated with the given keys.
Args:
keys: A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
values: List[Optional[bytes]] = []
for key in keys:
full_path = self._get_full_path(key)
if full_path.exists():
value = full_path.read_bytes()
values.append(value)
else:
values.append(None)
return values
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs: A sequence of key-value pairs.
Returns:
None
"""
for key, value in key_value_pairs:
full_path = self._get_full_path(key)
full_path.parent.mkdir(parents=True, exist_ok=True)
full_path.write_bytes(value)
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[str]): A sequence of keys to delete.
Returns:
None
"""
for key in keys:
full_path = self._get_full_path(key)
if full_path.exists():
full_path.unlink()
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (Optional[str]): The prefix to match.
Returns:
Iterator[str]: An iterator over keys that match the given prefix.
"""
prefix_path = self._get_full_path(prefix) if prefix else self.root_path
for file in prefix_path.rglob("*"):
if file.is_file():
relative_path = file.relative_to(self.root_path)
yield str(relative_path)

@ -0,0 +1,85 @@
"""In memory store that is not thread safe and has no eviction policy.
This is a simple implementation of the BaseStore using a dictionary that is useful
primarily for unit testing purposes.
"""
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
from langchain.schema import BaseStore
class InMemoryStore(BaseStore[str, Any]):
"""In-memory implementation of the BaseStore using a dictionary.
Attributes:
store (Dict[str, Any]): The underlying dictionary that stores
the key-value pairs.
Examples:
... code-block:: python
from langchain.storage import InMemoryStore
store = InMemoryStore()
store.mset([('key1', 'value1'), ('key2', 'value2')])
store.mget(['key1', 'key2'])
# ['value1', 'value2']
store.mdelete(['key1'])
list(store.yield_keys())
# ['key2']
list(store.yield_keys(prefix='k'))
# ['key2']
"""
def __init__(self) -> None:
"""Initialize an empty store."""
self.store: Dict[str, Any] = {}
def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
"""Get the values associated with the given keys.
Args:
keys (Sequence[str]): A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
return [self.store.get(key) for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
Returns:
None
"""
for key, value in key_value_pairs:
self.store[key] = value
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[str]): A sequence of keys to delete.
"""
for key in keys:
self.store.pop(key, None)
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (str, optional): The prefix to match. Defaults to None.
Returns:
Iterator[str]: An iterator over keys that match the given prefix.
"""
if prefix is None:
yield from self.store.keys()
else:
for key in self.store.keys():
if key.startswith(prefix):
yield key

@ -0,0 +1,78 @@
import tempfile
from typing import Generator
import pytest
from langchain.storage.exceptions import InvalidKeyException
from langchain.storage.file_system import LocalFileStore
@pytest.fixture
def file_store() -> Generator[LocalFileStore, None, None]:
# Create a temporary directory for testing
with tempfile.TemporaryDirectory() as temp_dir:
# Instantiate the LocalFileStore with the temporary directory as the root path
store = LocalFileStore(temp_dir)
yield store
def test_mset_and_mget(file_store: LocalFileStore) -> None:
# Set values for keys
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
file_store.mset(key_value_pairs)
# Get values for keys
values = file_store.mget(["key1", "key2"])
# Assert that the retrieved values match the original values
assert values == [b"value1", b"value2"]
def test_mdelete(file_store: LocalFileStore) -> None:
# Set values for keys
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
file_store.mset(key_value_pairs)
# Delete keys
file_store.mdelete(["key1"])
# Check if the deleted key is present
values = file_store.mget(["key1"])
# Assert that the value is None after deletion
assert values == [None]
def test_set_invalid_key(file_store: LocalFileStore) -> None:
"""Test that an exception is raised when an invalid key is set."""
# Set a key-value pair
key = "crying-cat/😿"
value = b"This is a test value"
with pytest.raises(InvalidKeyException):
file_store.mset([(key, value)])
def test_set_key_and_verify_content(file_store: LocalFileStore) -> None:
"""Test that the content of the file is the same as the value set."""
# Set a key-value pair
key = "test_key"
value = b"This is a test value"
file_store.mset([(key, value)])
# Verify the content of the actual file
full_path = file_store._get_full_path(key)
assert full_path.exists()
assert full_path.read_bytes() == b"This is a test value"
def test_yield_keys(file_store: LocalFileStore) -> None:
# Set values for keys
key_value_pairs = [("key1", b"value1"), ("subdir/key2", b"value2")]
file_store.mset(key_value_pairs)
# Iterate over keys
keys = list(file_store.yield_keys())
# Assert that the yielded keys match the expected keys
expected_keys = ["key1", "subdir/key2"]
assert keys == expected_keys

@ -0,0 +1,48 @@
from langchain.storage.in_memory import InMemoryStore
def test_mget() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2")])
values = store.mget(["key1", "key2"])
assert values == ["value1", "value2"]
# Test non-existent key
non_existent_value = store.mget(["key3"])
assert non_existent_value == [None]
def test_mset() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2")])
values = store.mget(["key1", "key2"])
assert values == ["value1", "value2"]
def test_mdelete() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2")])
store.mdelete(["key1"])
values = store.mget(["key1", "key2"])
assert values == [None, "value2"]
# Test deleting non-existent key
store.mdelete(["key3"]) # No error should be raised
def test_yield_keys() -> None:
store = InMemoryStore()
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
keys = list(store.yield_keys())
assert set(keys) == {"key1", "key2", "key3"}
keys_with_prefix = list(store.yield_keys(prefix="key"))
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
assert keys_with_invalid_prefix == []
Loading…
Cancel
Save