Harrison/new search (#4359)

Co-authored-by: Jiaping(JP) Zhang <vincentzhangv@gmail.com>
parallel_dir_loader
Harrison Chase 1 year ago committed by GitHub
parent 545ae8b756
commit 3ce29cb4a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,18 +25,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 2,
"id": "9fbcc58f", "id": "9fbcc58f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exiting: Cleaning up .chroma directory\n"
]
}
],
"source": [ "source": [
"from langchain.text_splitter import CharacterTextSplitter\n", "from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import FAISS\n", "from langchain.vectorstores import FAISS\n",
@ -74,6 +66,7 @@
"id": "79b783de", "id": "79b783de",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Maximum Marginal Relevance Retrieval\n",
"By default, the vectorstore retriever uses similarity search. If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type." "By default, the vectorstore retriever uses similarity search. If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type."
] ]
}, },
@ -97,11 +90,42 @@
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")" "docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "2d958271",
"metadata": {},
"source": [
"## Similarity Score Threshold Retrieval\n",
"\n",
"You can also a retrieval method that sets a similarity score threshold and only returns documents with a score above that threshold"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d4272ad8",
"metadata": {},
"outputs": [],
"source": [
"retriever = db.as_retriever(search_type=\"similarity_score_threshold\", search_kwargs={\"score_threshold\": .5})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "438e761d",
"metadata": {},
"outputs": [],
"source": [
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "c23b7698", "id": "c23b7698",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Specifying top k\n",
"You can also specify search kwargs like `k` to use when doing retrieval." "You can also specify search kwargs like `k` to use when doing retrieval."
] ]
}, },
@ -171,7 +195,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.3" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
@ -116,6 +117,16 @@ class VectorStore(ABC):
"""Return docs and relevance scores in the range [0, 1]. """Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar. 0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
""" """
docs_and_similarities = self._similarity_search_with_relevance_scores( docs_and_similarities = self._similarity_search_with_relevance_scores(
query, k=k, **kwargs query, k=k, **kwargs
@ -124,10 +135,23 @@ class VectorStore(ABC):
similarity < 0.0 or similarity > 1.0 similarity < 0.0 or similarity > 1.0
for _, similarity in docs_and_similarities for _, similarity in docs_and_similarities
): ):
raise ValueError( warnings.warn(
"Relevance scores must be between" "Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}" f" 0 and 1, got {docs_and_similarities}"
) )
score_threshold = kwargs.get("score_threshold")
if score_threshold is not None:
docs_and_similarities = [
(doc, similarity)
for doc, similarity in docs_and_similarities
if similarity >= score_threshold
]
if len(docs_and_similarities) == 0:
warnings.warn(
f"No relevant docs were retrieved using the relevance score\
threshold {score_threshold}"
)
return docs_and_similarities return docs_and_similarities
def _similarity_search_with_relevance_scores( def _similarity_search_with_relevance_scores(
@ -324,13 +348,29 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
"""Validate search type.""" """Validate search type."""
if "search_type" in values: if "search_type" in values:
search_type = values["search_type"] search_type = values["search_type"]
if search_type not in ("similarity", "mmr"): if search_type not in ("similarity", "similarity_score_threshold", "mmr"):
raise ValueError(f"search_type of {search_type} not allowed.") raise ValueError(f"search_type of {search_type} not allowed.")
if search_type == "similarity_score_threshold":
score_threshold = values["search_kwargs"].get("score_threshold")
if (score_threshold is None) or (
not isinstance(score_threshold, float)
):
raise ValueError(
"`score_threshold` is not specified with a float value(0~1) "
"in `search_kwargs`."
)
return values return values
def get_relevant_documents(self, query: str) -> List[Document]: def get_relevant_documents(self, query: str) -> List[Document]:
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs) docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr": elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search( docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs query, **self.search_kwargs

Loading…
Cancel
Save