Add mmr support to redis retriever (#10556)

pull/10586/head
Bagatur 1 year ago committed by GitHub
parent ccf71e23e8
commit 7f3f6097e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -158,7 +158,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -178,7 +178,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -242,7 +242,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"metadata": {
"tags": []
},
@ -253,7 +253,7 @@
"rds = Redis.from_texts(\n",
" texts,\n",
" embeddings,\n",
" metadatas=metadats,\n",
" metadatas=metadata,\n",
" redis_url=\"redis://localhost:6379\",\n",
" index_name=\"users\"\n",
")"
@ -597,7 +597,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -607,7 +607,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -1110,6 +1110,38 @@
"retriever.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"retriever = rds.as_retriever(search_type=\"mmr\", search_kwargs={\"fetch_k\": 20, \"k\": 4, \"lambda_mult\": 0.1})"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='foo', metadata={'id': 'doc:users:8f6b673b390647809d510112cde01a27', 'user': 'john', 'job': 'engineer', 'credit_score': 'high', 'age': '18'}),\n",
" Document(page_content='bar', metadata={'id': 'doc:users:93521560735d42328b48c9c6f6418d6a', 'user': 'tyler', 'job': 'engineer', 'credit_score': 'high', 'age': '100'}),\n",
" Document(page_content='foo', metadata={'id': 'doc:users:125ecd39d07845eabf1a699d44134a5b', 'user': 'nancy', 'job': 'doctor', 'credit_score': 'high', 'age': '94'}),\n",
" Document(page_content='foo', metadata={'id': 'doc:users:d6200ab3764c466082fde3eaab972a2a', 'user': 'derrick', 'job': 'doctor', 'credit_score': 'low', 'age': '45'})]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -1227,7 +1259,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.9.1"
}
},
"nbformat": 4,

@ -1425,6 +1425,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
"similarity",
"similarity_distance_threshold",
"similarity_score_threshold",
"mmr",
]
"""Allowed search types."""
@ -1438,7 +1439,6 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_distance_threshold":
if self.search_kwargs["distance_threshold"] is None:
raise ValueError(
@ -1454,6 +1454,10 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
)
)
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs

Loading…
Cancel
Save