Harrison/new search (#4359)

Co-authored-by: Jiaping(JP) Zhang <vincentzhangv@gmail.com>
parallel_dir_loader
Harrison Chase 1 year ago committed by GitHub
parent 545ae8b756
commit 3ce29cb4a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,18 +25,10 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 2,
"id": "9fbcc58f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exiting: Cleaning up .chroma directory\n"
]
}
],
"outputs": [],
"source": [
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import FAISS\n",
@ -74,6 +66,7 @@
"id": "79b783de",
"metadata": {},
"source": [
"## Maximum Marginal Relevance Retrieval\n",
"By default, the vectorstore retriever uses similarity search. If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type."
]
},
@ -97,11 +90,42 @@
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
]
},
{
"cell_type": "markdown",
"id": "2d958271",
"metadata": {},
"source": [
"## Similarity Score Threshold Retrieval\n",
"\n",
"You can also a retrieval method that sets a similarity score threshold and only returns documents with a score above that threshold"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d4272ad8",
"metadata": {},
"outputs": [],
"source": [
"retriever = db.as_retriever(search_type=\"similarity_score_threshold\", search_kwargs={\"score_threshold\": .5})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "438e761d",
"metadata": {},
"outputs": [],
"source": [
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
]
},
{
"cell_type": "markdown",
"id": "c23b7698",
"metadata": {},
"source": [
"## Specifying top k\n",
"You can also specify search kwargs like `k` to use when doing retrieval."
]
},
@ -171,7 +195,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.9.1"
}
},
"nbformat": 4,

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
@ -116,6 +117,16 @@ class VectorStore(ABC):
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
docs_and_similarities = self._similarity_search_with_relevance_scores(
query, k=k, **kwargs
@ -124,10 +135,23 @@ class VectorStore(ABC):
similarity < 0.0 or similarity > 1.0
for _, similarity in docs_and_similarities
):
raise ValueError(
warnings.warn(
"Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}"
)
score_threshold = kwargs.get("score_threshold")
if score_threshold is not None:
docs_and_similarities = [
(doc, similarity)
for doc, similarity in docs_and_similarities
if similarity >= score_threshold
]
if len(docs_and_similarities) == 0:
warnings.warn(
f"No relevant docs were retrieved using the relevance score\
threshold {score_threshold}"
)
return docs_and_similarities
def _similarity_search_with_relevance_scores(
@ -324,13 +348,29 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "mmr"):
if search_type not in ("similarity", "similarity_score_threshold", "mmr"):
raise ValueError(f"search_type of {search_type} not allowed.")
if search_type == "similarity_score_threshold":
score_threshold = values["search_kwargs"].get("score_threshold")
if (score_threshold is None) or (
not isinstance(score_threshold, float)
):
raise ValueError(
"`score_threshold` is not specified with a float value(0~1) "
"in `search_kwargs`."
)
return values
def get_relevant_documents(self, query: str) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs

Loading…
Cancel
Save