From 074ad5095ff3ea61b66f4ecbcba2915385a560d3 Mon Sep 17 00:00:00 2001 From: morgana Date: Fri, 29 Mar 2024 11:45:22 -0700 Subject: [PATCH] community[patch]: mmr search for Rockset vectorstore integration (#16908) - **Description:** Adding support for mmr search in the Rockset vectorstore integration. - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** `@_morgan_adams_` --------- Co-authored-by: Rockset API Bot Co-authored-by: Bagatur Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- .../vectorstores/rocksetdb.py | 70 ++++++++++++++++++- .../vectorstores/test_rocksetdb.py | 46 ++++++++++-- 2 files changed, 110 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/rocksetdb.py b/libs/community/langchain_community/vectorstores/rocksetdb.py index ffb9f7f7f2..263a211fba 100644 --- a/libs/community/langchain_community/vectorstores/rocksetdb.py +++ b/libs/community/langchain_community/vectorstores/rocksetdb.py @@ -5,11 +5,14 @@ from copy import deepcopy from enum import Enum from typing import Any, Iterable, List, Optional, Tuple +import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.runnables import run_in_executor from langchain_core.vectorstores import VectorStore +from langchain_community.vectorstores.utils import maximal_marginal_relevance + logger = logging.getLogger(__name__) @@ -254,7 +257,12 @@ class Rockset(VectorStore): """Accepts a query_embedding (vector), and returns documents with similar embeddings along with their relevance scores.""" - q_str = self._build_query_sql(embedding, distance_func, k, where_str) + exclude_embeddings = True + if "exclude_embeddings" in kwargs: + exclude_embeddings = kwargs["exclude_embeddings"] + q_str = self._build_query_sql( + embedding, distance_func, k, where_str, exclude_embeddings + ) try: query_response = self._client.Queries.query(sql={"query": q_str}) except Exception as e: @@ -290,6 +298,60 @@ class Rockset(VectorStore): ) return finalResult + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + *, + where_str: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + distance_func (DistanceFunction): how to compute distance between two + vectors in Rockset. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + where_str: where clause for the sql query + Returns: + List of Documents selected by maximal marginal relevance. + """ + query_embedding = self._embeddings.embed_query(query) + initial_docs = self.similarity_search_by_vector( + query_embedding, + k=fetch_k, + where_str=where_str, + exclude_embeddings=False, + **kwargs, + ) + + embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs] + + selected_indices = maximal_marginal_relevance( + np.array(query_embedding), + embeddings, + lambda_mult=lambda_mult, + k=k, + ) + + # remove embeddings key before returning for cleanup to be consistent with + # other search functions + for i in selected_indices: + del initial_docs[i].metadata[self._embedding_key] + + return [initial_docs[i] for i in selected_indices] + # Helper functions def _build_query_sql( @@ -298,6 +360,7 @@ class Rockset(VectorStore): distance_func: DistanceFunction, k: int = 4, where_str: Optional[str] = None, + exclude_embeddings: bool = True, ) -> str: """Builds Rockset SQL query to query similar vectors to query_vector""" @@ -305,8 +368,11 @@ class Rockset(VectorStore): distance_str = f"""{distance_func.value}({self._embedding_key}, \ [{q_embedding_str}]) as dist""" where_str = f"WHERE {where_str}\n" if where_str else "" + select_embedding = ( + f" EXCEPT({self._embedding_key})," if exclude_embeddings else "," + ) return f"""\ -SELECT * EXCEPT({self._embedding_key}), {distance_str} +SELECT *{select_embedding} {distance_str} FROM {self._workspace}.{self._collection_name} {where_str}\ ORDER BY dist {distance_func.order_by()} diff --git a/libs/community/tests/integration_tests/vectorstores/test_rocksetdb.py b/libs/community/tests/integration_tests/vectorstores/test_rocksetdb.py index 56950ce8c4..9be8e16a90 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_rocksetdb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_rocksetdb.py @@ -96,16 +96,18 @@ class TestRockset: client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE ) - def test_rockset_insert_and_search(self) -> None: - """Test end to end vector search in Rockset""" - texts = ["foo", "bar", "baz"] metadatas = [{"metadata_index": i} for i in range(len(texts))] - ids = self.rockset_vectorstore.add_texts( + ids = cls.rockset_vectorstore.add_texts( texts=texts, metadatas=metadatas, ) + assert len(ids) == len(texts) + + def test_rockset_search(self) -> None: + """Test end-to-end vector search in Rockset""" + # Test that `foo` is closest to `foo` output = self.rockset_vectorstore.similarity_search( query="foo", distance_func=Rockset.DistanceFunction.COSINE_SIM, k=1 @@ -121,6 +123,26 @@ class TestRockset: ) assert output == [Document(page_content="bar", metadata={"metadata_index": 1})] + def test_rockset_mmr_search(self) -> None: + """Test end-to-end mmr search in Rockset""" + output = self.rockset_vectorstore.max_marginal_relevance_search( + query="foo", + distance_func=Rockset.DistanceFunction.COSINE_SIM, + fetch_k=1, + k=1, + ) + assert output == [Document(page_content="foo", metadata={"metadata_index": 0})] + + # Find closest vector to `foo` which is not `foo` + output = self.rockset_vectorstore.max_marginal_relevance_search( + query="foo", + distance_func=Rockset.DistanceFunction.COSINE_SIM, + fetch_k=3, + k=1, + where_str="metadata_index != 0", + ) + assert output == [Document(page_content="bar", metadata={"metadata_index": 1})] + def test_add_documents_and_delete(self) -> None: """ "add_documents" and "delete" are requirements to support use with RecordManager""" @@ -184,5 +206,21 @@ FROM {WORKSPACE}.{COLLECTION_NAME} WHERE age >= 10 ORDER BY dist DESC LIMIT 4 +""" + assert q_str == expected + + def test_build_query_sql_with_select_embeddings(self) -> None: + vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + q_str = self.rockset_vectorstore._build_query_sql( + vector, Rockset.DistanceFunction.COSINE_SIM, 4, "age >= 10", False + ) + vector_str = ",".join(map(str, vector)) + expected = f"""\ +SELECT *, \ +COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist +FROM {WORKSPACE}.{COLLECTION_NAME} +WHERE age >= 10 +ORDER BY dist DESC +LIMIT 4 """ assert q_str == expected