diff --git a/libs/core/langchain_core/indexing/__init__.py b/libs/core/langchain_core/indexing/__init__.py index d67ec5eec0..bc581297d7 100644 --- a/libs/core/langchain_core/indexing/__init__.py +++ b/libs/core/langchain_core/indexing/__init__.py @@ -5,11 +5,12 @@ a vectorstore while avoiding duplicated content and over-writing content if it's unchanged. """ from langchain_core.indexing.api import IndexingResult, aindex, index -from langchain_core.indexing.base import RecordManager +from langchain_core.indexing.base import InMemoryRecordManager, RecordManager __all__ = [ "aindex", "index", "IndexingResult", + "InMemoryRecordManager", "RecordManager", ] diff --git a/libs/core/langchain_core/indexing/base.py b/libs/core/langchain_core/indexing/base.py index ac73191bbe..aae9a32824 100644 --- a/libs/core/langchain_core/indexing/base.py +++ b/libs/core/langchain_core/indexing/base.py @@ -1,7 +1,8 @@ from __future__ import annotations +import time from abc import ABC, abstractmethod -from typing import List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, TypedDict class RecordManager(ABC): @@ -215,3 +216,104 @@ class RecordManager(ABC): Args: keys: A list of keys to delete. """ + + +class _Record(TypedDict): + group_id: Optional[str] + updated_at: float + + +class InMemoryRecordManager(RecordManager): + """An in-memory record manager for testing purposes.""" + + def __init__(self, namespace: str) -> None: + super().__init__(namespace) + # Each key points to a dictionary + # of {'group_id': group_id, 'updated_at': timestamp} + self.records: Dict[str, _Record] = {} + self.namespace = namespace + + def create_schema(self) -> None: + """In-memory schema creation is simply ensuring the structure is initialized.""" + + async def acreate_schema(self) -> None: + """In-memory schema creation is simply ensuring the structure is initialized.""" + + def get_time(self) -> float: + """Get the current server time as a high resolution timestamp!""" + return time.time() + + async def aget_time(self) -> float: + """Get the current server time as a high resolution timestamp!""" + return self.get_time() + + def update( + self, + keys: Sequence[str], + *, + group_ids: Optional[Sequence[Optional[str]]] = None, + time_at_least: Optional[float] = None, + ) -> None: + if group_ids and len(keys) != len(group_ids): + raise ValueError("Length of keys must match length of group_ids") + for index, key in enumerate(keys): + group_id = group_ids[index] if group_ids else None + if time_at_least and time_at_least > self.get_time(): + raise ValueError("time_at_least must be in the past") + self.records[key] = {"group_id": group_id, "updated_at": self.get_time()} + + async def aupdate( + self, + keys: Sequence[str], + *, + group_ids: Optional[Sequence[Optional[str]]] = None, + time_at_least: Optional[float] = None, + ) -> None: + self.update(keys, group_ids=group_ids, time_at_least=time_at_least) + + def exists(self, keys: Sequence[str]) -> List[bool]: + return [key in self.records for key in keys] + + async def aexists(self, keys: Sequence[str]) -> List[bool]: + return self.exists(keys) + + def list_keys( + self, + *, + before: Optional[float] = None, + after: Optional[float] = None, + group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, + ) -> List[str]: + result = [] + for key, data in self.records.items(): + if before and data["updated_at"] >= before: + continue + if after and data["updated_at"] <= after: + continue + if group_ids and data["group_id"] not in group_ids: + continue + result.append(key) + if limit: + return result[:limit] + return result + + async def alist_keys( + self, + *, + before: Optional[float] = None, + after: Optional[float] = None, + group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, + ) -> List[str]: + return self.list_keys( + before=before, after=after, group_ids=group_ids, limit=limit + ) + + def delete_keys(self, keys: Sequence[str]) -> None: + for key in keys: + if key in self.records: + del self.records[key] + + async def adelete_keys(self, keys: Sequence[str]) -> None: + self.delete_keys(keys) diff --git a/libs/core/tests/unit_tests/indexing/in_memory.py b/libs/core/tests/unit_tests/indexing/in_memory.py deleted file mode 100644 index c9d55f51a7..0000000000 --- a/libs/core/tests/unit_tests/indexing/in_memory.py +++ /dev/null @@ -1,105 +0,0 @@ -import time -from typing import Dict, List, Optional, Sequence, TypedDict - -from langchain_core.indexing.base import RecordManager - - -class _Record(TypedDict): - group_id: Optional[str] - updated_at: float - - -class InMemoryRecordManager(RecordManager): - """An in-memory record manager for testing purposes.""" - - def __init__(self, namespace: str) -> None: - super().__init__(namespace) - # Each key points to a dictionary - # of {'group_id': group_id, 'updated_at': timestamp} - self.records: Dict[str, _Record] = {} - self.namespace = namespace - - def create_schema(self) -> None: - """In-memory schema creation is simply ensuring the structure is initialized.""" - - async def acreate_schema(self) -> None: - """In-memory schema creation is simply ensuring the structure is initialized.""" - - def get_time(self) -> float: - """Get the current server time as a high resolution timestamp!""" - return time.time() - - async def aget_time(self) -> float: - """Get the current server time as a high resolution timestamp!""" - return self.get_time() - - def update( - self, - keys: Sequence[str], - *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, - ) -> None: - if group_ids and len(keys) != len(group_ids): - raise ValueError("Length of keys must match length of group_ids") - for index, key in enumerate(keys): - group_id = group_ids[index] if group_ids else None - if time_at_least and time_at_least > self.get_time(): - raise ValueError("time_at_least must be in the past") - self.records[key] = {"group_id": group_id, "updated_at": self.get_time()} - - async def aupdate( - self, - keys: Sequence[str], - *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, - ) -> None: - self.update(keys, group_ids=group_ids, time_at_least=time_at_least) - - def exists(self, keys: Sequence[str]) -> List[bool]: - return [key in self.records for key in keys] - - async def aexists(self, keys: Sequence[str]) -> List[bool]: - return self.exists(keys) - - def list_keys( - self, - *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - ) -> List[str]: - result = [] - for key, data in self.records.items(): - if before and data["updated_at"] >= before: - continue - if after and data["updated_at"] <= after: - continue - if group_ids and data["group_id"] not in group_ids: - continue - result.append(key) - if limit: - return result[:limit] - return result - - async def alist_keys( - self, - *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - ) -> List[str]: - return self.list_keys( - before=before, after=after, group_ids=group_ids, limit=limit - ) - - def delete_keys(self, keys: Sequence[str]) -> None: - for key in keys: - if key in self.records: - del self.records[key] - - async def adelete_keys(self, keys: Sequence[str]) -> None: - self.delete_keys(keys) diff --git a/libs/core/tests/unit_tests/indexing/test_in_memory_record_manager.py b/libs/core/tests/unit_tests/indexing/test_in_memory_record_manager.py index ea88724513..1dd001068f 100644 --- a/libs/core/tests/unit_tests/indexing/test_in_memory_record_manager.py +++ b/libs/core/tests/unit_tests/indexing/test_in_memory_record_manager.py @@ -4,7 +4,7 @@ from unittest.mock import patch import pytest import pytest_asyncio -from tests.unit_tests.indexing.in_memory import InMemoryRecordManager +from langchain_core.indexing import InMemoryRecordManager @pytest.fixture() diff --git a/libs/core/tests/unit_tests/indexing/test_indexing.py b/libs/core/tests/unit_tests/indexing/test_indexing.py index 701204363a..c7170e22f7 100644 --- a/libs/core/tests/unit_tests/indexing/test_indexing.py +++ b/libs/core/tests/unit_tests/indexing/test_indexing.py @@ -18,10 +18,9 @@ import pytest_asyncio from langchain_core.document_loaders.base import BaseLoader from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.indexing import aindex, index +from langchain_core.indexing import InMemoryRecordManager, aindex, index from langchain_core.indexing.api import _abatch, _HashedDocument from langchain_core.vectorstores import VST, VectorStore -from tests.unit_tests.indexing.in_memory import InMemoryRecordManager class ToyLoader(BaseLoader): diff --git a/libs/core/tests/unit_tests/indexing/test_public_api.py b/libs/core/tests/unit_tests/indexing/test_public_api.py index 8c0b367dd8..89c52cf681 100644 --- a/libs/core/tests/unit_tests/indexing/test_public_api.py +++ b/libs/core/tests/unit_tests/indexing/test_public_api.py @@ -8,5 +8,6 @@ def test_all() -> None: "aindex", "index", "IndexingResult", + "InMemoryRecordManager", "RecordManager", ]