You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/vectorstores/sklearn.py

349 lines
12 KiB
Python

""" Wrapper around scikit-learn NearestNeighbors implementation.
The vector store can be persisted in json, bson or parquet format.
"""
import json
import math
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type
from uuid import uuid4
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import guard_import
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
DEFAULT_K = 4 # Number of Documents to return.
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
class BaseSerializer(ABC):
"""Abstract base class for saving and loading data."""
def __init__(self, persist_path: str) -> None:
self.persist_path = persist_path
@classmethod
@abstractmethod
def extension(cls) -> str:
"""The file extension suggested by this serializer (without dot)."""
@abstractmethod
def save(self, data: Any) -> None:
"""Saves the data to the persist_path"""
@abstractmethod
def load(self) -> Any:
"""Loads the data from the persist_path"""
class JsonSerializer(BaseSerializer):
"""Serializes data in json using the json package from python standard library."""
@classmethod
def extension(cls) -> str:
return "json"
def save(self, data: Any) -> None:
with open(self.persist_path, "w") as fp:
json.dump(data, fp)
def load(self) -> Any:
with open(self.persist_path, "r") as fp:
return json.load(fp)
class BsonSerializer(BaseSerializer):
"""Serializes data in binary json using the bson python package."""
def __init__(self, persist_path: str) -> None:
super().__init__(persist_path)
self.bson = guard_import("bson")
@classmethod
def extension(cls) -> str:
return "bson"
def save(self, data: Any) -> None:
with open(self.persist_path, "wb") as fp:
fp.write(self.bson.dumps(data))
def load(self) -> Any:
with open(self.persist_path, "rb") as fp:
return self.bson.loads(fp.read())
class ParquetSerializer(BaseSerializer):
"""Serializes data in Apache Parquet format using the pyarrow package."""
def __init__(self, persist_path: str) -> None:
super().__init__(persist_path)
self.pd = guard_import("pandas")
self.pa = guard_import("pyarrow")
self.pq = guard_import("pyarrow.parquet")
@classmethod
def extension(cls) -> str:
return "parquet"
def save(self, data: Any) -> None:
df = self.pd.DataFrame(data)
table = self.pa.Table.from_pandas(df)
if os.path.exists(self.persist_path):
backup_path = str(self.persist_path) + "-backup"
os.rename(self.persist_path, backup_path)
try:
self.pq.write_table(table, self.persist_path)
except Exception as exc:
os.rename(backup_path, self.persist_path)
raise exc
else:
os.remove(backup_path)
else:
self.pq.write_table(table, self.persist_path)
def load(self) -> Any:
table = self.pq.read_table(self.persist_path)
df = table.to_pandas()
return {col: series.tolist() for col, series in df.items()}
SERIALIZER_MAP: Dict[str, Type[BaseSerializer]] = {
"json": JsonSerializer,
"bson": BsonSerializer,
"parquet": ParquetSerializer,
}
class SKLearnVectorStoreException(RuntimeError):
pass
class SKLearnVectorStore(VectorStore):
"""A simple in-memory vector store based on the scikit-learn library
NearestNeighbors implementation."""
def __init__(
self,
embedding: Embeddings,
*,
persist_path: Optional[str] = None,
serializer: Literal["json", "bson", "parquet"] = "json",
metric: str = "cosine",
**kwargs: Any,
) -> None:
np = guard_import("numpy")
sklearn_neighbors = guard_import("sklearn.neighbors", pip_name="scikit-learn")
# non-persistent properties
self._np = np
self._neighbors = sklearn_neighbors.NearestNeighbors(metric=metric, **kwargs)
self._neighbors_fitted = False
self._embedding_function = embedding
self._persist_path = persist_path
self._serializer: Optional[BaseSerializer] = None
if self._persist_path is not None:
serializer_cls = SERIALIZER_MAP[serializer]
self._serializer = serializer_cls(persist_path=self._persist_path)
# data properties
self._embeddings: List[List[float]] = []
self._texts: List[str] = []
self._metadatas: List[dict] = []
self._ids: List[str] = []
# cache properties
self._embeddings_np: Any = np.asarray([])
if self._persist_path is not None and os.path.isfile(self._persist_path):
self._load()
def persist(self) -> None:
if self._serializer is None:
raise SKLearnVectorStoreException(
"You must specify a persist_path on creation to persist the "
"collection."
)
data = {
"ids": self._ids,
"texts": self._texts,
"metadatas": self._metadatas,
"embeddings": self._embeddings,
}
self._serializer.save(data)
def _load(self) -> None:
if self._serializer is None:
raise SKLearnVectorStoreException(
"You must specify a persist_path on creation to load the " "collection."
)
data = self._serializer.load()
self._embeddings = data["embeddings"]
self._texts = data["texts"]
self._metadatas = data["metadatas"]
self._ids = data["ids"]
self._update_neighbors()
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
_texts = list(texts)
_ids = ids or [str(uuid4()) for _ in _texts]
self._texts.extend(_texts)
self._embeddings.extend(self._embedding_function.embed_documents(_texts))
self._metadatas.extend(metadatas or ([{}] * len(_texts)))
self._ids.extend(_ids)
self._update_neighbors()
return _ids
def _update_neighbors(self) -> None:
if len(self._embeddings) == 0:
raise SKLearnVectorStoreException(
"No data was added to SKLearnVectorStore."
)
self._embeddings_np = self._np.asarray(self._embeddings)
self._neighbors.fit(self._embeddings_np)
self._neighbors_fitted = True
def _similarity_index_search_with_score(
self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[int, float]]:
"""Search k embeddings similar to the query embedding. Returns a list of
(index, distance) tuples."""
if not self._neighbors_fitted:
raise SKLearnVectorStoreException(
"No data was added to SKLearnVectorStore."
)
neigh_dists, neigh_idxs = self._neighbors.kneighbors(
[query_embedding], n_neighbors=k
)
return list(zip(neigh_idxs[0], neigh_dists[0]))
def similarity_search_with_score(
self, query: str, *, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[Document, float]]:
query_embedding = self._embedding_function.embed_query(query)
indices_dists = self._similarity_index_search_with_score(
query_embedding, k=k, **kwargs
)
return [
(
Document(
page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]},
),
dist,
)
for idx, dist in indices_dists
]
def similarity_search(
self, query: str, k: int = DEFAULT_K, **kwargs: Any
) -> List[Document]:
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
return [doc for doc, _ in docs_scores]
def _similarity_search_with_relevance_scores(
self, query: str, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[Document, float]]:
docs_dists = self.similarity_search_with_score(query, k=k, **kwargs)
docs, dists = zip(*docs_dists)
scores = [1 / math.exp(dist) for dist in dists]
return list(zip(list(docs), scores))
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
indices_dists = self._similarity_index_search_with_score(
embedding, k=fetch_k, **kwargs
)
indices, _ = zip(*indices_dists)
result_embeddings = self._embeddings_np[indices,]
mmr_selected = maximal_marginal_relevance(
self._np.array(embedding, dtype=self._np.float32),
result_embeddings,
k=k,
lambda_mult=lambda_mult,
)
mmr_indices = [indices[i] for i in mmr_selected]
return [
Document(
page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]},
)
for idx in mmr_indices
]
def max_marginal_relevance_search(
self,
query: str,
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._embedding_function is None:
raise ValueError(
"For MMR search, you must specify an embedding function on creation."
)
embedding = self._embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mul=lambda_mult
)
return docs
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
persist_path: Optional[str] = None,
**kwargs: Any,
) -> "SKLearnVectorStore":
vs = SKLearnVectorStore(embedding, persist_path=persist_path, **kwargs)
vs.add_texts(texts, metadatas=metadatas, ids=ids)
return vs