mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
7cf2d2759d
Added missed docstrings. Format docstings to the consistent form.
356 lines
12 KiB
Python
356 lines
12 KiB
Python
""" Wrapper around scikit-learn NearestNeighbors implementation.
|
|
|
|
The vector store can be persisted in json, bson or parquet format.
|
|
"""
|
|
|
|
import json
|
|
import math
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type
|
|
from uuid import uuid4
|
|
|
|
from langchain_core.documents import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.utils import guard_import
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
|
|
|
DEFAULT_K = 4 # Number of Documents to return.
|
|
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
|
|
|
|
|
|
class BaseSerializer(ABC):
|
|
"""Base class for serializing data."""
|
|
|
|
def __init__(self, persist_path: str) -> None:
|
|
self.persist_path = persist_path
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def extension(cls) -> str:
|
|
"""The file extension suggested by this serializer (without dot)."""
|
|
|
|
@abstractmethod
|
|
def save(self, data: Any) -> None:
|
|
"""Saves the data to the persist_path"""
|
|
|
|
@abstractmethod
|
|
def load(self) -> Any:
|
|
"""Loads the data from the persist_path"""
|
|
|
|
|
|
class JsonSerializer(BaseSerializer):
|
|
"""Serialize data in JSON using the json package from python standard library."""
|
|
|
|
@classmethod
|
|
def extension(cls) -> str:
|
|
return "json"
|
|
|
|
def save(self, data: Any) -> None:
|
|
with open(self.persist_path, "w") as fp:
|
|
json.dump(data, fp)
|
|
|
|
def load(self) -> Any:
|
|
with open(self.persist_path, "r") as fp:
|
|
return json.load(fp)
|
|
|
|
|
|
class BsonSerializer(BaseSerializer):
|
|
"""Serialize data in Binary JSON using the `bson` python package."""
|
|
|
|
def __init__(self, persist_path: str) -> None:
|
|
super().__init__(persist_path)
|
|
self.bson = guard_import("bson")
|
|
|
|
@classmethod
|
|
def extension(cls) -> str:
|
|
return "bson"
|
|
|
|
def save(self, data: Any) -> None:
|
|
with open(self.persist_path, "wb") as fp:
|
|
fp.write(self.bson.dumps(data))
|
|
|
|
def load(self) -> Any:
|
|
with open(self.persist_path, "rb") as fp:
|
|
return self.bson.loads(fp.read())
|
|
|
|
|
|
class ParquetSerializer(BaseSerializer):
|
|
"""Serialize data in `Apache Parquet` format using the `pyarrow` package."""
|
|
|
|
def __init__(self, persist_path: str) -> None:
|
|
super().__init__(persist_path)
|
|
self.pd = guard_import("pandas")
|
|
self.pa = guard_import("pyarrow")
|
|
self.pq = guard_import("pyarrow.parquet")
|
|
|
|
@classmethod
|
|
def extension(cls) -> str:
|
|
return "parquet"
|
|
|
|
def save(self, data: Any) -> None:
|
|
df = self.pd.DataFrame(data)
|
|
table = self.pa.Table.from_pandas(df)
|
|
if os.path.exists(self.persist_path):
|
|
backup_path = str(self.persist_path) + "-backup"
|
|
os.rename(self.persist_path, backup_path)
|
|
try:
|
|
self.pq.write_table(table, self.persist_path)
|
|
except Exception as exc:
|
|
os.rename(backup_path, self.persist_path)
|
|
raise exc
|
|
else:
|
|
os.remove(backup_path)
|
|
else:
|
|
self.pq.write_table(table, self.persist_path)
|
|
|
|
def load(self) -> Any:
|
|
table = self.pq.read_table(self.persist_path)
|
|
df = table.to_pandas()
|
|
return {col: series.tolist() for col, series in df.items()}
|
|
|
|
|
|
SERIALIZER_MAP: Dict[str, Type[BaseSerializer]] = {
|
|
"json": JsonSerializer,
|
|
"bson": BsonSerializer,
|
|
"parquet": ParquetSerializer,
|
|
}
|
|
|
|
|
|
class SKLearnVectorStoreException(RuntimeError):
|
|
"""Exception raised by SKLearnVectorStore."""
|
|
|
|
pass
|
|
|
|
|
|
class SKLearnVectorStore(VectorStore):
|
|
"""Simple in-memory vector store based on the `scikit-learn` library
|
|
`NearestNeighbors`."""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding: Embeddings,
|
|
*,
|
|
persist_path: Optional[str] = None,
|
|
serializer: Literal["json", "bson", "parquet"] = "json",
|
|
metric: str = "cosine",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
np = guard_import("numpy")
|
|
sklearn_neighbors = guard_import("sklearn.neighbors", pip_name="scikit-learn")
|
|
|
|
# non-persistent properties
|
|
self._np = np
|
|
self._neighbors = sklearn_neighbors.NearestNeighbors(metric=metric, **kwargs)
|
|
self._neighbors_fitted = False
|
|
self._embedding_function = embedding
|
|
self._persist_path = persist_path
|
|
self._serializer: Optional[BaseSerializer] = None
|
|
if self._persist_path is not None:
|
|
serializer_cls = SERIALIZER_MAP[serializer]
|
|
self._serializer = serializer_cls(persist_path=self._persist_path)
|
|
|
|
# data properties
|
|
self._embeddings: List[List[float]] = []
|
|
self._texts: List[str] = []
|
|
self._metadatas: List[dict] = []
|
|
self._ids: List[str] = []
|
|
|
|
# cache properties
|
|
self._embeddings_np: Any = np.asarray([])
|
|
|
|
if self._persist_path is not None and os.path.isfile(self._persist_path):
|
|
self._load()
|
|
|
|
@property
|
|
def embeddings(self) -> Embeddings:
|
|
return self._embedding_function
|
|
|
|
def persist(self) -> None:
|
|
if self._serializer is None:
|
|
raise SKLearnVectorStoreException(
|
|
"You must specify a persist_path on creation to persist the "
|
|
"collection."
|
|
)
|
|
data = {
|
|
"ids": self._ids,
|
|
"texts": self._texts,
|
|
"metadatas": self._metadatas,
|
|
"embeddings": self._embeddings,
|
|
}
|
|
self._serializer.save(data)
|
|
|
|
def _load(self) -> None:
|
|
if self._serializer is None:
|
|
raise SKLearnVectorStoreException(
|
|
"You must specify a persist_path on creation to load the " "collection."
|
|
)
|
|
data = self._serializer.load()
|
|
self._embeddings = data["embeddings"]
|
|
self._texts = data["texts"]
|
|
self._metadatas = data["metadatas"]
|
|
self._ids = data["ids"]
|
|
self._update_neighbors()
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
ids: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
_texts = list(texts)
|
|
_ids = ids or [str(uuid4()) for _ in _texts]
|
|
self._texts.extend(_texts)
|
|
self._embeddings.extend(self._embedding_function.embed_documents(_texts))
|
|
self._metadatas.extend(metadatas or ([{}] * len(_texts)))
|
|
self._ids.extend(_ids)
|
|
self._update_neighbors()
|
|
return _ids
|
|
|
|
def _update_neighbors(self) -> None:
|
|
if len(self._embeddings) == 0:
|
|
raise SKLearnVectorStoreException(
|
|
"No data was added to SKLearnVectorStore."
|
|
)
|
|
self._embeddings_np = self._np.asarray(self._embeddings)
|
|
self._neighbors.fit(self._embeddings_np)
|
|
self._neighbors_fitted = True
|
|
|
|
def _similarity_index_search_with_score(
|
|
self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any
|
|
) -> List[Tuple[int, float]]:
|
|
"""Search k embeddings similar to the query embedding. Returns a list of
|
|
(index, distance) tuples."""
|
|
if not self._neighbors_fitted:
|
|
raise SKLearnVectorStoreException(
|
|
"No data was added to SKLearnVectorStore."
|
|
)
|
|
neigh_dists, neigh_idxs = self._neighbors.kneighbors(
|
|
[query_embedding], n_neighbors=k
|
|
)
|
|
return list(zip(neigh_idxs[0], neigh_dists[0]))
|
|
|
|
def similarity_search_with_score(
|
|
self, query: str, *, k: int = DEFAULT_K, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
query_embedding = self._embedding_function.embed_query(query)
|
|
indices_dists = self._similarity_index_search_with_score(
|
|
query_embedding, k=k, **kwargs
|
|
)
|
|
return [
|
|
(
|
|
Document(
|
|
page_content=self._texts[idx],
|
|
metadata={"id": self._ids[idx], **self._metadatas[idx]},
|
|
),
|
|
dist,
|
|
)
|
|
for idx, dist in indices_dists
|
|
]
|
|
|
|
def similarity_search(
|
|
self, query: str, k: int = DEFAULT_K, **kwargs: Any
|
|
) -> List[Document]:
|
|
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
|
|
return [doc for doc, _ in docs_scores]
|
|
|
|
def _similarity_search_with_relevance_scores(
|
|
self, query: str, k: int = DEFAULT_K, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
docs_dists = self.similarity_search_with_score(query, k=k, **kwargs)
|
|
docs, dists = zip(*docs_dists)
|
|
scores = [1 / math.exp(dist) for dist in dists]
|
|
return list(zip(list(docs), scores))
|
|
|
|
def max_marginal_relevance_search_by_vector(
|
|
self,
|
|
embedding: List[float],
|
|
k: int = DEFAULT_K,
|
|
fetch_k: int = DEFAULT_FETCH_K,
|
|
lambda_mult: float = 0.5,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Return docs selected using the maximal marginal relevance.
|
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
among selected documents.
|
|
Args:
|
|
embedding: Embedding to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
lambda_mult: Number between 0 and 1 that determines the degree
|
|
of diversity among the results with 0 corresponding
|
|
to maximum diversity and 1 to minimum diversity.
|
|
Defaults to 0.5.
|
|
Returns:
|
|
List of Documents selected by maximal marginal relevance.
|
|
"""
|
|
indices_dists = self._similarity_index_search_with_score(
|
|
embedding, k=fetch_k, **kwargs
|
|
)
|
|
indices, _ = zip(*indices_dists)
|
|
result_embeddings = self._embeddings_np[indices,]
|
|
mmr_selected = maximal_marginal_relevance(
|
|
self._np.array(embedding, dtype=self._np.float32),
|
|
result_embeddings,
|
|
k=k,
|
|
lambda_mult=lambda_mult,
|
|
)
|
|
mmr_indices = [indices[i] for i in mmr_selected]
|
|
return [
|
|
Document(
|
|
page_content=self._texts[idx],
|
|
metadata={"id": self._ids[idx], **self._metadatas[idx]},
|
|
)
|
|
for idx in mmr_indices
|
|
]
|
|
|
|
def max_marginal_relevance_search(
|
|
self,
|
|
query: str,
|
|
k: int = DEFAULT_K,
|
|
fetch_k: int = DEFAULT_FETCH_K,
|
|
lambda_mult: float = 0.5,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Return docs selected using the maximal marginal relevance.
|
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
among selected documents.
|
|
Args:
|
|
query: Text to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
lambda_mult: Number between 0 and 1 that determines the degree
|
|
of diversity among the results with 0 corresponding
|
|
to maximum diversity and 1 to minimum diversity.
|
|
Defaults to 0.5.
|
|
Returns:
|
|
List of Documents selected by maximal marginal relevance.
|
|
"""
|
|
if self._embedding_function is None:
|
|
raise ValueError(
|
|
"For MMR search, you must specify an embedding function on creation."
|
|
)
|
|
|
|
embedding = self._embedding_function.embed_query(query)
|
|
docs = self.max_marginal_relevance_search_by_vector(
|
|
embedding, k, fetch_k, lambda_mul=lambda_mult
|
|
)
|
|
return docs
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls,
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
ids: Optional[List[str]] = None,
|
|
persist_path: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> "SKLearnVectorStore":
|
|
vs = SKLearnVectorStore(embedding, persist_path=persist_path, **kwargs)
|
|
vs.add_texts(texts, metadatas=metadatas, ids=ids)
|
|
return vs
|