From 76ecca4d5383a798f0a5db576b8c4b5dd70c3b42 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 27 Mar 2023 19:51:23 -0700 Subject: [PATCH] redis retriever (#2060) --- .../examples/vectorstore-retriever.ipynb | 102 ++++++++++++++++-- .../indexes/vectorstores/examples/redis.ipynb | 94 +++++++++++++--- langchain/vectorstores/base.py | 4 +- langchain/vectorstores/redis.py | 38 ++++++- 4 files changed, 212 insertions(+), 26 deletions(-) diff --git a/docs/modules/indexes/retrievers/examples/vectorstore-retriever.ipynb b/docs/modules/indexes/retrievers/examples/vectorstore-retriever.ipynb index d96ab341..cc9c948f 100644 --- a/docs/modules/indexes/retrievers/examples/vectorstore-retriever.ipynb +++ b/docs/modules/indexes/retrievers/examples/vectorstore-retriever.ipynb @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "5831703b", "metadata": {}, "outputs": [], @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "id": "9fbcc58f", "metadata": {}, "outputs": [ @@ -33,26 +33,25 @@ "name": "stdout", "output_type": "stream", "text": [ - "Running Chroma using direct local API.\n", - "Using DuckDB in-memory for database. Data will be transient.\n" + "Exiting: Cleaning up .chroma directory\n" ] } ], "source": [ "from langchain.text_splitter import CharacterTextSplitter\n", - "from langchain.vectorstores import Chroma\n", + "from langchain.vectorstores import FAISS\n", "from langchain.embeddings import OpenAIEmbeddings\n", "\n", "documents = loader.load()\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "texts = text_splitter.split_documents(documents)\n", "embeddings = OpenAIEmbeddings()\n", - "db = Chroma.from_documents(texts, embeddings)" + "db = FAISS.from_documents(texts, embeddings)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "id": "0cbfb1af", "metadata": {}, "outputs": [], @@ -62,10 +61,97 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "fc12700b", "metadata": {}, "outputs": [], + "source": [ + "docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")" + ] + }, + { + "cell_type": "markdown", + "id": "79b783de", + "metadata": {}, + "source": [ + "By default, the vectorstore retriever uses similarity search. If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "44c7303e", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = db.as_retriever(search_type=\"mmr\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d16ceec6", + "metadata": {}, + "outputs": [], + "source": [ + "docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")" + ] + }, + { + "cell_type": "markdown", + "id": "c23b7698", + "metadata": {}, + "source": [ + "You can also specify search kwargs like `k` to use when doing retrieval." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b5f44cdf", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = db.as_retriever(search_kwargs={\"k\": 1})" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "56b6a545", + "metadata": {}, + "outputs": [], + "source": [ + "docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b5416858", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(docs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a658023", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/docs/modules/indexes/vectorstores/examples/redis.ipynb b/docs/modules/indexes/vectorstores/examples/redis.ipynb index 5fe54f1d..cbcb5eb3 100644 --- a/docs/modules/indexes/vectorstores/examples/redis.ipynb +++ b/docs/modules/indexes/vectorstores/examples/redis.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -46,16 +46,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'b564189668a343648996bd5a1d353d4e'" + "'link'" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -64,24 +64,15 @@ "rds.index_name" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n", - "\n", - "We cannot let this happen. \n", - "\n", "Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n", "\n", "Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n", @@ -147,6 +138,79 @@ "results = rds.similarity_search(query)\n", "print(results[0].page_content)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RedisVectorStoreRetriever\n", + "\n", + "Here we go over different options for using the vector store as a retriever.\n", + "\n", + "There are three different search methods we can use to do retrieval. By default, it will use semantic similarity." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = rds.as_retriever()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "docs = retriever.get_relevant_documents(query)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use similarity_limit as a search method. This is only return documents if they are similar enough" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = rds.as_retriever(search_type=\"similarity_limit\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Here we can see it doesn't return any results because there are no relevant documents\n", + "retriever.get_relevant_documents(\"where did ankush go to college?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 9e80b3d5..87caddd1 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -126,8 +126,8 @@ class VectorStore(ABC): ) -> VectorStore: """Return VectorStore initialized from texts and embeddings.""" - def as_retriever(self) -> VectorStoreRetriever: - return VectorStoreRetriever(vectorstore=self) + def as_retriever(self, **kwargs: Any) -> BaseRetriever: + return VectorStoreRetriever(vectorstore=self, **kwargs) class VectorStoreRetriever(BaseRetriever, BaseModel): diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index 21b1e5e2..5e0a6986 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -4,13 +4,15 @@ from __future__ import annotations import json import logging import uuid -from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple import numpy as np +from pydantic import BaseModel, Field, root_validator from redis.client import Redis as RedisType from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore @@ -339,3 +341,37 @@ class Redis(VectorStore): ) return cls(redis_url, index_name, embedding.embed_query) + + def as_retriever(self, **kwargs: Any) -> BaseRetriever: + return RedisVectorStoreRetriever(vectorstore=self, **kwargs) + + +class RedisVectorStoreRetriever(BaseRetriever, BaseModel): + vectorstore: Redis + search_type: str = "similarity" + search_kwargs: dict = Field(default_factory=dict) + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @root_validator() + def validate_search_type(cls, values: Dict) -> Dict: + """Validate search type.""" + if "search_type" in values: + search_type = values["search_type"] + if search_type not in ("similarity", "similarity_limit"): + raise ValueError(f"search_type of {search_type} not allowed.") + return values + + def get_relevant_documents(self, query: str) -> List[Document]: + if self.search_type == "similarity": + docs = self.vectorstore.similarity_search(query, **self.search_kwargs) + elif self.search_type == "similarity_limit": + docs = self.vectorstore.similarity_search_limit_score( + query, **self.search_kwargs + ) + else: + raise ValueError(f"search_type of {self.search_type} not allowed.") + return docs