community[patch]: mmr search for Rockset vectorstore integration (#16908)

- **Description:** Adding support for mmr search in the Rockset
vectorstore integration.
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Twitter handle:** `@_morgan_adams_`

---------

Co-authored-by: Rockset API Bot <admin@rockset.io>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
morgana 2024-03-29 11:45:22 -07:00 committed by GitHub
parent f51e6a35ba
commit 074ad5095f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 110 additions and 6 deletions

View File

@ -5,11 +5,14 @@ from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
import numpy as np
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -254,7 +257,12 @@ class Rockset(VectorStore):
"""Accepts a query_embedding (vector), and returns documents with """Accepts a query_embedding (vector), and returns documents with
similar embeddings along with their relevance scores.""" similar embeddings along with their relevance scores."""
q_str = self._build_query_sql(embedding, distance_func, k, where_str) exclude_embeddings = True
if "exclude_embeddings" in kwargs:
exclude_embeddings = kwargs["exclude_embeddings"]
q_str = self._build_query_sql(
embedding, distance_func, k, where_str, exclude_embeddings
)
try: try:
query_response = self._client.Queries.query(sql={"query": q_str}) query_response = self._client.Queries.query(sql={"query": q_str})
except Exception as e: except Exception as e:
@ -290,6 +298,60 @@ class Rockset(VectorStore):
) )
return finalResult return finalResult
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
*,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
distance_func (DistanceFunction): how to compute distance between two
vectors in Rockset.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
where_str: where clause for the sql query
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_embedding = self._embeddings.embed_query(query)
initial_docs = self.similarity_search_by_vector(
query_embedding,
k=fetch_k,
where_str=where_str,
exclude_embeddings=False,
**kwargs,
)
embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs]
selected_indices = maximal_marginal_relevance(
np.array(query_embedding),
embeddings,
lambda_mult=lambda_mult,
k=k,
)
# remove embeddings key before returning for cleanup to be consistent with
# other search functions
for i in selected_indices:
del initial_docs[i].metadata[self._embedding_key]
return [initial_docs[i] for i in selected_indices]
# Helper functions # Helper functions
def _build_query_sql( def _build_query_sql(
@ -298,6 +360,7 @@ class Rockset(VectorStore):
distance_func: DistanceFunction, distance_func: DistanceFunction,
k: int = 4, k: int = 4,
where_str: Optional[str] = None, where_str: Optional[str] = None,
exclude_embeddings: bool = True,
) -> str: ) -> str:
"""Builds Rockset SQL query to query similar vectors to query_vector""" """Builds Rockset SQL query to query similar vectors to query_vector"""
@ -305,8 +368,11 @@ class Rockset(VectorStore):
distance_str = f"""{distance_func.value}({self._embedding_key}, \ distance_str = f"""{distance_func.value}({self._embedding_key}, \
[{q_embedding_str}]) as dist""" [{q_embedding_str}]) as dist"""
where_str = f"WHERE {where_str}\n" if where_str else "" where_str = f"WHERE {where_str}\n" if where_str else ""
select_embedding = (
f" EXCEPT({self._embedding_key})," if exclude_embeddings else ","
)
return f"""\ return f"""\
SELECT * EXCEPT({self._embedding_key}), {distance_str} SELECT *{select_embedding} {distance_str}
FROM {self._workspace}.{self._collection_name} FROM {self._workspace}.{self._collection_name}
{where_str}\ {where_str}\
ORDER BY dist {distance_func.order_by()} ORDER BY dist {distance_func.order_by()}

View File

@ -96,16 +96,18 @@ class TestRockset:
client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE
) )
def test_rockset_insert_and_search(self) -> None:
"""Test end to end vector search in Rockset"""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
metadatas = [{"metadata_index": i} for i in range(len(texts))] metadatas = [{"metadata_index": i} for i in range(len(texts))]
ids = self.rockset_vectorstore.add_texts( ids = cls.rockset_vectorstore.add_texts(
texts=texts, texts=texts,
metadatas=metadatas, metadatas=metadatas,
) )
assert len(ids) == len(texts) assert len(ids) == len(texts)
def test_rockset_search(self) -> None:
"""Test end-to-end vector search in Rockset"""
# Test that `foo` is closest to `foo` # Test that `foo` is closest to `foo`
output = self.rockset_vectorstore.similarity_search( output = self.rockset_vectorstore.similarity_search(
query="foo", distance_func=Rockset.DistanceFunction.COSINE_SIM, k=1 query="foo", distance_func=Rockset.DistanceFunction.COSINE_SIM, k=1
@ -121,6 +123,26 @@ class TestRockset:
) )
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})] assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
def test_rockset_mmr_search(self) -> None:
"""Test end-to-end mmr search in Rockset"""
output = self.rockset_vectorstore.max_marginal_relevance_search(
query="foo",
distance_func=Rockset.DistanceFunction.COSINE_SIM,
fetch_k=1,
k=1,
)
assert output == [Document(page_content="foo", metadata={"metadata_index": 0})]
# Find closest vector to `foo` which is not `foo`
output = self.rockset_vectorstore.max_marginal_relevance_search(
query="foo",
distance_func=Rockset.DistanceFunction.COSINE_SIM,
fetch_k=3,
k=1,
where_str="metadata_index != 0",
)
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
def test_add_documents_and_delete(self) -> None: def test_add_documents_and_delete(self) -> None:
""" "add_documents" and "delete" are requirements to support use """ "add_documents" and "delete" are requirements to support use
with RecordManager""" with RecordManager"""
@ -184,5 +206,21 @@ FROM {WORKSPACE}.{COLLECTION_NAME}
WHERE age >= 10 WHERE age >= 10
ORDER BY dist DESC ORDER BY dist DESC
LIMIT 4 LIMIT 4
"""
assert q_str == expected
def test_build_query_sql_with_select_embeddings(self) -> None:
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
q_str = self.rockset_vectorstore._build_query_sql(
vector, Rockset.DistanceFunction.COSINE_SIM, 4, "age >= 10", False
)
vector_str = ",".join(map(str, vector))
expected = f"""\
SELECT *, \
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
FROM {WORKSPACE}.{COLLECTION_NAME}
WHERE age >= 10
ORDER BY dist DESC
LIMIT 4
""" """
assert q_str == expected assert q_str == expected