From 8c0f391815eac61f2b5d1b993e9bc4795808696f Mon Sep 17 00:00:00 2001 From: Hamza Tahboub Date: Fri, 8 Sep 2023 15:14:44 -0700 Subject: [PATCH] Implemented MMR search for Redis (#10140) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Description: Implemented MMR search for Redis. Pretty straightforward, just using the already implemented MMR method on similarity search–fetched docs. Issue: #10059 Dependencies: None Twitter handle: @hamza_tahboub --------- Co-authored-by: Bagatur --- .../integrations/vectorstores/redis.ipynb | 23 +++- libs/langchain/langchain/utilities/redis.py | 4 + .../langchain/vectorstores/redis/base.py | 124 ++++++++++++++++-- .../vectorstores/test_redis.py | 42 +++++- 4 files changed, 179 insertions(+), 14 deletions(-) diff --git a/docs/extras/integrations/vectorstores/redis.ipynb b/docs/extras/integrations/vectorstores/redis.ipynb index ae17b0e4e6..f729be1599 100644 --- a/docs/extras/integrations/vectorstores/redis.ipynb +++ b/docs/extras/integrations/vectorstores/redis.ipynb @@ -413,7 +413,8 @@ "- ``similarity_search``: Find the most similar vectors to a given vector.\n", "- ``similarity_search_with_score``: Find the most similar vectors to a given vector and return the vector distance\n", "- ``similarity_search_limit_score``: Find the most similar vectors to a given vector and limit the number of results to the ``score_threshold``\n", - "- ``similarity_search_with_relevance_scores``: Find the most similar vectors to a given vector and return the vector similarities" + "- ``similarity_search_with_relevance_scores``: Find the most similar vectors to a given vector and return the vector similarities\n", + "- ``max_marginal_relevance_search``: Find the most similar vectors to a given vector while also optimizing for diversity" ] }, { @@ -596,6 +597,26 @@ "print(results[0].metadata)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# use maximal marginal relevance search to diversify results\n", + "results = rds.max_marginal_relevance_search(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# the lambda_mult parameter controls the diversity of the results, the lower the more diverse\n", + "results = rds.max_marginal_relevance_search(\"foo\", lambda_mult=0.1)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/libs/langchain/langchain/utilities/redis.py b/libs/langchain/langchain/utilities/redis.py index a45391c8bc..605a611967 100644 --- a/libs/langchain/langchain/utilities/redis.py +++ b/libs/langchain/langchain/utilities/redis.py @@ -17,6 +17,10 @@ def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes: return np.array(array).astype(dtype).tobytes() +def _buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: + return np.frombuffer(buffer, dtype=dtype).tolist() + + class TokenEscaper: """ Escape punctuation within an input string. diff --git a/libs/langchain/langchain/vectorstores/redis/base.py b/libs/langchain/langchain/vectorstores/redis/base.py index a09ba44cba..fe973b4831 100644 --- a/libs/langchain/langchain/vectorstores/redis/base.py +++ b/libs/langchain/langchain/vectorstores/redis/base.py @@ -17,8 +17,10 @@ from typing import ( Tuple, Type, Union, + cast, ) +import numpy as np import yaml from langchain._api import deprecated @@ -30,6 +32,7 @@ from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.utilities.redis import ( _array_to_buffer, + _buffer_to_array, check_redis_module_exist, get_client, ) @@ -39,6 +42,7 @@ from langchain.vectorstores.redis.constants import ( REDIS_REQUIRED_MODULES, REDIS_TAG_SEPARATOR, ) +from langchain.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger(__name__) @@ -803,8 +807,10 @@ class Redis(VectorStore): + "score_threshold will be removed in a future release.", ) + query_embedding = self._embeddings.embed_query(query) + redis_query, params_dict = self._prepare_query( - query, + query_embedding, k=k, filter=filter, with_metadata=return_metadata, @@ -858,13 +864,48 @@ class Redis(VectorStore): Defaults to None. return_metadata (bool, optional): Whether to return metadata. Defaults to True. - distance_threshold (Optional[float], optional): Distance threshold - for vector distance from query vector. Defaults to None. + distance_threshold (Optional[float], optional): Maximum vector distance + between selected documents and the query vector. Defaults to None. Returns: List[Document]: A list of documents that are most similar to the query text. + """ + query_embedding = self._embeddings.embed_query(query) + return self.similarity_search_by_vector( + query_embedding, + k=k, + filter=filter, + return_metadata=return_metadata, + distance_threshold=distance_threshold, + **kwargs, + ) + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[RedisFilterExpression] = None, + return_metadata: bool = True, + distance_threshold: Optional[float] = None, + **kwargs: Any, + ) -> List[Document]: + """Run similarity search between a query vector and the indexed vectors. + Args: + embedding (List[float]): The query vector for which to find similar + documents. + k (int): The number of documents to return. Default is 4. + filter (RedisFilterExpression, optional): Optional metadata filter. + Defaults to None. + return_metadata (bool, optional): Whether to return metadata. + Defaults to True. + distance_threshold (Optional[float], optional): Maximum vector distance + between selected documents and the query vector. Defaults to None. + + Returns: + List[Document]: A list of documents that are most similar to the query + text. """ try: import redis @@ -884,7 +925,7 @@ class Redis(VectorStore): ) redis_query, params_dict = self._prepare_query( - query, + embedding, k=k, filter=filter, distance_threshold=distance_threshold, @@ -920,6 +961,74 @@ class Redis(VectorStore): ) return docs + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[RedisFilterExpression] = None, + return_metadata: bool = True, + distance_threshold: Optional[float] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (RedisFilterExpression, optional): Optional metadata filter. + Defaults to None. + return_metadata (bool, optional): Whether to return metadata. + Defaults to True. + distance_threshold (Optional[float], optional): Maximum vector distance + between selected documents and the query vector. Defaults to None. + + Returns: + List[Document]: A list of Documents selected by maximal marginal relevance. + """ + # Embed the query + query_embedding = self._embeddings.embed_query(query) + + # Fetch the initial documents + prefetch_docs = self.similarity_search_by_vector( + query_embedding, + k=fetch_k, + filter=filter, + return_metadata=return_metadata, + distance_threshold=distance_threshold, + **kwargs, + ) + prefetch_ids = [doc.metadata["id"] for doc in prefetch_docs] + + # Get the embeddings for the fetched documents + prefetch_embeddings = [ + _buffer_to_array( + cast( + bytes, + self.client.hget(prefetch_id, self._schema.content_vector_key), + ), + dtype=self._schema.vector_dtype, + ) + for prefetch_id in prefetch_ids + ] + + # Select documents using maximal marginal relevance + selected_indices = maximal_marginal_relevance( + np.array(query_embedding), prefetch_embeddings, lambda_mult=lambda_mult, k=k + ) + selected_docs = [prefetch_docs[i] for i in selected_indices] + + return selected_docs + def _collect_metadata(self, result: "Document") -> Dict[str, Any]: """Collect metadata from Redis. @@ -952,19 +1061,16 @@ class Redis(VectorStore): def _prepare_query( self, - query: str, + query_embedding: List[float], k: int = 4, filter: Optional[RedisFilterExpression] = None, distance_threshold: Optional[float] = None, with_metadata: bool = True, with_distance: bool = False, ) -> Tuple["Query", Dict[str, Any]]: - # Creates embedding vector from user query - embedding = self._embeddings.embed_query(query) - # Creates Redis query params_dict: Dict[str, Union[str, bytes, float]] = { - "vector": _array_to_buffer(embedding, self._schema.vector_dtype), + "vector": _array_to_buffer(query_embedding, self._schema.vector_dtype), } # prepare return fields including score diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_redis.py b/libs/langchain/tests/integration_tests/vectorstores/test_redis.py index 3b7a4c7acc..cbcd78d070 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_redis.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_redis.py @@ -187,12 +187,21 @@ def test_redis_filters_1( documents, FakeEmbeddings(), redis_url=TEST_REDIS_URL ) - output = docsearch.similarity_search("foo", k=3, filter=filter_expr) + sim_output = docsearch.similarity_search("foo", k=3, filter=filter_expr) + mmr_output = docsearch.max_marginal_relevance_search( + "foo", k=3, fetch_k=5, filter=filter_expr + ) - assert len(output) == expected_length + assert len(sim_output) == expected_length + assert len(mmr_output) == expected_length if expected_nums is not None: - for out in output: + for out in sim_output: + assert ( + out.metadata["text"] in expected_nums + or int(out.metadata["num"]) in expected_nums + ) + for out in mmr_output: assert ( out.metadata["text"] in expected_nums or int(out.metadata["num"]) in expected_nums @@ -302,7 +311,6 @@ def test_similarity_search_limit_distance(texts: List[str]) -> None: def test_similarity_search_with_score_with_limit_distance(texts: List[str]) -> None: """Test similarity search with score with limit score.""" - docsearch = Redis.from_texts( texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL ) @@ -317,6 +325,32 @@ def test_similarity_search_with_score_with_limit_distance(texts: List[str]) -> N assert drop(docsearch.index_name) +def test_max_marginal_relevance_search(texts: List[str]) -> None: + """Test max marginal relevance search.""" + docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL) + + mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=3) + sim_output = docsearch.similarity_search(texts[0], k=3) + assert mmr_output == sim_output + + mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=2, fetch_k=3) + assert len(mmr_output) == 2 + assert mmr_output[0].page_content == texts[0] + assert mmr_output[1].page_content == texts[1] + + mmr_output = docsearch.max_marginal_relevance_search( + texts[0], k=2, fetch_k=3, lambda_mult=0.1 # more diversity + ) + assert len(mmr_output) == 2 + assert mmr_output[0].page_content == texts[0] + assert mmr_output[1].page_content == texts[2] + + # if fetch_k < k, then the output will be less than k + mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=2) + assert len(mmr_output) == 2 + assert drop(docsearch.index_name) + + def test_delete(texts: List[str]) -> None: """Test deleting a new document""" docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)