forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
272 lines
8.8 KiB
Python
272 lines
8.8 KiB
Python
""" Wrapper around scikit-learn NearestNeighbors implementation.
|
|
|
|
The vector store can be persisted in json, bson or parquet format.
|
|
"""
|
|
|
|
import importlib
|
|
import json
|
|
import math
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type
|
|
from uuid import uuid4
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.embeddings.base import Embeddings
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
|
def guard_import(
|
|
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
|
|
) -> Any:
|
|
"""Dynamically imports a module and raises a helpful exception if the module is not
|
|
installed."""
|
|
try:
|
|
module = importlib.import_module(module_name, package)
|
|
except ImportError:
|
|
raise ImportError(
|
|
f"Could not import {module_name} python package. "
|
|
f"Please install it with `pip install {pip_name or module_name}`."
|
|
)
|
|
return module
|
|
|
|
|
|
class BaseSerializer(ABC):
|
|
"""Abstract base class for saving and loading data."""
|
|
|
|
def __init__(self, persist_path: str) -> None:
|
|
self.persist_path = persist_path
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def extension(cls) -> str:
|
|
"""The file extension suggested by this serializer (without dot)."""
|
|
|
|
@abstractmethod
|
|
def save(self, data: Any) -> None:
|
|
"""Saves the data to the persist_path"""
|
|
|
|
@abstractmethod
|
|
def load(self) -> Any:
|
|
"""Loads the data from the persist_path"""
|
|
|
|
|
|
class JsonSerializer(BaseSerializer):
|
|
"""Serializes data in json using the json package from python standard library."""
|
|
|
|
@classmethod
|
|
def extension(cls) -> str:
|
|
return "json"
|
|
|
|
def save(self, data: Any) -> None:
|
|
with open(self.persist_path, "w") as fp:
|
|
json.dump(data, fp)
|
|
|
|
def load(self) -> Any:
|
|
with open(self.persist_path, "r") as fp:
|
|
return json.load(fp)
|
|
|
|
|
|
class BsonSerializer(BaseSerializer):
|
|
"""Serializes data in binary json using the bson python package."""
|
|
|
|
def __init__(self, persist_path: str) -> None:
|
|
super().__init__(persist_path)
|
|
self.bson = guard_import("bson")
|
|
|
|
@classmethod
|
|
def extension(cls) -> str:
|
|
return "bson"
|
|
|
|
def save(self, data: Any) -> None:
|
|
with open(self.persist_path, "wb") as fp:
|
|
fp.write(self.bson.dumps(data))
|
|
|
|
def load(self) -> Any:
|
|
with open(self.persist_path, "rb") as fp:
|
|
return self.bson.loads(fp.read())
|
|
|
|
|
|
class ParquetSerializer(BaseSerializer):
|
|
"""Serializes data in Apache Parquet format using the pyarrow package."""
|
|
|
|
def __init__(self, persist_path: str) -> None:
|
|
super().__init__(persist_path)
|
|
self.pd = guard_import("pandas")
|
|
self.pa = guard_import("pyarrow")
|
|
self.pq = guard_import("pyarrow.parquet")
|
|
|
|
@classmethod
|
|
def extension(cls) -> str:
|
|
return "parquet"
|
|
|
|
def save(self, data: Any) -> None:
|
|
df = self.pd.DataFrame(data)
|
|
table = self.pa.Table.from_pandas(df)
|
|
if os.path.exists(self.persist_path):
|
|
backup_path = str(self.persist_path) + "-backup"
|
|
os.rename(self.persist_path, backup_path)
|
|
try:
|
|
self.pq.write_table(table, self.persist_path)
|
|
except Exception as exc:
|
|
os.rename(backup_path, self.persist_path)
|
|
raise exc
|
|
else:
|
|
os.remove(backup_path)
|
|
else:
|
|
self.pq.write_table(table, self.persist_path)
|
|
|
|
def load(self) -> Any:
|
|
table = self.pq.read_table(self.persist_path)
|
|
df = table.to_pandas()
|
|
return {col: series.tolist() for col, series in df.items()}
|
|
|
|
|
|
SERIALIZER_MAP: Dict[str, Type[BaseSerializer]] = {
|
|
"json": JsonSerializer,
|
|
"bson": BsonSerializer,
|
|
"parquet": ParquetSerializer,
|
|
}
|
|
|
|
|
|
class SKLearnVectorStoreException(RuntimeError):
|
|
pass
|
|
|
|
|
|
class SKLearnVectorStore(VectorStore):
|
|
"""A simple in-memory vector store based on the scikit-learn library
|
|
NearestNeighbors implementation."""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding: Embeddings,
|
|
*,
|
|
persist_path: Optional[str] = None,
|
|
serializer: Literal["json", "bson", "parquet"] = "json",
|
|
metric: str = "cosine",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
np = guard_import("numpy")
|
|
sklearn_neighbors = guard_import("sklearn.neighbors", pip_name="scikit-learn")
|
|
|
|
# non-persistent properties
|
|
self._np = np
|
|
self._neighbors = sklearn_neighbors.NearestNeighbors(metric=metric, **kwargs)
|
|
self._neighbors_fitted = False
|
|
self._embedding_function = embedding
|
|
self._persist_path = persist_path
|
|
self._serializer: Optional[BaseSerializer] = None
|
|
if self._persist_path is not None:
|
|
serializer_cls = SERIALIZER_MAP[serializer]
|
|
self._serializer = serializer_cls(persist_path=self._persist_path)
|
|
|
|
# data properties
|
|
self._embeddings: List[List[float]] = []
|
|
self._texts: List[str] = []
|
|
self._metadatas: List[dict] = []
|
|
self._ids: List[str] = []
|
|
|
|
# cache properties
|
|
self._embeddings_np: Any = np.asarray([])
|
|
|
|
if self._persist_path is not None and os.path.isfile(self._persist_path):
|
|
self._load()
|
|
|
|
def persist(self) -> None:
|
|
if self._serializer is None:
|
|
raise SKLearnVectorStoreException(
|
|
"You must specify a persist_path on creation to persist the "
|
|
"collection."
|
|
)
|
|
data = {
|
|
"ids": self._ids,
|
|
"texts": self._texts,
|
|
"metadatas": self._metadatas,
|
|
"embeddings": self._embeddings,
|
|
}
|
|
self._serializer.save(data)
|
|
|
|
def _load(self) -> None:
|
|
if self._serializer is None:
|
|
raise SKLearnVectorStoreException(
|
|
"You must specify a persist_path on creation to load the " "collection."
|
|
)
|
|
data = self._serializer.load()
|
|
self._embeddings = data["embeddings"]
|
|
self._texts = data["texts"]
|
|
self._metadatas = data["metadatas"]
|
|
self._ids = data["ids"]
|
|
self._update_neighbors()
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
ids: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
_texts = list(texts)
|
|
_ids = ids or [str(uuid4()) for _ in _texts]
|
|
self._texts.extend(_texts)
|
|
self._embeddings.extend(self._embedding_function.embed_documents(_texts))
|
|
self._metadatas.extend(metadatas or ([{}] * len(_texts)))
|
|
self._ids.extend(_ids)
|
|
self._update_neighbors()
|
|
return _ids
|
|
|
|
def _update_neighbors(self) -> None:
|
|
if len(self._embeddings) == 0:
|
|
raise SKLearnVectorStoreException(
|
|
"No data was added to SKLearnVectorStore."
|
|
)
|
|
self._embeddings_np = self._np.asarray(self._embeddings)
|
|
self._neighbors.fit(self._embeddings_np)
|
|
self._neighbors_fitted = True
|
|
|
|
def similarity_search_with_score(
|
|
self, query: str, *, k: int = 4, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
if not self._neighbors_fitted:
|
|
raise SKLearnVectorStoreException(
|
|
"No data was added to SKLearnVectorStore."
|
|
)
|
|
query_embedding = self._embedding_function.embed_query(query)
|
|
neigh_dists, neigh_idxs = self._neighbors.kneighbors(
|
|
[query_embedding], n_neighbors=k
|
|
)
|
|
res = []
|
|
for idx, dist in zip(neigh_idxs[0], neigh_dists[0]):
|
|
_idx = int(idx)
|
|
metadata = {"id": self._ids[_idx], **self._metadatas[_idx]}
|
|
doc = Document(page_content=self._texts[_idx], metadata=metadata)
|
|
res.append((doc, dist))
|
|
return res
|
|
|
|
def similarity_search(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Document]:
|
|
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
|
|
return [doc for doc, _ in docs_scores]
|
|
|
|
def _similarity_search_with_relevance_scores(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
docs_dists = self.similarity_search_with_score(query=query, k=k, **kwargs)
|
|
docs, dists = zip(*docs_dists)
|
|
scores = [1 / math.exp(dist) for dist in dists]
|
|
return list(zip(list(docs), scores))
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls,
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
ids: Optional[List[str]] = None,
|
|
persist_path: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> "SKLearnVectorStore":
|
|
vs = SKLearnVectorStore(embedding, persist_path=persist_path, **kwargs)
|
|
vs.add_texts(texts, metadatas=metadatas, ids=ids)
|
|
return vs
|