langchain/libs/community/langchain_community/vectorstores/rocksetdb.py
morgana 074ad5095f
community[patch]: mmr search for Rockset vectorstore integration (#16908)
- **Description:** Adding support for mmr search in the Rockset
vectorstore integration.
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Twitter handle:** `@_morgan_adams_`

---------

Co-authored-by: Rockset API Bot <admin@rockset.io>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
2024-03-29 18:45:22 +00:00

419 lines
15 KiB
Python

from __future__ import annotations
import logging
from copy import deepcopy
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
class Rockset(VectorStore):
"""`Rockset` vector store.
To use, you should have the `rockset` python package installed. Note that to use
this, the collection being used must already exist in your Rockset instance.
You must also ensure you use a Rockset ingest transformation to apply
`VECTOR_ENFORCE` on the column being used to store `embedding_key` in the
collection.
See: https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details
Everything below assumes `commons` Rockset workspace.
Example:
.. code-block:: python
from langchain_community.vectorstores import Rockset
from langchain_community.embeddings.openai import OpenAIEmbeddings
import rockset
# Make sure you use the right host (region) for your Rockset instance
# and APIKEY has both read-write access to your collection.
rs = rockset.RocksetClient(host=rockset.Regions.use1a1, api_key="***")
collection_name = "langchain_demo"
embeddings = OpenAIEmbeddings()
vectorstore = Rockset(rs, collection_name, embeddings,
"description", "description_embedding")
"""
def __init__(
self,
client: Any,
embeddings: Embeddings,
collection_name: str,
text_key: str,
embedding_key: str,
workspace: str = "commons",
):
"""Initialize with Rockset client.
Args:
client: Rockset client object
collection: Rockset collection to insert docs / query
embeddings: Langchain Embeddings object to use to generate
embedding for given text.
text_key: column in Rockset collection to use to store the text
embedding_key: column in Rockset collection to use to store the embedding.
Note: We must apply `VECTOR_ENFORCE()` on this column via
Rockset ingest transformation.
"""
try:
from rockset import RocksetClient
except ImportError:
raise ImportError(
"Could not import rockset client python package. "
"Please install it with `pip install rockset`."
)
if not isinstance(client, RocksetClient):
raise ValueError(
f"client should be an instance of rockset.RocksetClient, "
f"got {type(client)}"
)
# TODO: check that `collection_name` exists in rockset. Create if not.
self._client = client
self._collection_name = collection_name
self._embeddings = embeddings
self._text_key = text_key
self._embedding_key = embedding_key
self._workspace = workspace
try:
self._client.set_application("langchain")
except AttributeError:
# ignore
pass
@property
def embeddings(self) -> Embeddings:
return self._embeddings
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 32,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of ids to associate with the texts.
batch_size: Send documents in batches to rockset.
Returns:
List of ids from adding the texts into the vectorstore.
"""
batch: list[dict] = []
stored_ids = []
for i, text in enumerate(texts):
if len(batch) == batch_size:
stored_ids += self._write_documents_to_rockset(batch)
batch = []
doc = {}
if metadatas and len(metadatas) > i:
doc = deepcopy(metadatas[i])
if ids and len(ids) > i:
doc["_id"] = ids[i]
doc[self._text_key] = text
doc[self._embedding_key] = self._embeddings.embed_query(text)
batch.append(doc)
if len(batch) > 0:
stored_ids += self._write_documents_to_rockset(batch)
batch = []
return stored_ids
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
client: Any = None,
collection_name: str = "",
text_key: str = "",
embedding_key: str = "",
ids: Optional[List[str]] = None,
batch_size: int = 32,
**kwargs: Any,
) -> Rockset:
"""Create Rockset wrapper with existing texts.
This is intended as a quicker way to get started.
"""
# Sanitize inputs
assert client is not None, "Rockset Client cannot be None"
assert collection_name, "Collection name cannot be empty"
assert text_key, "Text key name cannot be empty"
assert embedding_key, "Embedding key cannot be empty"
rockset = cls(client, embedding, collection_name, text_key, embedding_key)
rockset.add_texts(texts, metadatas, ids, batch_size)
return rockset
# Rockset supports these vector distance functions.
class DistanceFunction(Enum):
COSINE_SIM = "COSINE_SIM"
EUCLIDEAN_DIST = "EUCLIDEAN_DIST"
DOT_PRODUCT = "DOT_PRODUCT"
# how to sort results for "similarity"
def order_by(self) -> str:
if self.value == "EUCLIDEAN_DIST":
return "ASC"
return "DESC"
def similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a similarity search with Rockset
Args:
query (str): Text to look up documents similar to.
distance_func (DistanceFunction): how to compute distance between two
vectors in Rockset.
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
where_str (Optional[str], optional): Metadata filters supplied as a
SQL `where` condition string. Defaults to None.
eg. "price<=70.0 AND brand='Nintendo'"
NOTE: Please do not let end-user to fill this and always be aware
of SQL injection.
Returns:
List[Tuple[Document, float]]: List of documents with their relevance score
"""
return self.similarity_search_by_vector_with_relevance_scores(
self._embeddings.embed_query(query),
k,
distance_func,
where_str,
**kwargs,
)
def similarity_search(
self,
query: str,
k: int = 4,
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Same as `similarity_search_with_relevance_scores` but
doesn't return the scores.
"""
return self.similarity_search_by_vector(
self._embeddings.embed_query(query),
k,
distance_func,
where_str,
**kwargs,
)
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Accepts a query_embedding (vector), and returns documents with
similar embeddings."""
docs_and_scores = self.similarity_search_by_vector_with_relevance_scores(
embedding, k, distance_func, where_str, **kwargs
)
return [doc for doc, _ in docs_and_scores]
def similarity_search_by_vector_with_relevance_scores(
self,
embedding: List[float],
k: int = 4,
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Accepts a query_embedding (vector), and returns documents with
similar embeddings along with their relevance scores."""
exclude_embeddings = True
if "exclude_embeddings" in kwargs:
exclude_embeddings = kwargs["exclude_embeddings"]
q_str = self._build_query_sql(
embedding, distance_func, k, where_str, exclude_embeddings
)
try:
query_response = self._client.Queries.query(sql={"query": q_str})
except Exception as e:
logger.error("Exception when querying Rockset: %s\n", e)
return []
finalResult: list[Tuple[Document, float]] = []
for document in query_response.results:
metadata = {}
assert isinstance(
document, dict
), "document should be of type `dict[str,Any]`. But found: `{}`".format(
type(document)
)
for k, v in document.items():
if k == self._text_key:
assert isinstance(v, str), (
"page content stored in column `{}` must be of type `str`. "
"But found: `{}`"
).format(self._text_key, type(v))
page_content = v
elif k == "dist":
assert isinstance(v, float), (
"Computed distance between vectors must of type `float`. "
"But found {}"
).format(type(v))
score = v
elif k not in ["_id", "_event_time", "_meta"]:
# These columns are populated by Rockset when documents are
# inserted. No need to return them in metadata dict.
metadata[k] = v
finalResult.append(
(Document(page_content=page_content, metadata=metadata), score)
)
return finalResult
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
*,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
distance_func (DistanceFunction): how to compute distance between two
vectors in Rockset.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
where_str: where clause for the sql query
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_embedding = self._embeddings.embed_query(query)
initial_docs = self.similarity_search_by_vector(
query_embedding,
k=fetch_k,
where_str=where_str,
exclude_embeddings=False,
**kwargs,
)
embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs]
selected_indices = maximal_marginal_relevance(
np.array(query_embedding),
embeddings,
lambda_mult=lambda_mult,
k=k,
)
# remove embeddings key before returning for cleanup to be consistent with
# other search functions
for i in selected_indices:
del initial_docs[i].metadata[self._embedding_key]
return [initial_docs[i] for i in selected_indices]
# Helper functions
def _build_query_sql(
self,
query_embedding: List[float],
distance_func: DistanceFunction,
k: int = 4,
where_str: Optional[str] = None,
exclude_embeddings: bool = True,
) -> str:
"""Builds Rockset SQL query to query similar vectors to query_vector"""
q_embedding_str = ",".join(map(str, query_embedding))
distance_str = f"""{distance_func.value}({self._embedding_key}, \
[{q_embedding_str}]) as dist"""
where_str = f"WHERE {where_str}\n" if where_str else ""
select_embedding = (
f" EXCEPT({self._embedding_key})," if exclude_embeddings else ","
)
return f"""\
SELECT *{select_embedding} {distance_str}
FROM {self._workspace}.{self._collection_name}
{where_str}\
ORDER BY dist {distance_func.order_by()}
LIMIT {str(k)}
"""
def _write_documents_to_rockset(self, batch: List[dict]) -> List[str]:
add_doc_res = self._client.Documents.add_documents(
collection=self._collection_name, data=batch, workspace=self._workspace
)
return [doc_status._id for doc_status in add_doc_res.data]
def delete_texts(self, ids: List[str]) -> None:
"""Delete a list of docs from the Rockset collection"""
try:
from rockset.models import DeleteDocumentsRequestData
except ImportError:
raise ImportError(
"Could not import rockset client python package. "
"Please install it with `pip install rockset`."
)
self._client.Documents.delete_documents(
collection=self._collection_name,
data=[DeleteDocumentsRequestData(id=i) for i in ids],
workspace=self._workspace,
)
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
try:
if ids is None:
ids = []
self.delete_texts(ids)
except Exception as e:
logger.error("Exception when deleting docs from Rockset: %s\n", e)
return False
return True
async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
) -> Optional[bool]:
return await run_in_executor(None, self.delete, ids, **kwargs)