community[patch]: mmr search for Rockset vectorstore integration (#16908)

- **Description:** Adding support for mmr search in the Rockset
vectorstore integration.
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Twitter handle:** `@_morgan_adams_`

---------

Co-authored-by: Rockset API Bot <admin@rockset.io>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
pull/16805/head^2
morgana 6 months ago committed by GitHub
parent f51e6a35ba
commit 074ad5095f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,11 +5,14 @@ from copy import deepcopy
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
@ -254,7 +257,12 @@ class Rockset(VectorStore):
"""Accepts a query_embedding (vector), and returns documents with
similar embeddings along with their relevance scores."""
q_str = self._build_query_sql(embedding, distance_func, k, where_str)
exclude_embeddings = True
if "exclude_embeddings" in kwargs:
exclude_embeddings = kwargs["exclude_embeddings"]
q_str = self._build_query_sql(
embedding, distance_func, k, where_str, exclude_embeddings
)
try:
query_response = self._client.Queries.query(sql={"query": q_str})
except Exception as e:
@ -290,6 +298,60 @@ class Rockset(VectorStore):
)
return finalResult
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
*,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
distance_func (DistanceFunction): how to compute distance between two
vectors in Rockset.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
where_str: where clause for the sql query
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_embedding = self._embeddings.embed_query(query)
initial_docs = self.similarity_search_by_vector(
query_embedding,
k=fetch_k,
where_str=where_str,
exclude_embeddings=False,
**kwargs,
)
embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs]
selected_indices = maximal_marginal_relevance(
np.array(query_embedding),
embeddings,
lambda_mult=lambda_mult,
k=k,
)
# remove embeddings key before returning for cleanup to be consistent with
# other search functions
for i in selected_indices:
del initial_docs[i].metadata[self._embedding_key]
return [initial_docs[i] for i in selected_indices]
# Helper functions
def _build_query_sql(
@ -298,6 +360,7 @@ class Rockset(VectorStore):
distance_func: DistanceFunction,
k: int = 4,
where_str: Optional[str] = None,
exclude_embeddings: bool = True,
) -> str:
"""Builds Rockset SQL query to query similar vectors to query_vector"""
@ -305,8 +368,11 @@ class Rockset(VectorStore):
distance_str = f"""{distance_func.value}({self._embedding_key}, \
[{q_embedding_str}]) as dist"""
where_str = f"WHERE {where_str}\n" if where_str else ""
select_embedding = (
f" EXCEPT({self._embedding_key})," if exclude_embeddings else ","
)
return f"""\
SELECT * EXCEPT({self._embedding_key}), {distance_str}
SELECT *{select_embedding} {distance_str}
FROM {self._workspace}.{self._collection_name}
{where_str}\
ORDER BY dist {distance_func.order_by()}

@ -96,16 +96,18 @@ class TestRockset:
client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE
)
def test_rockset_insert_and_search(self) -> None:
"""Test end to end vector search in Rockset"""
texts = ["foo", "bar", "baz"]
metadatas = [{"metadata_index": i} for i in range(len(texts))]
ids = self.rockset_vectorstore.add_texts(
ids = cls.rockset_vectorstore.add_texts(
texts=texts,
metadatas=metadatas,
)
assert len(ids) == len(texts)
def test_rockset_search(self) -> None:
"""Test end-to-end vector search in Rockset"""
# Test that `foo` is closest to `foo`
output = self.rockset_vectorstore.similarity_search(
query="foo", distance_func=Rockset.DistanceFunction.COSINE_SIM, k=1
@ -121,6 +123,26 @@ class TestRockset:
)
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
def test_rockset_mmr_search(self) -> None:
"""Test end-to-end mmr search in Rockset"""
output = self.rockset_vectorstore.max_marginal_relevance_search(
query="foo",
distance_func=Rockset.DistanceFunction.COSINE_SIM,
fetch_k=1,
k=1,
)
assert output == [Document(page_content="foo", metadata={"metadata_index": 0})]
# Find closest vector to `foo` which is not `foo`
output = self.rockset_vectorstore.max_marginal_relevance_search(
query="foo",
distance_func=Rockset.DistanceFunction.COSINE_SIM,
fetch_k=3,
k=1,
where_str="metadata_index != 0",
)
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
def test_add_documents_and_delete(self) -> None:
""" "add_documents" and "delete" are requirements to support use
with RecordManager"""
@ -184,5 +206,21 @@ FROM {WORKSPACE}.{COLLECTION_NAME}
WHERE age >= 10
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected
def test_build_query_sql_with_select_embeddings(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector, Rockset.DistanceFunction.COSINE_SIM, 4, "age >= 10", False
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT *, \
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
FROM {WORKSPACE}.{COLLECTION_NAME}
WHERE age >= 10
ORDER BY dist DESC
LIMIT 4
"""
assert q_str == expected

Loading…
Cancel
Save