mirror of https://github.com/hwchase17/langchain
Integrate Rockset as Vectorstore (#6216)
This PR adds Rockset as a vectorstore for langchain. [Rockset](https://rockset.com/blog/introducing-vector-search-on-rockset/) is a real time OLAP database which provides a fast and efficient vector search functionality. Further since it is entirely schemaless, it can store metadata in separate columns thereby allowing fast metadata filters during vector similarity search (as opposed to storing the entire metadata in a single JSON column). It currently supports three distance functions: `COSINE_SIMILARITY`, `EUCLIDEAN_DISTANCE`, and `DOT_PRODUCT`. This PR adds `rockset` client as an optional dependency. We would love a twitter shoutout, our handle is https://twitter.com/RocksetCloud --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>octoml/master^2
parent
ab7ecc9c30
commit
94c7899257
@ -0,0 +1,327 @@
|
||||
"""Wrapper around Rockset vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Rockset(VectorStore):
|
||||
"""Wrapper arpund Rockset vector database.
|
||||
|
||||
To use, you should have the `rockset` python package installed. Note that to use
|
||||
this, the collection being used must already exist in your Rockset instance.
|
||||
You must also ensure you use a Rockset ingest transformation to apply
|
||||
`VECTOR_ENFORCE` on the column being used to store `embedding_key` in the
|
||||
collection.
|
||||
See: https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details
|
||||
|
||||
Everything below assumes `commons` Rockset workspace.
|
||||
TODO: Add support for workspace args.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Rockset
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
import rockset
|
||||
|
||||
# Make sure you use the right host (region) for your Rockset instance
|
||||
# and APIKEY has both read-write access to your collection.
|
||||
|
||||
rs = rockset.RocksetClient(host=rockset.Regions.use1a1, api_key="***")
|
||||
collection_name = "langchain_demo"
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = Rockset(rs, collection_name, embeddings,
|
||||
"description", "description_embedding")
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
embeddings: Embeddings,
|
||||
collection_name: str,
|
||||
text_key: str,
|
||||
embedding_key: str,
|
||||
):
|
||||
"""Initialize with Rockset client.
|
||||
Args:
|
||||
client: Rockset client object
|
||||
collection: Rockset collection to insert docs / query
|
||||
embeddings: Langchain Embeddings object to use to generate
|
||||
embedding for given text.
|
||||
text_key: column in Rockset collection to use to store the text
|
||||
embedding_key: column in Rockset collection to use to store the embedding.
|
||||
Note: We must apply `VECTOR_ENFORCE()` on this column via
|
||||
Rockset ingest transformation.
|
||||
|
||||
"""
|
||||
try:
|
||||
from rockset import RocksetClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import rockset client python package. "
|
||||
"Please install it with `pip install rockset`."
|
||||
)
|
||||
|
||||
if not isinstance(client, RocksetClient):
|
||||
raise ValueError(
|
||||
f"client should be an instance of rockset.RocksetClient, "
|
||||
f"got {type(client)}"
|
||||
)
|
||||
# TODO: check that `collection_name` exists in rockset. Create if not.
|
||||
self._client = client
|
||||
self._collection_name = collection_name
|
||||
self._embeddings = embeddings
|
||||
self._text_key = text_key
|
||||
self._embedding_key = embedding_key
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: int = 32,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of ids to associate with the texts.
|
||||
batch_size: Send documents in batches to rockset.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
|
||||
"""
|
||||
batch: list[dict] = []
|
||||
stored_ids = []
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
if len(batch) == batch_size:
|
||||
stored_ids += self._write_documents_to_rockset(batch)
|
||||
batch = []
|
||||
doc = {}
|
||||
if metadatas and len(metadatas) > i:
|
||||
doc = metadatas[i]
|
||||
if ids and len(ids) > i:
|
||||
doc["_id"] = ids[i]
|
||||
doc[self._text_key] = text
|
||||
doc[self._embedding_key] = self._embeddings.embed_query(text)
|
||||
batch.append(doc)
|
||||
if len(batch) > 0:
|
||||
stored_ids += self._write_documents_to_rockset(batch)
|
||||
batch = []
|
||||
return stored_ids
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
client: Any = None,
|
||||
collection_name: str = "",
|
||||
text_key: str = "",
|
||||
embedding_key: str = "",
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: int = 32,
|
||||
**kwargs: Any,
|
||||
) -> Rockset:
|
||||
"""Create Rockset wrapper with existing texts.
|
||||
This is intended as a quicker way to get started.
|
||||
"""
|
||||
|
||||
# Sanitize imputs
|
||||
assert client is not None, "Rockset Client cannot be None"
|
||||
assert collection_name, "Collection name cannot be empty"
|
||||
assert text_key, "Text key name cannot be empty"
|
||||
assert embedding_key, "Embedding key cannot be empty"
|
||||
|
||||
rockset = cls(client, embedding, collection_name, text_key, embedding_key)
|
||||
rockset.add_texts(texts, metadatas, ids, batch_size)
|
||||
return rockset
|
||||
|
||||
# Rockset supports these vector distance functions.
|
||||
class DistanceFunction(Enum):
|
||||
COSINE_SIM = "COSINE_SIM"
|
||||
EUCLIDEAN_DIST = "EUCLIDEAN_DIST"
|
||||
DOT_PRODUCT = "DOT_PRODUCT"
|
||||
|
||||
# how to sort results for "similarity"
|
||||
def order_by(self) -> str:
|
||||
if self.value == "EUCLIDEAN_DIST":
|
||||
return "ASC"
|
||||
return "DESC"
|
||||
|
||||
def similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
|
||||
where_str: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Perform a similarity search with Rockset
|
||||
|
||||
Args:
|
||||
query (str): Text to look up documents similar to.
|
||||
distance_func (DistanceFunction): how to compute distance between two
|
||||
vectors in Rockset.
|
||||
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
||||
where_str (Optional[str], optional): Metadata filters supplied as a
|
||||
SQL `where` condition string. Defaults to None.
|
||||
eg. "price<=70.0 AND brand='Nintendo'"
|
||||
|
||||
NOTE: Please do not let end-user to fill this and always be aware
|
||||
of SQL injection.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: List of documents with their relevance score
|
||||
"""
|
||||
return self.similarity_search_by_vector_with_relevance_scores(
|
||||
self._embeddings.embed_query(query),
|
||||
k,
|
||||
distance_func,
|
||||
where_str,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
|
||||
where_str: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Same as `similarity_search_with_relevance_scores` but
|
||||
doesn't return the scores.
|
||||
"""
|
||||
return self.similarity_search_by_vector(
|
||||
self._embeddings.embed_query(query),
|
||||
k,
|
||||
distance_func,
|
||||
where_str,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
|
||||
where_str: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Accepts a query_embedding (vector), and returns documents with
|
||||
similar embeddings."""
|
||||
|
||||
docs_and_scores = self.similarity_search_by_vector_with_relevance_scores(
|
||||
embedding, k, distance_func, where_str, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search_by_vector_with_relevance_scores(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
|
||||
where_str: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Accepts a query_embedding (vector), and returns documents with
|
||||
similar embeddings along with their relevance scores."""
|
||||
|
||||
q_str = self._build_query_sql(embedding, distance_func, k, where_str)
|
||||
try:
|
||||
query_response = self._client.Queries.query(sql={"query": q_str})
|
||||
except Exception as e:
|
||||
logger.error("Exception when querying Rockset: %s\n", e)
|
||||
return []
|
||||
finalResult: list[Tuple[Document, float]] = []
|
||||
for document in query_response.results:
|
||||
metadata = {}
|
||||
assert isinstance(
|
||||
document, dict
|
||||
), "document should be of type `dict[str,Any]`. But found: `{}`".format(
|
||||
type(document)
|
||||
)
|
||||
for k, v in document.items():
|
||||
if k == self._text_key:
|
||||
assert isinstance(
|
||||
v, str
|
||||
), "page content stored in column `{}` must be of type `str`. \
|
||||
But found: `{}`".format(
|
||||
self._text_key, type(v)
|
||||
)
|
||||
page_content = v
|
||||
elif k == "dist":
|
||||
assert isinstance(
|
||||
v, float
|
||||
), "Computed distance between vectors must of type `float`. \
|
||||
But found {}".format(
|
||||
type(v)
|
||||
)
|
||||
score = v
|
||||
elif k not in ["_id", "_event_time", "_meta"]:
|
||||
# These columns are populated by Rockset when documents are
|
||||
# inserted. No need to return them in metadata dict.
|
||||
metadata[k] = v
|
||||
finalResult.append(
|
||||
(Document(page_content=page_content, metadata=metadata), score)
|
||||
)
|
||||
return finalResult
|
||||
|
||||
# Helper functions
|
||||
|
||||
def _build_query_sql(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
distance_func: DistanceFunction,
|
||||
k: int = 4,
|
||||
where_str: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Builds Rockset SQL query to query similar vectors to query_vector"""
|
||||
|
||||
q_embedding_str = ",".join(map(str, query_embedding))
|
||||
distance_str = f"""{distance_func.value}({self._embedding_key}, \
|
||||
[{q_embedding_str}]) as dist"""
|
||||
where_str = f"WHERE {where_str}\n" if where_str else ""
|
||||
return f"""\
|
||||
SELECT * EXCEPT({self._embedding_key}), {distance_str}
|
||||
FROM {self._collection_name}
|
||||
{where_str}\
|
||||
ORDER BY dist {distance_func.order_by()}
|
||||
LIMIT {str(k)}
|
||||
"""
|
||||
|
||||
def _write_documents_to_rockset(self, batch: List[dict]) -> List[str]:
|
||||
add_doc_res = self._client.Documents.add_documents(
|
||||
collection=self._collection_name, data=batch
|
||||
)
|
||||
return [doc_status._id for doc_status in add_doc_res.data]
|
||||
|
||||
def delete_texts(self, ids: List[str]) -> None:
|
||||
"""Delete a list of docs from the Rockset collection"""
|
||||
try:
|
||||
from rockset.models import DeleteDocumentsRequestData
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import rockset client python package. "
|
||||
"Please install it with `pip install rockset`."
|
||||
)
|
||||
|
||||
self._client.Documents.delete_documents(
|
||||
collection=self._collection_name,
|
||||
data=[DeleteDocumentsRequestData(id=i) for i in ids],
|
||||
)
|
@ -0,0 +1,155 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import rockset
|
||||
import rockset.models
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.rocksetdb import Rockset
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
fake_texts,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# To run these tests, make sure you have a collection with the name `langchain_demo`
|
||||
# and the following ingest transformation:
|
||||
#
|
||||
# SELECT
|
||||
# _input.* EXCEPT(_meta),
|
||||
# VECTOR_ENFORCE(_input.description_embedding, 10, 'float') as
|
||||
# description_embedding
|
||||
# FROM
|
||||
# _input
|
||||
#
|
||||
# We're using FakeEmbeddings utility to create text embeddings.
|
||||
# It generates vector embeddings of length 10.
|
||||
#
|
||||
# Set env ROCKSET_DELETE_DOCS_ON_START=1 if you want to delete all docs from
|
||||
# the collection before running any test. Be careful, this will delete any
|
||||
# existing documents in your Rockset collection.
|
||||
#
|
||||
# See https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details.
|
||||
|
||||
collection_name = "langchain_demo"
|
||||
text_key = "description"
|
||||
embedding_key = "description_embedding"
|
||||
|
||||
|
||||
class TestRockset:
|
||||
rockset_vectorstore: Rockset
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
assert os.environ.get("ROCKSET_API_KEY") is not None
|
||||
assert os.environ.get("ROCKSET_REGION") is not None
|
||||
|
||||
api_key = os.environ.get("ROCKSET_API_KEY")
|
||||
region = os.environ.get("ROCKSET_REGION")
|
||||
if region == "use1a1":
|
||||
host = rockset.Regions.use1a1
|
||||
elif region == "usw2a1":
|
||||
host = rockset.Regions.usw2a1
|
||||
elif region == "euc1a1":
|
||||
host = rockset.Regions.euc1a1
|
||||
elif region == "dev":
|
||||
host = rockset.DevRegions.usw2a1
|
||||
else:
|
||||
logger.warn(
|
||||
"Using ROCKSET_REGION:%s as it is.. \
|
||||
You should know what you're doing...",
|
||||
region,
|
||||
)
|
||||
|
||||
host = region
|
||||
|
||||
client = rockset.RocksetClient(host, api_key)
|
||||
if os.environ.get("ROCKSET_DELETE_DOCS_ON_START") == "1":
|
||||
logger.info(
|
||||
"Deleting all existing documents from the Rockset collection %s",
|
||||
collection_name,
|
||||
)
|
||||
|
||||
query_response = client.Queries.query(
|
||||
sql={"query": "select _id from {}".format(collection_name)}
|
||||
)
|
||||
ids = [
|
||||
str(r["_id"])
|
||||
for r in getattr(
|
||||
query_response, query_response.attribute_map["results"]
|
||||
)
|
||||
]
|
||||
logger.info("Existing ids in collection: %s", ids)
|
||||
client.Documents.delete_documents(
|
||||
collection=collection_name,
|
||||
data=[rockset.models.DeleteDocumentsRequestData(id=i) for i in ids],
|
||||
)
|
||||
|
||||
embeddings = ConsistentFakeEmbeddings()
|
||||
embeddings.embed_documents(fake_texts)
|
||||
cls.rockset_vectorstore = Rockset(
|
||||
client, embeddings, collection_name, text_key, embedding_key
|
||||
)
|
||||
|
||||
def test_rockset_insert_and_search(self) -> None:
|
||||
"""Test end to end vector search in Rockset"""
|
||||
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"metadata_index": i} for i in range(len(texts))]
|
||||
ids = self.rockset_vectorstore.add_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
assert len(ids) == len(texts)
|
||||
# Test that `foo` is closest to `foo`
|
||||
output = self.rockset_vectorstore.similarity_search(
|
||||
query="foo", distance_func=Rockset.DistanceFunction.COSINE_SIM, k=1
|
||||
)
|
||||
assert output == [Document(page_content="foo", metadata={"metadata_index": 0})]
|
||||
|
||||
# Find closest vector to `foo` which is not `foo`
|
||||
output = self.rockset_vectorstore.similarity_search(
|
||||
query="foo",
|
||||
distance_func=Rockset.DistanceFunction.COSINE_SIM,
|
||||
k=1,
|
||||
where_str="metadata_index != 0",
|
||||
)
|
||||
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
|
||||
|
||||
def test_build_query_sql(self) -> None:
|
||||
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
q_str = self.rockset_vectorstore._build_query_sql(
|
||||
vector,
|
||||
Rockset.DistanceFunction.COSINE_SIM,
|
||||
4,
|
||||
)
|
||||
vector_str = ",".join(map(str, vector))
|
||||
expected = f"""\
|
||||
SELECT * EXCEPT(description_embedding), \
|
||||
COSINE_SIM(description_embedding, [{vector_str}]) as dist
|
||||
FROM langchain_demo
|
||||
ORDER BY dist DESC
|
||||
LIMIT 4
|
||||
"""
|
||||
assert q_str == expected
|
||||
|
||||
def test_build_query_sql_with_where(self) -> None:
|
||||
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
q_str = self.rockset_vectorstore._build_query_sql(
|
||||
vector,
|
||||
Rockset.DistanceFunction.COSINE_SIM,
|
||||
4,
|
||||
"age >= 10",
|
||||
)
|
||||
vector_str = ",".join(map(str, vector))
|
||||
expected = f"""\
|
||||
SELECT * EXCEPT(description_embedding), \
|
||||
COSINE_SIM(description_embedding, [{vector_str}]) as dist
|
||||
FROM langchain_demo
|
||||
WHERE age >= 10
|
||||
ORDER BY dist DESC
|
||||
LIMIT 4
|
||||
"""
|
||||
assert q_str == expected
|
Loading…
Reference in New Issue