forked from Archives/langchain
redis retriever (#2060)
This commit is contained in:
parent
b7ebb8fe30
commit
76ecca4d53
@ -14,7 +14,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 1,
|
||||||
"id": "5831703b",
|
"id": "5831703b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -25,7 +25,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 9,
|
||||||
"id": "9fbcc58f",
|
"id": "9fbcc58f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -33,26 +33,25 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Running Chroma using direct local API.\n",
|
"Exiting: Cleaning up .chroma directory\n"
|
||||||
"Using DuckDB in-memory for database. Data will be transient.\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
"from langchain.vectorstores import Chroma\n",
|
"from langchain.vectorstores import FAISS\n",
|
||||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||||
"\n",
|
"\n",
|
||||||
"documents = loader.load()\n",
|
"documents = loader.load()\n",
|
||||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||||
"texts = text_splitter.split_documents(documents)\n",
|
"texts = text_splitter.split_documents(documents)\n",
|
||||||
"embeddings = OpenAIEmbeddings()\n",
|
"embeddings = OpenAIEmbeddings()\n",
|
||||||
"db = Chroma.from_documents(texts, embeddings)"
|
"db = FAISS.from_documents(texts, embeddings)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 10,
|
||||||
"id": "0cbfb1af",
|
"id": "0cbfb1af",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -62,10 +61,97 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 11,
|
||||||
"id": "fc12700b",
|
"id": "fc12700b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "79b783de",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"By default, the vectorstore retriever uses similarity search. If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "44c7303e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = db.as_retriever(search_type=\"mmr\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"id": "d16ceec6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c23b7698",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can also specify search kwargs like `k` to use when doing retrieval."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "b5f44cdf",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = db.as_retriever(search_kwargs={\"k\": 1})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "56b6a545",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"id": "b5416858",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"len(docs)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9a658023",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -37,7 +37,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -46,16 +46,16 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'b564189668a343648996bd5a1d353d4e'"
|
"'link'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -64,24 +64,15 @@
|
|||||||
"rds.index_name"
|
"rds.index_name"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": []
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
|
|
||||||
"\n",
|
|
||||||
"We cannot let this happen. \n",
|
|
||||||
"\n",
|
|
||||||
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||||||
"\n",
|
"\n",
|
||||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||||
@ -147,6 +138,79 @@
|
|||||||
"results = rds.similarity_search(query)\n",
|
"results = rds.similarity_search(query)\n",
|
||||||
"print(results[0].page_content)"
|
"print(results[0].page_content)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## RedisVectorStoreRetriever\n",
|
||||||
|
"\n",
|
||||||
|
"Here we go over different options for using the vector store as a retriever.\n",
|
||||||
|
"\n",
|
||||||
|
"There are three different search methods we can use to do retrieval. By default, it will use semantic similarity."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = rds.as_retriever()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"docs = retriever.get_relevant_documents(query)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We can also use similarity_limit as a search method. This is only return documents if they are similar enough"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = rds.as_retriever(search_type=\"similarity_limit\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Here we can see it doesn't return any results because there are no relevant documents\n",
|
||||||
|
"retriever.get_relevant_documents(\"where did ankush go to college?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -126,8 +126,8 @@ class VectorStore(ABC):
|
|||||||
) -> VectorStore:
|
) -> VectorStore:
|
||||||
"""Return VectorStore initialized from texts and embeddings."""
|
"""Return VectorStore initialized from texts and embeddings."""
|
||||||
|
|
||||||
def as_retriever(self) -> VectorStoreRetriever:
|
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||||
return VectorStoreRetriever(vectorstore=self)
|
return VectorStoreRetriever(vectorstore=self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreRetriever(BaseRetriever, BaseModel):
|
class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||||
|
@ -4,13 +4,15 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, Field, root_validator
|
||||||
from redis.client import Redis as RedisType
|
from redis.client import Redis as RedisType
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import BaseRetriever
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
@ -339,3 +341,37 @@ class Redis(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return cls(redis_url, index_name, embedding.embed_query)
|
return cls(redis_url, index_name, embedding.embed_query)
|
||||||
|
|
||||||
|
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||||
|
return RedisVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||||
|
vectorstore: Redis
|
||||||
|
search_type: str = "similarity"
|
||||||
|
search_kwargs: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_search_type(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate search type."""
|
||||||
|
if "search_type" in values:
|
||||||
|
search_type = values["search_type"]
|
||||||
|
if search_type not in ("similarity", "similarity_limit"):
|
||||||
|
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||||
|
return values
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
if self.search_type == "similarity":
|
||||||
|
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||||
|
elif self.search_type == "similarity_limit":
|
||||||
|
docs = self.vectorstore.similarity_search_limit_score(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
|
return docs
|
||||||
|
Loading…
Reference in New Issue
Block a user