mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
074ad5095f
- **Description:** Adding support for mmr search in the Rockset vectorstore integration. - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** `@_morgan_adams_` --------- Co-authored-by: Rockset API Bot <admin@rockset.io> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
419 lines
15 KiB
Python
419 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from copy import deepcopy
|
|
from enum import Enum
|
|
from typing import Any, Iterable, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
from langchain_core.documents import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.runnables import run_in_executor
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Rockset(VectorStore):
|
|
"""`Rockset` vector store.
|
|
|
|
To use, you should have the `rockset` python package installed. Note that to use
|
|
this, the collection being used must already exist in your Rockset instance.
|
|
You must also ensure you use a Rockset ingest transformation to apply
|
|
`VECTOR_ENFORCE` on the column being used to store `embedding_key` in the
|
|
collection.
|
|
See: https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details
|
|
|
|
Everything below assumes `commons` Rockset workspace.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.vectorstores import Rockset
|
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
|
import rockset
|
|
|
|
# Make sure you use the right host (region) for your Rockset instance
|
|
# and APIKEY has both read-write access to your collection.
|
|
|
|
rs = rockset.RocksetClient(host=rockset.Regions.use1a1, api_key="***")
|
|
collection_name = "langchain_demo"
|
|
embeddings = OpenAIEmbeddings()
|
|
vectorstore = Rockset(rs, collection_name, embeddings,
|
|
"description", "description_embedding")
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
client: Any,
|
|
embeddings: Embeddings,
|
|
collection_name: str,
|
|
text_key: str,
|
|
embedding_key: str,
|
|
workspace: str = "commons",
|
|
):
|
|
"""Initialize with Rockset client.
|
|
Args:
|
|
client: Rockset client object
|
|
collection: Rockset collection to insert docs / query
|
|
embeddings: Langchain Embeddings object to use to generate
|
|
embedding for given text.
|
|
text_key: column in Rockset collection to use to store the text
|
|
embedding_key: column in Rockset collection to use to store the embedding.
|
|
Note: We must apply `VECTOR_ENFORCE()` on this column via
|
|
Rockset ingest transformation.
|
|
|
|
"""
|
|
try:
|
|
from rockset import RocksetClient
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import rockset client python package. "
|
|
"Please install it with `pip install rockset`."
|
|
)
|
|
|
|
if not isinstance(client, RocksetClient):
|
|
raise ValueError(
|
|
f"client should be an instance of rockset.RocksetClient, "
|
|
f"got {type(client)}"
|
|
)
|
|
# TODO: check that `collection_name` exists in rockset. Create if not.
|
|
self._client = client
|
|
self._collection_name = collection_name
|
|
self._embeddings = embeddings
|
|
self._text_key = text_key
|
|
self._embedding_key = embedding_key
|
|
self._workspace = workspace
|
|
|
|
try:
|
|
self._client.set_application("langchain")
|
|
except AttributeError:
|
|
# ignore
|
|
pass
|
|
|
|
@property
|
|
def embeddings(self) -> Embeddings:
|
|
return self._embeddings
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
ids: Optional[List[str]] = None,
|
|
batch_size: int = 32,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
"""Run more texts through the embeddings and add to the vectorstore
|
|
|
|
Args:
|
|
texts: Iterable of strings to add to the vectorstore.
|
|
metadatas: Optional list of metadatas associated with the texts.
|
|
ids: Optional list of ids to associate with the texts.
|
|
batch_size: Send documents in batches to rockset.
|
|
|
|
Returns:
|
|
List of ids from adding the texts into the vectorstore.
|
|
|
|
"""
|
|
batch: list[dict] = []
|
|
stored_ids = []
|
|
|
|
for i, text in enumerate(texts):
|
|
if len(batch) == batch_size:
|
|
stored_ids += self._write_documents_to_rockset(batch)
|
|
batch = []
|
|
doc = {}
|
|
if metadatas and len(metadatas) > i:
|
|
doc = deepcopy(metadatas[i])
|
|
if ids and len(ids) > i:
|
|
doc["_id"] = ids[i]
|
|
doc[self._text_key] = text
|
|
doc[self._embedding_key] = self._embeddings.embed_query(text)
|
|
batch.append(doc)
|
|
if len(batch) > 0:
|
|
stored_ids += self._write_documents_to_rockset(batch)
|
|
batch = []
|
|
return stored_ids
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls,
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
client: Any = None,
|
|
collection_name: str = "",
|
|
text_key: str = "",
|
|
embedding_key: str = "",
|
|
ids: Optional[List[str]] = None,
|
|
batch_size: int = 32,
|
|
**kwargs: Any,
|
|
) -> Rockset:
|
|
"""Create Rockset wrapper with existing texts.
|
|
This is intended as a quicker way to get started.
|
|
"""
|
|
|
|
# Sanitize inputs
|
|
assert client is not None, "Rockset Client cannot be None"
|
|
assert collection_name, "Collection name cannot be empty"
|
|
assert text_key, "Text key name cannot be empty"
|
|
assert embedding_key, "Embedding key cannot be empty"
|
|
|
|
rockset = cls(client, embedding, collection_name, text_key, embedding_key)
|
|
rockset.add_texts(texts, metadatas, ids, batch_size)
|
|
return rockset
|
|
|
|
# Rockset supports these vector distance functions.
|
|
class DistanceFunction(Enum):
|
|
COSINE_SIM = "COSINE_SIM"
|
|
EUCLIDEAN_DIST = "EUCLIDEAN_DIST"
|
|
DOT_PRODUCT = "DOT_PRODUCT"
|
|
|
|
# how to sort results for "similarity"
|
|
def order_by(self) -> str:
|
|
if self.value == "EUCLIDEAN_DIST":
|
|
return "ASC"
|
|
return "DESC"
|
|
|
|
def similarity_search_with_relevance_scores(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
|
|
where_str: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Perform a similarity search with Rockset
|
|
|
|
Args:
|
|
query (str): Text to look up documents similar to.
|
|
distance_func (DistanceFunction): how to compute distance between two
|
|
vectors in Rockset.
|
|
k (int, optional): Top K neighbors to retrieve. Defaults to 4.
|
|
where_str (Optional[str], optional): Metadata filters supplied as a
|
|
SQL `where` condition string. Defaults to None.
|
|
eg. "price<=70.0 AND brand='Nintendo'"
|
|
|
|
NOTE: Please do not let end-user to fill this and always be aware
|
|
of SQL injection.
|
|
|
|
Returns:
|
|
List[Tuple[Document, float]]: List of documents with their relevance score
|
|
"""
|
|
return self.similarity_search_by_vector_with_relevance_scores(
|
|
self._embeddings.embed_query(query),
|
|
k,
|
|
distance_func,
|
|
where_str,
|
|
**kwargs,
|
|
)
|
|
|
|
def similarity_search(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
|
|
where_str: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Same as `similarity_search_with_relevance_scores` but
|
|
doesn't return the scores.
|
|
"""
|
|
return self.similarity_search_by_vector(
|
|
self._embeddings.embed_query(query),
|
|
k,
|
|
distance_func,
|
|
where_str,
|
|
**kwargs,
|
|
)
|
|
|
|
def similarity_search_by_vector(
|
|
self,
|
|
embedding: List[float],
|
|
k: int = 4,
|
|
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
|
|
where_str: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Accepts a query_embedding (vector), and returns documents with
|
|
similar embeddings."""
|
|
|
|
docs_and_scores = self.similarity_search_by_vector_with_relevance_scores(
|
|
embedding, k, distance_func, where_str, **kwargs
|
|
)
|
|
return [doc for doc, _ in docs_and_scores]
|
|
|
|
def similarity_search_by_vector_with_relevance_scores(
|
|
self,
|
|
embedding: List[float],
|
|
k: int = 4,
|
|
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
|
|
where_str: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Accepts a query_embedding (vector), and returns documents with
|
|
similar embeddings along with their relevance scores."""
|
|
|
|
exclude_embeddings = True
|
|
if "exclude_embeddings" in kwargs:
|
|
exclude_embeddings = kwargs["exclude_embeddings"]
|
|
q_str = self._build_query_sql(
|
|
embedding, distance_func, k, where_str, exclude_embeddings
|
|
)
|
|
try:
|
|
query_response = self._client.Queries.query(sql={"query": q_str})
|
|
except Exception as e:
|
|
logger.error("Exception when querying Rockset: %s\n", e)
|
|
return []
|
|
finalResult: list[Tuple[Document, float]] = []
|
|
for document in query_response.results:
|
|
metadata = {}
|
|
assert isinstance(
|
|
document, dict
|
|
), "document should be of type `dict[str,Any]`. But found: `{}`".format(
|
|
type(document)
|
|
)
|
|
for k, v in document.items():
|
|
if k == self._text_key:
|
|
assert isinstance(v, str), (
|
|
"page content stored in column `{}` must be of type `str`. "
|
|
"But found: `{}`"
|
|
).format(self._text_key, type(v))
|
|
page_content = v
|
|
elif k == "dist":
|
|
assert isinstance(v, float), (
|
|
"Computed distance between vectors must of type `float`. "
|
|
"But found {}"
|
|
).format(type(v))
|
|
score = v
|
|
elif k not in ["_id", "_event_time", "_meta"]:
|
|
# These columns are populated by Rockset when documents are
|
|
# inserted. No need to return them in metadata dict.
|
|
metadata[k] = v
|
|
finalResult.append(
|
|
(Document(page_content=page_content, metadata=metadata), score)
|
|
)
|
|
return finalResult
|
|
|
|
def max_marginal_relevance_search(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
fetch_k: int = 20,
|
|
lambda_mult: float = 0.5,
|
|
*,
|
|
where_str: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Return docs selected using the maximal marginal relevance.
|
|
|
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
among selected documents.
|
|
|
|
Args:
|
|
query: Text to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
distance_func (DistanceFunction): how to compute distance between two
|
|
vectors in Rockset.
|
|
lambda_mult: Number between 0 and 1 that determines the degree
|
|
of diversity among the results with 0 corresponding
|
|
to maximum diversity and 1 to minimum diversity.
|
|
Defaults to 0.5.
|
|
where_str: where clause for the sql query
|
|
Returns:
|
|
List of Documents selected by maximal marginal relevance.
|
|
"""
|
|
query_embedding = self._embeddings.embed_query(query)
|
|
initial_docs = self.similarity_search_by_vector(
|
|
query_embedding,
|
|
k=fetch_k,
|
|
where_str=where_str,
|
|
exclude_embeddings=False,
|
|
**kwargs,
|
|
)
|
|
|
|
embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs]
|
|
|
|
selected_indices = maximal_marginal_relevance(
|
|
np.array(query_embedding),
|
|
embeddings,
|
|
lambda_mult=lambda_mult,
|
|
k=k,
|
|
)
|
|
|
|
# remove embeddings key before returning for cleanup to be consistent with
|
|
# other search functions
|
|
for i in selected_indices:
|
|
del initial_docs[i].metadata[self._embedding_key]
|
|
|
|
return [initial_docs[i] for i in selected_indices]
|
|
|
|
# Helper functions
|
|
|
|
def _build_query_sql(
|
|
self,
|
|
query_embedding: List[float],
|
|
distance_func: DistanceFunction,
|
|
k: int = 4,
|
|
where_str: Optional[str] = None,
|
|
exclude_embeddings: bool = True,
|
|
) -> str:
|
|
"""Builds Rockset SQL query to query similar vectors to query_vector"""
|
|
|
|
q_embedding_str = ",".join(map(str, query_embedding))
|
|
distance_str = f"""{distance_func.value}({self._embedding_key}, \
|
|
[{q_embedding_str}]) as dist"""
|
|
where_str = f"WHERE {where_str}\n" if where_str else ""
|
|
select_embedding = (
|
|
f" EXCEPT({self._embedding_key})," if exclude_embeddings else ","
|
|
)
|
|
return f"""\
|
|
SELECT *{select_embedding} {distance_str}
|
|
FROM {self._workspace}.{self._collection_name}
|
|
{where_str}\
|
|
ORDER BY dist {distance_func.order_by()}
|
|
LIMIT {str(k)}
|
|
"""
|
|
|
|
def _write_documents_to_rockset(self, batch: List[dict]) -> List[str]:
|
|
add_doc_res = self._client.Documents.add_documents(
|
|
collection=self._collection_name, data=batch, workspace=self._workspace
|
|
)
|
|
return [doc_status._id for doc_status in add_doc_res.data]
|
|
|
|
def delete_texts(self, ids: List[str]) -> None:
|
|
"""Delete a list of docs from the Rockset collection"""
|
|
try:
|
|
from rockset.models import DeleteDocumentsRequestData
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import rockset client python package. "
|
|
"Please install it with `pip install rockset`."
|
|
)
|
|
|
|
self._client.Documents.delete_documents(
|
|
collection=self._collection_name,
|
|
data=[DeleteDocumentsRequestData(id=i) for i in ids],
|
|
workspace=self._workspace,
|
|
)
|
|
|
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
|
try:
|
|
if ids is None:
|
|
ids = []
|
|
self.delete_texts(ids)
|
|
except Exception as e:
|
|
logger.error("Exception when deleting docs from Rockset: %s\n", e)
|
|
return False
|
|
|
|
return True
|
|
|
|
async def adelete(
|
|
self, ids: Optional[List[str]] = None, **kwargs: Any
|
|
) -> Optional[bool]:
|
|
return await run_in_executor(None, self.delete, ids, **kwargs)
|