redis retriever (#2060)

This commit is contained in:
Harrison Chase 2023-03-27 19:51:23 -07:00 committed by GitHub
parent b7ebb8fe30
commit 76ecca4d53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 212 additions and 26 deletions

View File

@ -14,7 +14,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 1,
"id": "5831703b", "id": "5831703b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -25,7 +25,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 9,
"id": "9fbcc58f", "id": "9fbcc58f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -33,26 +33,25 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Running Chroma using direct local API.\n", "Exiting: Cleaning up .chroma directory\n"
"Using DuckDB in-memory for database. Data will be transient.\n"
] ]
} }
], ],
"source": [ "source": [
"from langchain.text_splitter import CharacterTextSplitter\n", "from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import Chroma\n", "from langchain.vectorstores import FAISS\n",
"from langchain.embeddings import OpenAIEmbeddings\n", "from langchain.embeddings import OpenAIEmbeddings\n",
"\n", "\n",
"documents = loader.load()\n", "documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"texts = text_splitter.split_documents(documents)\n", "texts = text_splitter.split_documents(documents)\n",
"embeddings = OpenAIEmbeddings()\n", "embeddings = OpenAIEmbeddings()\n",
"db = Chroma.from_documents(texts, embeddings)" "db = FAISS.from_documents(texts, embeddings)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 10,
"id": "0cbfb1af", "id": "0cbfb1af",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -62,10 +61,97 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 11,
"id": "fc12700b", "id": "fc12700b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
]
},
{
"cell_type": "markdown",
"id": "79b783de",
"metadata": {},
"source": [
"By default, the vectorstore retriever uses similarity search. If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "44c7303e",
"metadata": {},
"outputs": [],
"source": [
"retriever = db.as_retriever(search_type=\"mmr\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d16ceec6",
"metadata": {},
"outputs": [],
"source": [
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
]
},
{
"cell_type": "markdown",
"id": "c23b7698",
"metadata": {},
"source": [
"You can also specify search kwargs like `k` to use when doing retrieval."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b5f44cdf",
"metadata": {},
"outputs": [],
"source": [
"retriever = db.as_retriever(search_kwargs={\"k\": 1})"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "56b6a545",
"metadata": {},
"outputs": [],
"source": [
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b5416858",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(docs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a658023",
"metadata": {},
"outputs": [],
"source": [] "source": []
} }
], ],

View File

@ -22,7 +22,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -37,7 +37,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -46,16 +46,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'b564189668a343648996bd5a1d353d4e'" "'link'"
] ]
}, },
"execution_count": 5, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -64,24 +64,15 @@
"rds.index_name" "rds.index_name"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
"\n",
"We cannot let this happen. \n",
"\n",
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \n", "Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
"\n", "\n",
"Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n", "Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
@ -147,6 +138,79 @@
"results = rds.similarity_search(query)\n", "results = rds.similarity_search(query)\n",
"print(results[0].page_content)" "print(results[0].page_content)"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RedisVectorStoreRetriever\n",
"\n",
"Here we go over different options for using the vector store as a retriever.\n",
"\n",
"There are three different search methods we can use to do retrieval. By default, it will use semantic similarity."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"retriever = rds.as_retriever()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"docs = retriever.get_relevant_documents(query)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also use similarity_limit as a search method. This is only return documents if they are similar enough"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"retriever = rds.as_retriever(search_type=\"similarity_limit\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Here we can see it doesn't return any results because there are no relevant documents\n",
"retriever.get_relevant_documents(\"where did ankush go to college?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@ -126,8 +126,8 @@ class VectorStore(ABC):
) -> VectorStore: ) -> VectorStore:
"""Return VectorStore initialized from texts and embeddings.""" """Return VectorStore initialized from texts and embeddings."""
def as_retriever(self) -> VectorStoreRetriever: def as_retriever(self, **kwargs: Any) -> BaseRetriever:
return VectorStoreRetriever(vectorstore=self) return VectorStoreRetriever(vectorstore=self, **kwargs)
class VectorStoreRetriever(BaseRetriever, BaseModel): class VectorStoreRetriever(BaseRetriever, BaseModel):

View File

@ -4,13 +4,15 @@ from __future__ import annotations
import json import json
import logging import logging
import uuid import uuid
from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
import numpy as np import numpy as np
from pydantic import BaseModel, Field, root_validator
from redis.client import Redis as RedisType from redis.client import Redis as RedisType
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
@ -339,3 +341,37 @@ class Redis(VectorStore):
) )
return cls(redis_url, index_name, embedding.embed_query) return cls(redis_url, index_name, embedding.embed_query)
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
return RedisVectorStoreRetriever(vectorstore=self, **kwargs)
class RedisVectorStoreRetriever(BaseRetriever, BaseModel):
vectorstore: Redis
search_type: str = "similarity"
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "similarity_limit"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def get_relevant_documents(self, query: str) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_limit":
docs = self.vectorstore.similarity_search_limit_score(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs