core[minor]: Adds an in-memory implementation of RecordManager (#13200)

**Description:**
langchain offers three technologies to save data:
-
[vectorstore](https://python.langchain.com/docs/modules/data_connection/vectorstores/)
- [docstore](https://js.langchain.com/docs/api/schema/classes/Docstore)
- [record
manager](https://python.langchain.com/docs/modules/data_connection/indexing)

If you want to combine these technologies in a sample persistence
stategy you need a common implementation for each. `DocStore` propose
`InMemoryDocstore`.

We propose the class `MemoryRecordManager` to complete the system.

This is the prelude to another full-request, which needs a consistent
combination of persistence components.

**Tag maintainer:**
@baskaryan

**Twitter handle:**
@pprados

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/23226/head
Philippe PRADOS 3 months ago committed by GitHub
parent 3ab49c0036
commit 8711c61298
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,11 +5,12 @@ a vectorstore while avoiding duplicated content and over-writing content
if it's unchanged. if it's unchanged.
""" """
from langchain_core.indexing.api import IndexingResult, aindex, index from langchain_core.indexing.api import IndexingResult, aindex, index
from langchain_core.indexing.base import RecordManager from langchain_core.indexing.base import InMemoryRecordManager, RecordManager
__all__ = [ __all__ = [
"aindex", "aindex",
"index", "index",
"IndexingResult", "IndexingResult",
"InMemoryRecordManager",
"RecordManager", "RecordManager",
] ]

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Sequence from typing import Dict, List, Optional, Sequence, TypedDict
class RecordManager(ABC): class RecordManager(ABC):
@ -215,3 +216,104 @@ class RecordManager(ABC):
Args: Args:
keys: A list of keys to delete. keys: A list of keys to delete.
""" """
class _Record(TypedDict):
group_id: Optional[str]
updated_at: float
class InMemoryRecordManager(RecordManager):
"""An in-memory record manager for testing purposes."""
def __init__(self, namespace: str) -> None:
super().__init__(namespace)
# Each key points to a dictionary
# of {'group_id': group_id, 'updated_at': timestamp}
self.records: Dict[str, _Record] = {}
self.namespace = namespace
def create_schema(self) -> None:
"""In-memory schema creation is simply ensuring the structure is initialized."""
async def acreate_schema(self) -> None:
"""In-memory schema creation is simply ensuring the structure is initialized."""
def get_time(self) -> float:
"""Get the current server time as a high resolution timestamp!"""
return time.time()
async def aget_time(self) -> float:
"""Get the current server time as a high resolution timestamp!"""
return self.get_time()
def update(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
if group_ids and len(keys) != len(group_ids):
raise ValueError("Length of keys must match length of group_ids")
for index, key in enumerate(keys):
group_id = group_ids[index] if group_ids else None
if time_at_least and time_at_least > self.get_time():
raise ValueError("time_at_least must be in the past")
self.records[key] = {"group_id": group_id, "updated_at": self.get_time()}
async def aupdate(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
def exists(self, keys: Sequence[str]) -> List[bool]:
return [key in self.records for key in keys]
async def aexists(self, keys: Sequence[str]) -> List[bool]:
return self.exists(keys)
def list_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
result = []
for key, data in self.records.items():
if before and data["updated_at"] >= before:
continue
if after and data["updated_at"] <= after:
continue
if group_ids and data["group_id"] not in group_ids:
continue
result.append(key)
if limit:
return result[:limit]
return result
async def alist_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
return self.list_keys(
before=before, after=after, group_ids=group_ids, limit=limit
)
def delete_keys(self, keys: Sequence[str]) -> None:
for key in keys:
if key in self.records:
del self.records[key]
async def adelete_keys(self, keys: Sequence[str]) -> None:
self.delete_keys(keys)

@ -1,105 +0,0 @@
import time
from typing import Dict, List, Optional, Sequence, TypedDict
from langchain_core.indexing.base import RecordManager
class _Record(TypedDict):
group_id: Optional[str]
updated_at: float
class InMemoryRecordManager(RecordManager):
"""An in-memory record manager for testing purposes."""
def __init__(self, namespace: str) -> None:
super().__init__(namespace)
# Each key points to a dictionary
# of {'group_id': group_id, 'updated_at': timestamp}
self.records: Dict[str, _Record] = {}
self.namespace = namespace
def create_schema(self) -> None:
"""In-memory schema creation is simply ensuring the structure is initialized."""
async def acreate_schema(self) -> None:
"""In-memory schema creation is simply ensuring the structure is initialized."""
def get_time(self) -> float:
"""Get the current server time as a high resolution timestamp!"""
return time.time()
async def aget_time(self) -> float:
"""Get the current server time as a high resolution timestamp!"""
return self.get_time()
def update(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
if group_ids and len(keys) != len(group_ids):
raise ValueError("Length of keys must match length of group_ids")
for index, key in enumerate(keys):
group_id = group_ids[index] if group_ids else None
if time_at_least and time_at_least > self.get_time():
raise ValueError("time_at_least must be in the past")
self.records[key] = {"group_id": group_id, "updated_at": self.get_time()}
async def aupdate(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
def exists(self, keys: Sequence[str]) -> List[bool]:
return [key in self.records for key in keys]
async def aexists(self, keys: Sequence[str]) -> List[bool]:
return self.exists(keys)
def list_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
result = []
for key, data in self.records.items():
if before and data["updated_at"] >= before:
continue
if after and data["updated_at"] <= after:
continue
if group_ids and data["group_id"] not in group_ids:
continue
result.append(key)
if limit:
return result[:limit]
return result
async def alist_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
return self.list_keys(
before=before, after=after, group_ids=group_ids, limit=limit
)
def delete_keys(self, keys: Sequence[str]) -> None:
for key in keys:
if key in self.records:
del self.records[key]
async def adelete_keys(self, keys: Sequence[str]) -> None:
self.delete_keys(keys)

@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from tests.unit_tests.indexing.in_memory import InMemoryRecordManager from langchain_core.indexing import InMemoryRecordManager
@pytest.fixture() @pytest.fixture()

@ -18,10 +18,9 @@ import pytest_asyncio
from langchain_core.document_loaders.base import BaseLoader from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.indexing import aindex, index from langchain_core.indexing import InMemoryRecordManager, aindex, index
from langchain_core.indexing.api import _abatch, _HashedDocument from langchain_core.indexing.api import _abatch, _HashedDocument
from langchain_core.vectorstores import VST, VectorStore from langchain_core.vectorstores import VST, VectorStore
from tests.unit_tests.indexing.in_memory import InMemoryRecordManager
class ToyLoader(BaseLoader): class ToyLoader(BaseLoader):

@ -8,5 +8,6 @@ def test_all() -> None:
"aindex", "aindex",
"index", "index",
"IndexingResult", "IndexingResult",
"InMemoryRecordManager",
"RecordManager", "RecordManager",
] ]

Loading…
Cancel
Save