diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index 69660c3a..21b1e5e2 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -4,7 +4,7 @@ from __future__ import annotations import json import logging import uuid -from typing import Any, Callable, Iterable, List, Mapping, Optional +from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple import numpy as np from redis.client import Redis as RedisType @@ -86,6 +86,48 @@ class Redis(VectorStore): def similarity_search( self, query: str, k: int = 4, **kwargs: Any ) -> List[Document]: + docs_and_scores = self.similarity_search_with_score(query, k=k) + return [doc for doc, _ in docs_and_scores] + + def similarity_search_limit_score( + self, query: str, k: int = 4, score_threshold: float = 0.2, **kwargs: Any + ) -> List[Document]: + """ + Returns the most similar indexed documents to the query text. + + Args: + query (str): The query text for which to find similar documents. + k (int): The number of documents to return. Default is 4. + score_threshold (float): The minimum matching score required for a document + to be considered a match. Defaults to 0.2. + Because the similarity calculation algorithm is based on cosine similarity, + the smaller the angle, the higher the similarity. + + Returns: + List[Document]: A list of documents that are most similar to the query text, + including the match score for each document. + + Note: + If there are no documents that satisfy the score_threshold value, + an empty list is returned. + + """ + docs_and_scores = self.similarity_search_with_score(query, k=k) + + return [doc for doc, score in docs_and_scores if score < score_threshold] + + def similarity_search_with_score( + self, query: str, k: int = 4 + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query and score for each + """ try: from redis.commands.search.query import Query except ImportError: @@ -120,12 +162,17 @@ class Redis(VectorStore): # perform vector search results = self.client.ft(self.index_name).search(redis_query, params_dict) - documents = [ - Document(page_content=result.content, metadata=json.loads(result.metadata)) + docs = [ + ( + Document( + page_content=result.content, metadata=json.loads(result.metadata) + ), + float(result.vector_score), + ) for result in results.docs ] - return documents + return docs @classmethod def from_texts(