From 3f9900a864ef9cdb6444a7674b7cfdf6f89fb626 Mon Sep 17 00:00:00 2001 From: Lance Martin <122662504+rlancemartin@users.noreply.github.com> Date: Tue, 27 Jun 2023 22:59:40 -0700 Subject: [PATCH] Create MultiQueryRetriever (#6833) Distance-based vector database retrieval embeds (represents) queries in high-dimensional space and finds similar embedded documents based on "distance". But, retrieval may produce difference results with subtle changes in query wording or if the embeddings do not capture the semantics of the data well. Prompt engineering / tuning is sometimes done to manually address these problems, but can be tedious. The `MultiQueryRetriever` automates the process of prompt tuning by using an LLM to generate multiple queries from different perspectives for a given user input query. For each query, it retrieves a set of relevant documents and takes the unique union across all queries to get a larger set of potentially relevant documents. By generating multiple perspectives on the same question, the `MultiQueryRetriever` might be able to overcome some of the limitations of the distance-based retrieval and get a richer set of results. --------- Co-authored-by: Harrison Chase --- .../how_to/MultiQueryRetriever.ipynb | 214 ++++++++++++++++++ langchain/retrievers/__init__.py | 2 + langchain/retrievers/multi_query.py | 158 +++++++++++++ 3 files changed, 374 insertions(+) create mode 100644 docs/extras/modules/data_connection/retrievers/how_to/MultiQueryRetriever.ipynb create mode 100644 langchain/retrievers/multi_query.py diff --git a/docs/extras/modules/data_connection/retrievers/how_to/MultiQueryRetriever.ipynb b/docs/extras/modules/data_connection/retrievers/how_to/MultiQueryRetriever.ipynb new file mode 100644 index 0000000000..12e24fc77d --- /dev/null +++ b/docs/extras/modules/data_connection/retrievers/how_to/MultiQueryRetriever.ipynb @@ -0,0 +1,214 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8cc82b48", + "metadata": {}, + "source": [ + "# MultiQueryRetriever\n", + "\n", + "Distance-based vector database retrieval embeds (represents) queries in high-dimensional space and finds similar embedded documents based on \"distance\". But, retrieval may produce difference results with subtle changes in query wording or if the embeddings do not capture the semantics of the data well. Prompt engineering / tuning is sometimes done to manually address these problems, but can be tedious.\n", + "\n", + "The `MultiQueryRetriever` automates the process of prompt tuning by using an LLM to generate multiple queries from different perspectives for a given user input query. For each query, it retrieves a set of relevant documents and takes the unique union across all queries to get a larger set of potentially relevant documents. By generating multiple perspectives on the same question, the `MultiQueryRetriever` might be able to overcome some of the limitations of the distance-based retrieval and get a richer set of results." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c2f3f5f2", + "metadata": {}, + "outputs": [], + "source": [ + "# Build a sample vectorDB\n", + "from langchain.vectorstores import Chroma\n", + "from langchain.document_loaders import PyPDFLoader\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "\n", + "# Load PDF\n", + "path=\"path-to-files\"\n", + "loaders = [\n", + " PyPDFLoader(path+\"docs/cs229_lectures/MachineLearning-Lecture01.pdf\"),\n", + " PyPDFLoader(path+\"docs/cs229_lectures/MachineLearning-Lecture02.pdf\"),\n", + " PyPDFLoader(path+\"docs/cs229_lectures/MachineLearning-Lecture03.pdf\")\n", + "]\n", + "docs = []\n", + "for loader in loaders:\n", + " docs.extend(loader.load())\n", + " \n", + "# Split\n", + "text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1500,chunk_overlap = 150)\n", + "splits = text_splitter.split_documents(docs)\n", + "\n", + "# VectorDB\n", + "embedding = OpenAIEmbeddings()\n", + "vectordb = Chroma.from_documents(documents=splits,embedding=embedding)" + ] + }, + { + "cell_type": "markdown", + "id": "cca8f56c", + "metadata": {}, + "source": [ + "`Simple usage`\n", + "\n", + "Specify the LLM to use for query generation, and the retriver will do the rest." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "edbca101", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.retrievers.multi_query import MultiQueryRetriever\n", + "question=\"What does the course say about regression?\"\n", + "num_queries=3\n", + "llm = ChatOpenAI(temperature=0)\n", + "retriever_from_llm = MultiQueryRetriever.from_llm(retriever=vectordb.as_retriever(),llm=llm)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e5203612", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Generated queries: [\"1. What is the course's perspective on regression?\", '2. How does the course discuss regression?', '3. What information does the course provide about regression?']\n" + ] + }, + { + "data": { + "text/plain": [ + "6" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unique_docs = retriever_from_llm.get_relevant_documents(question=\"What does the course say about regression?\")\n", + "len(unique_docs)" + ] + }, + { + "cell_type": "markdown", + "id": "c54a282f", + "metadata": {}, + "source": [ + "`Supplying your own prompt`\n", + "\n", + "You can also supply a prompt along with an output parser to split the results into a list of queries." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d9afb0ca", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List\n", + "from langchain import LLMChain\n", + "from pydantic import BaseModel, Field\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.output_parsers import PydanticOutputParser\n", + "\n", + "# Output parser will split the LLM result into a list of queries\n", + "class LineList(BaseModel):\n", + " # \"lines\" is the key (attribute name) of the parsed output\n", + " lines: List[str] = Field(description=\"Lines of text\")\n", + "\n", + "class LineListOutputParser(PydanticOutputParser):\n", + " def __init__(self) -> None:\n", + " super().__init__(pydantic_object=LineList)\n", + " def parse(self, text: str) -> LineList:\n", + " lines = text.strip().split(\"\\n\")\n", + " return LineList(lines=lines)\n", + "\n", + "output_parser = LineListOutputParser()\n", + " \n", + "QUERY_PROMPT = PromptTemplate(\n", + " input_variables=[\"question\"],\n", + " template=\"\"\"You are an AI language model assistant. Your task is to generate five \n", + " different versions of the given user question to retrieve relevant documents from a vector \n", + " database. By generating multiple perspectives on the user question, your goal is to help\n", + " the user overcome some of the limitations of the distance-based similarity search. \n", + " Provide these alternative questions seperated by newlines.\n", + " Original question: {question}\"\"\",\n", + ")\n", + "llm = ChatOpenAI(temperature=0)\n", + "\n", + "# Chain\n", + "llm_chain = LLMChain(llm=llm,prompt=QUERY_PROMPT,output_parser=output_parser)\n", + " \n", + "# Other inputs\n", + "question=\"What does the course say about regression?\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6660d7ee", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Generated queries: [\"1. What is the course's perspective on regression?\", '2. Can you provide information on regression as discussed in the course?', '3. How does the course cover the topic of regression?', \"4. What are the course's teachings on regression?\", '5. In relation to the course, what is mentioned about regression?']\n" + ] + }, + { + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run\n", + "retriever = MultiQueryRetriever(retriever=vectordb.as_retriever(), \n", + " llm_chain=llm_chain,\n", + " parser_key=\"lines\") # \"lines\" is the key (attribute name) of the parsed output\n", + "\n", + "# Results\n", + "unique_docs = retriever.get_relevant_documents(question=\"What does the course say about regression?\")\n", + "len(unique_docs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index c872651b0a..5fc3d0b8c0 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -14,6 +14,7 @@ from langchain.retrievers.llama_index import ( from langchain.retrievers.merger_retriever import MergerRetriever from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.milvus import MilvusRetriever +from langchain.retrievers.multi_query import MultiQueryRetriever from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever from langchain.retrievers.pupmed import PubMedRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever @@ -43,6 +44,7 @@ __all__ = [ "MergerRetriever", "MetalRetriever", "MilvusRetriever", + "MultiQueryRetriever", "PineconeHybridSearchRetriever", "PubMedRetriever", "RemoteLangChainRetriever", diff --git a/langchain/retrievers/multi_query.py b/langchain/retrievers/multi_query.py new file mode 100644 index 0000000000..78a2624b67 --- /dev/null +++ b/langchain/retrievers/multi_query.py @@ -0,0 +1,158 @@ +import logging +from typing import List + +from pydantic import BaseModel, Field + +from langchain.chains.llm import LLMChain +from langchain.llms.base import BaseLLM +from langchain.output_parsers.pydantic import PydanticOutputParser +from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BaseRetriever, Document + +logging.basicConfig(level=logging.INFO) + + +class LineList(BaseModel): + lines: List[str] = Field(description="Lines of text") + + +class LineListOutputParser(PydanticOutputParser): + def __init__(self) -> None: + super().__init__(pydantic_object=LineList) + + def parse(self, text: str) -> LineList: + lines = text.strip().split("\n") + return LineList(lines=lines) + + +# Default prompt +DEFAULT_QUERY_PROMPT = PromptTemplate( + input_variables=["question"], + template="""You are an AI language model assistant. Your task is + to generate 3 different versions of the given user + question to retrieve relevant documents from a vector database. + By generating multiple perspectives on the user question, + your goal is to help the user overcome some of the limitations + of distance-based similarity search. Provide these alternative + questions seperated by newlines. Original question: {question}""", +) + + +class MultiQueryRetriever(BaseRetriever): + + """Given a user query, use an LLM to write a set of queries. + Retrieve docs for each query. Rake the unique union of all retrieved docs.""" + + def __init__( + self, + retriever: BaseRetriever, + llm_chain: LLMChain, + verbose: bool = True, + parser_key: str = "lines", + ) -> None: + """Initialize MultiQueryRetriever. + + Args: + retriever: retriever to query documents from + llm_chain: llm_chain for query generation + verbose: show the queries that we generated to the user + parser_key: attribute name for the parsed output + + Returns: + MultiQueryRetriever + """ + self.retriever = retriever + self.llm_chain = llm_chain + self.verbose = verbose + self.parser_key = parser_key + + @classmethod + def from_llm( + cls, + retriever: BaseRetriever, + llm: BaseLLM, + prompt: PromptTemplate = DEFAULT_QUERY_PROMPT, + parser_key: str = "lines", + ) -> "MultiQueryRetriever": + """Initialize from llm using default template. + + Args: + retriever: retriever to query documents from + llm: llm for query generation using DEFAULT_QUERY_PROMPT + + Returns: + MultiQueryRetriever + """ + output_parser = LineListOutputParser() + llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser) + return cls( + retriever=retriever, + llm_chain=llm_chain, + parser_key=parser_key, + ) + + def get_relevant_documents(self, question: str) -> List[Document]: + """Get relevated documents given a user query. + + Args: + question: user query + + Returns: + Unique union of relevant documents from all generated queries + """ + queries = self.generate_queries(question) + documents = self.retrieve_documents(queries) + unique_documents = self.unique_union(documents) + return unique_documents + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError + + def generate_queries(self, question: str) -> List[str]: + """Generate queries based upon user input. + + Args: + question: user query + + Returns: + List of LLM generated queries that are similar to the user input + """ + response = self.llm_chain({"question": question}) + lines = getattr(response["text"], self.parser_key, []) + if self.verbose: + logging.info(f"Generated queries: {lines}") + return lines + + def retrieve_documents(self, queries: List[str]) -> List[Document]: + """Run all LLM generated queries. + + Args: + queries: query list + + Returns: + List of retrived Documents + """ + documents = [] + for query in queries: + docs = self.retriever.get_relevant_documents(query) + documents.extend(docs) + return documents + + def unique_union(self, documents: List[Document]) -> List[Document]: + """Get uniqe Documents. + + Args: + documents: List of retrived Documents + + Returns: + List of unique retrived Documents + """ + # Create a dictionary with page_content as keys to remove duplicates + # TODO: Add Document ID property (e.g., UUID) + unique_documents_dict = { + (doc.page_content, tuple(sorted(doc.metadata.items()))): doc + for doc in documents + } + + unique_documents = list(unique_documents_dict.values()) + return unique_documents