mirror of https://github.com/hwchase17/langchain
Add SKLearnVectorStore (#5305)
# Add SKLearnVectorStore This PR adds SKLearnVectorStore, a simply vector store based on NearestNeighbors implementations in the scikit-learn package. This provides a simple drop-in vector store implementation with minimal dependencies (scikit-learn is typically installed in a data scientist / ml engineer environment). The vector store can be persisted and loaded from json, bson and parquet format. SKLearnVectorStore has soft (dynamic) dependency on the scikit-learn, numpy and pandas packages. Persisting to bson requires the bson package, persisting to parquet requires the pyarrow package. ## Before submitting Integration tests are provided under `tests/integration_tests/vectorstores/test_sklearn.py` Sample usage notebook is provided under `docs/modules/indexes/vectorstores/examples/sklear.ipynb` Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>pull/5364/head
parent
e2742953a6
commit
5f4552391f
@ -0,0 +1,23 @@
|
||||
# scikit-learn
|
||||
|
||||
This page covers how to use the scikit-learn package within LangChain.
|
||||
It is broken into two parts: installation and setup, and then references to specific scikit-learn wrappers.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install the Python package with `pip install scikit-learn`
|
||||
|
||||
## Wrappers
|
||||
|
||||
### VectorStore
|
||||
|
||||
`SKLearnVectorStore` provides a simple wrapper around the nearest neighbor implementation in the
|
||||
scikit-learn package, allowing you to use it as a vectorstore.
|
||||
|
||||
To import this vectorstore:
|
||||
|
||||
```python
|
||||
from langchain.vectorstores import SKLearnVectorStore
|
||||
```
|
||||
|
||||
For a more detailed walkthrough of the SKLearnVectorStore wrapper, see [this notebook](../modules/indexes/vectorstores/examples/sklearn.ipynb).
|
@ -0,0 +1,271 @@
|
||||
""" Wrapper around scikit-learn NearestNeighbors implementation.
|
||||
|
||||
The vector store can be persisted in json, bson or parquet format.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
def guard_import(
|
||||
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
|
||||
) -> Any:
|
||||
"""Dynamically imports a module and raises a helpful exception if the module is not
|
||||
installed."""
|
||||
try:
|
||||
module = importlib.import_module(module_name, package)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {pip_name or module_name}`."
|
||||
)
|
||||
return module
|
||||
|
||||
|
||||
class BaseSerializer(ABC):
|
||||
"""Abstract base class for saving and loading data."""
|
||||
|
||||
def __init__(self, persist_path: str) -> None:
|
||||
self.persist_path = persist_path
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def extension(cls) -> str:
|
||||
"""The file extension suggested by this serializer (without dot)."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, data: Any) -> None:
|
||||
"""Saves the data to the persist_path"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> Any:
|
||||
"""Loads the data from the persist_path"""
|
||||
|
||||
|
||||
class JsonSerializer(BaseSerializer):
|
||||
"""Serializes data in json using the json package from python standard library."""
|
||||
|
||||
@classmethod
|
||||
def extension(cls) -> str:
|
||||
return "json"
|
||||
|
||||
def save(self, data: Any) -> None:
|
||||
with open(self.persist_path, "w") as fp:
|
||||
json.dump(data, fp)
|
||||
|
||||
def load(self) -> Any:
|
||||
with open(self.persist_path, "r") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
class BsonSerializer(BaseSerializer):
|
||||
"""Serializes data in binary json using the bson python package."""
|
||||
|
||||
def __init__(self, persist_path: str) -> None:
|
||||
super().__init__(persist_path)
|
||||
self.bson = guard_import("bson")
|
||||
|
||||
@classmethod
|
||||
def extension(cls) -> str:
|
||||
return "bson"
|
||||
|
||||
def save(self, data: Any) -> None:
|
||||
with open(self.persist_path, "wb") as fp:
|
||||
fp.write(self.bson.dumps(data))
|
||||
|
||||
def load(self) -> Any:
|
||||
with open(self.persist_path, "rb") as fp:
|
||||
return self.bson.loads(fp.read())
|
||||
|
||||
|
||||
class ParquetSerializer(BaseSerializer):
|
||||
"""Serializes data in Apache Parquet format using the pyarrow package."""
|
||||
|
||||
def __init__(self, persist_path: str) -> None:
|
||||
super().__init__(persist_path)
|
||||
self.pd = guard_import("pandas")
|
||||
self.pa = guard_import("pyarrow")
|
||||
self.pq = guard_import("pyarrow.parquet")
|
||||
|
||||
@classmethod
|
||||
def extension(cls) -> str:
|
||||
return "parquet"
|
||||
|
||||
def save(self, data: Any) -> None:
|
||||
df = self.pd.DataFrame(data)
|
||||
table = self.pa.Table.from_pandas(df)
|
||||
if os.path.exists(self.persist_path):
|
||||
backup_path = str(self.persist_path) + "-backup"
|
||||
os.rename(self.persist_path, backup_path)
|
||||
try:
|
||||
self.pq.write_table(table, self.persist_path)
|
||||
except Exception as exc:
|
||||
os.rename(backup_path, self.persist_path)
|
||||
raise exc
|
||||
else:
|
||||
os.remove(backup_path)
|
||||
else:
|
||||
self.pq.write_table(table, self.persist_path)
|
||||
|
||||
def load(self) -> Any:
|
||||
table = self.pq.read_table(self.persist_path)
|
||||
df = table.to_pandas()
|
||||
return {col: series.tolist() for col, series in df.items()}
|
||||
|
||||
|
||||
SERIALIZER_MAP: Dict[str, Type[BaseSerializer]] = {
|
||||
"json": JsonSerializer,
|
||||
"bson": BsonSerializer,
|
||||
"parquet": ParquetSerializer,
|
||||
}
|
||||
|
||||
|
||||
class SKLearnVectorStoreException(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class SKLearnVectorStore(VectorStore):
|
||||
"""A simple in-memory vector store based on the scikit-learn library
|
||||
NearestNeighbors implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
*,
|
||||
persist_path: Optional[str] = None,
|
||||
serializer: Literal["json", "bson", "parquet"] = "json",
|
||||
metric: str = "cosine",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
np = guard_import("numpy")
|
||||
sklearn_neighbors = guard_import("sklearn.neighbors", pip_name="scikit-learn")
|
||||
|
||||
# non-persistent properties
|
||||
self._np = np
|
||||
self._neighbors = sklearn_neighbors.NearestNeighbors(metric=metric, **kwargs)
|
||||
self._neighbors_fitted = False
|
||||
self._embedding_function = embedding
|
||||
self._persist_path = persist_path
|
||||
self._serializer: Optional[BaseSerializer] = None
|
||||
if self._persist_path is not None:
|
||||
serializer_cls = SERIALIZER_MAP[serializer]
|
||||
self._serializer = serializer_cls(persist_path=self._persist_path)
|
||||
|
||||
# data properties
|
||||
self._embeddings: List[List[float]] = []
|
||||
self._texts: List[str] = []
|
||||
self._metadatas: List[dict] = []
|
||||
self._ids: List[str] = []
|
||||
|
||||
# cache properties
|
||||
self._embeddings_np: Any = np.asarray([])
|
||||
|
||||
if self._persist_path is not None and os.path.isfile(self._persist_path):
|
||||
self._load()
|
||||
|
||||
def persist(self) -> None:
|
||||
if self._serializer is None:
|
||||
raise SKLearnVectorStoreException(
|
||||
"You must specify a persist_path on creation to persist the "
|
||||
"collection."
|
||||
)
|
||||
data = {
|
||||
"ids": self._ids,
|
||||
"texts": self._texts,
|
||||
"metadatas": self._metadatas,
|
||||
"embeddings": self._embeddings,
|
||||
}
|
||||
self._serializer.save(data)
|
||||
|
||||
def _load(self) -> None:
|
||||
if self._serializer is None:
|
||||
raise SKLearnVectorStoreException(
|
||||
"You must specify a persist_path on creation to load the " "collection."
|
||||
)
|
||||
data = self._serializer.load()
|
||||
self._embeddings = data["embeddings"]
|
||||
self._texts = data["texts"]
|
||||
self._metadatas = data["metadatas"]
|
||||
self._ids = data["ids"]
|
||||
self._update_neighbors()
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
_texts = list(texts)
|
||||
_ids = ids or [str(uuid4()) for _ in _texts]
|
||||
self._texts.extend(_texts)
|
||||
self._embeddings.extend(self._embedding_function.embed_documents(_texts))
|
||||
self._metadatas.extend(metadatas or ([{}] * len(_texts)))
|
||||
self._ids.extend(_ids)
|
||||
self._update_neighbors()
|
||||
return _ids
|
||||
|
||||
def _update_neighbors(self) -> None:
|
||||
if len(self._embeddings) == 0:
|
||||
raise SKLearnVectorStoreException(
|
||||
"No data was added to SKLearnVectorStore."
|
||||
)
|
||||
self._embeddings_np = self._np.asarray(self._embeddings)
|
||||
self._neighbors.fit(self._embeddings_np)
|
||||
self._neighbors_fitted = True
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, *, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
if not self._neighbors_fitted:
|
||||
raise SKLearnVectorStoreException(
|
||||
"No data was added to SKLearnVectorStore."
|
||||
)
|
||||
query_embedding = self._embedding_function.embed_query(query)
|
||||
neigh_dists, neigh_idxs = self._neighbors.kneighbors(
|
||||
[query_embedding], n_neighbors=k
|
||||
)
|
||||
res = []
|
||||
for idx, dist in zip(neigh_idxs[0], neigh_dists[0]):
|
||||
_idx = int(idx)
|
||||
metadata = {"id": self._ids[_idx], **self._metadatas[_idx]}
|
||||
doc = Document(page_content=self._texts[_idx], metadata=metadata)
|
||||
res.append((doc, dist))
|
||||
return res
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
|
||||
return [doc for doc, _ in docs_scores]
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
docs_dists = self.similarity_search_with_score(query=query, k=k, **kwargs)
|
||||
docs, dists = zip(*docs_dists)
|
||||
scores = [1 / math.exp(dist) for dist in dists]
|
||||
return list(zip(list(docs), scores))
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
persist_path: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> "SKLearnVectorStore":
|
||||
vs = SKLearnVectorStore(embedding, persist_path=persist_path, **kwargs)
|
||||
vs.add_texts(texts, metadatas=metadatas, ids=ids)
|
||||
return vs
|
@ -0,0 +1,76 @@
|
||||
"""Test SKLearnVectorStore functionality."""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.vectorstores import SKLearnVectorStore
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
|
||||
@pytest.mark.requires("numpy", "sklearn")
|
||||
def test_sklearn() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = SKLearnVectorStore.from_texts(texts, embedding=FakeEmbeddings())
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("numpy", "sklearn")
|
||||
def test_sklearn_with_metadatas() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = SKLearnVectorStore.from_texts(
|
||||
texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output[0].metadata["page"] == "0"
|
||||
|
||||
|
||||
@pytest.mark.requires("numpy", "sklearn")
|
||||
def test_sklearn_with_metadatas_with_scores() -> None:
|
||||
"""Test end to end construction and scored search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = SKLearnVectorStore.from_texts(
|
||||
texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||
assert len(output) == 1
|
||||
doc, score = output[0]
|
||||
assert doc.page_content == "foo"
|
||||
assert doc.metadata["page"] == "0"
|
||||
assert score == 1
|
||||
|
||||
|
||||
@pytest.mark.requires("numpy", "sklearn")
|
||||
def test_sklearn_with_persistence(tmpdir: Path) -> None:
|
||||
"""Test end to end construction and search, with persistence."""
|
||||
persist_path = tmpdir / "foo.parquet"
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = SKLearnVectorStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
persist_path=str(persist_path),
|
||||
serializer="json",
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
|
||||
docsearch.persist()
|
||||
|
||||
# Get a new VectorStore from the persisted directory
|
||||
docsearch = SKLearnVectorStore(
|
||||
embedding=FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
Loading…
Reference in New Issue