Create MultiQueryRetriever (#6833)

Distance-based vector database retrieval embeds (represents) queries in
high-dimensional space and finds similar embedded documents based on
"distance". But, retrieval may produce difference results with subtle
changes in query wording or if the embeddings do not capture the
semantics of the data well. Prompt engineering / tuning is sometimes
done to manually address these problems, but can be tedious.

The `MultiQueryRetriever` automates the process of prompt tuning by
using an LLM to generate multiple queries from different perspectives
for a given user input query. For each query, it retrieves a set of
relevant documents and takes the unique union across all queries to get
a larger set of potentially relevant documents. By generating multiple
perspectives on the same question, the `MultiQueryRetriever` might be
able to overcome some of the limitations of the distance-based retrieval
and get a richer set of results.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/6857/head
Lance Martin 1 year ago committed by GitHub
parent 3ca1a387c2
commit 3f9900a864
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,214 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "8cc82b48",
"metadata": {},
"source": [
"# MultiQueryRetriever\n",
"\n",
"Distance-based vector database retrieval embeds (represents) queries in high-dimensional space and finds similar embedded documents based on \"distance\". But, retrieval may produce difference results with subtle changes in query wording or if the embeddings do not capture the semantics of the data well. Prompt engineering / tuning is sometimes done to manually address these problems, but can be tedious.\n",
"\n",
"The `MultiQueryRetriever` automates the process of prompt tuning by using an LLM to generate multiple queries from different perspectives for a given user input query. For each query, it retrieves a set of relevant documents and takes the unique union across all queries to get a larger set of potentially relevant documents. By generating multiple perspectives on the same question, the `MultiQueryRetriever` might be able to overcome some of the limitations of the distance-based retrieval and get a richer set of results."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c2f3f5f2",
"metadata": {},
"outputs": [],
"source": [
"# Build a sample vectorDB\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.document_loaders import PyPDFLoader\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"# Load PDF\n",
"path=\"path-to-files\"\n",
"loaders = [\n",
" PyPDFLoader(path+\"docs/cs229_lectures/MachineLearning-Lecture01.pdf\"),\n",
" PyPDFLoader(path+\"docs/cs229_lectures/MachineLearning-Lecture02.pdf\"),\n",
" PyPDFLoader(path+\"docs/cs229_lectures/MachineLearning-Lecture03.pdf\")\n",
"]\n",
"docs = []\n",
"for loader in loaders:\n",
" docs.extend(loader.load())\n",
" \n",
"# Split\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1500,chunk_overlap = 150)\n",
"splits = text_splitter.split_documents(docs)\n",
"\n",
"# VectorDB\n",
"embedding = OpenAIEmbeddings()\n",
"vectordb = Chroma.from_documents(documents=splits,embedding=embedding)"
]
},
{
"cell_type": "markdown",
"id": "cca8f56c",
"metadata": {},
"source": [
"`Simple usage`\n",
"\n",
"Specify the LLM to use for query generation, and the retriver will do the rest."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "edbca101",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.retrievers.multi_query import MultiQueryRetriever\n",
"question=\"What does the course say about regression?\"\n",
"num_queries=3\n",
"llm = ChatOpenAI(temperature=0)\n",
"retriever_from_llm = MultiQueryRetriever.from_llm(retriever=vectordb.as_retriever(),llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e5203612",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Generated queries: [\"1. What is the course's perspective on regression?\", '2. How does the course discuss regression?', '3. What information does the course provide about regression?']\n"
]
},
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unique_docs = retriever_from_llm.get_relevant_documents(question=\"What does the course say about regression?\")\n",
"len(unique_docs)"
]
},
{
"cell_type": "markdown",
"id": "c54a282f",
"metadata": {},
"source": [
"`Supplying your own prompt`\n",
"\n",
"You can also supply a prompt along with an output parser to split the results into a list of queries."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d9afb0ca",
"metadata": {},
"outputs": [],
"source": [
"from typing import List\n",
"from langchain import LLMChain\n",
"from pydantic import BaseModel, Field\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.output_parsers import PydanticOutputParser\n",
"\n",
"# Output parser will split the LLM result into a list of queries\n",
"class LineList(BaseModel):\n",
" # \"lines\" is the key (attribute name) of the parsed output\n",
" lines: List[str] = Field(description=\"Lines of text\")\n",
"\n",
"class LineListOutputParser(PydanticOutputParser):\n",
" def __init__(self) -> None:\n",
" super().__init__(pydantic_object=LineList)\n",
" def parse(self, text: str) -> LineList:\n",
" lines = text.strip().split(\"\\n\")\n",
" return LineList(lines=lines)\n",
"\n",
"output_parser = LineListOutputParser()\n",
" \n",
"QUERY_PROMPT = PromptTemplate(\n",
" input_variables=[\"question\"],\n",
" template=\"\"\"You are an AI language model assistant. Your task is to generate five \n",
" different versions of the given user question to retrieve relevant documents from a vector \n",
" database. By generating multiple perspectives on the user question, your goal is to help\n",
" the user overcome some of the limitations of the distance-based similarity search. \n",
" Provide these alternative questions seperated by newlines.\n",
" Original question: {question}\"\"\",\n",
")\n",
"llm = ChatOpenAI(temperature=0)\n",
"\n",
"# Chain\n",
"llm_chain = LLMChain(llm=llm,prompt=QUERY_PROMPT,output_parser=output_parser)\n",
" \n",
"# Other inputs\n",
"question=\"What does the course say about regression?\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6660d7ee",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Generated queries: [\"1. What is the course's perspective on regression?\", '2. Can you provide information on regression as discussed in the course?', '3. How does the course cover the topic of regression?', \"4. What are the course's teachings on regression?\", '5. In relation to the course, what is mentioned about regression?']\n"
]
},
{
"data": {
"text/plain": [
"8"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Run\n",
"retriever = MultiQueryRetriever(retriever=vectordb.as_retriever(), \n",
" llm_chain=llm_chain,\n",
" parser_key=\"lines\") # \"lines\" is the key (attribute name) of the parsed output\n",
"\n",
"# Results\n",
"unique_docs = retriever.get_relevant_documents(question=\"What does the course say about regression?\")\n",
"len(unique_docs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -14,6 +14,7 @@ from langchain.retrievers.llama_index import (
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.milvus import MilvusRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.pupmed import PubMedRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
@ -43,6 +44,7 @@ __all__ = [
"MergerRetriever",
"MetalRetriever",
"MilvusRetriever",
"MultiQueryRetriever",
"PineconeHybridSearchRetriever",
"PubMedRetriever",
"RemoteLangChainRetriever",

@ -0,0 +1,158 @@
import logging
from typing import List
from pydantic import BaseModel, Field
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.output_parsers.pydantic import PydanticOutputParser
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseRetriever, Document
logging.basicConfig(level=logging.INFO)
class LineList(BaseModel):
lines: List[str] = Field(description="Lines of text")
class LineListOutputParser(PydanticOutputParser):
def __init__(self) -> None:
super().__init__(pydantic_object=LineList)
def parse(self, text: str) -> LineList:
lines = text.strip().split("\n")
return LineList(lines=lines)
# Default prompt
DEFAULT_QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is
to generate 3 different versions of the given user
question to retrieve relevant documents from a vector database.
By generating multiple perspectives on the user question,
your goal is to help the user overcome some of the limitations
of distance-based similarity search. Provide these alternative
questions seperated by newlines. Original question: {question}""",
)
class MultiQueryRetriever(BaseRetriever):
"""Given a user query, use an LLM to write a set of queries.
Retrieve docs for each query. Rake the unique union of all retrieved docs."""
def __init__(
self,
retriever: BaseRetriever,
llm_chain: LLMChain,
verbose: bool = True,
parser_key: str = "lines",
) -> None:
"""Initialize MultiQueryRetriever.
Args:
retriever: retriever to query documents from
llm_chain: llm_chain for query generation
verbose: show the queries that we generated to the user
parser_key: attribute name for the parsed output
Returns:
MultiQueryRetriever
"""
self.retriever = retriever
self.llm_chain = llm_chain
self.verbose = verbose
self.parser_key = parser_key
@classmethod
def from_llm(
cls,
retriever: BaseRetriever,
llm: BaseLLM,
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
parser_key: str = "lines",
) -> "MultiQueryRetriever":
"""Initialize from llm using default template.
Args:
retriever: retriever to query documents from
llm: llm for query generation using DEFAULT_QUERY_PROMPT
Returns:
MultiQueryRetriever
"""
output_parser = LineListOutputParser()
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
return cls(
retriever=retriever,
llm_chain=llm_chain,
parser_key=parser_key,
)
def get_relevant_documents(self, question: str) -> List[Document]:
"""Get relevated documents given a user query.
Args:
question: user query
Returns:
Unique union of relevant documents from all generated queries
"""
queries = self.generate_queries(question)
documents = self.retrieve_documents(queries)
unique_documents = self.unique_union(documents)
return unique_documents
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
def generate_queries(self, question: str) -> List[str]:
"""Generate queries based upon user input.
Args:
question: user query
Returns:
List of LLM generated queries that are similar to the user input
"""
response = self.llm_chain({"question": question})
lines = getattr(response["text"], self.parser_key, [])
if self.verbose:
logging.info(f"Generated queries: {lines}")
return lines
def retrieve_documents(self, queries: List[str]) -> List[Document]:
"""Run all LLM generated queries.
Args:
queries: query list
Returns:
List of retrived Documents
"""
documents = []
for query in queries:
docs = self.retriever.get_relevant_documents(query)
documents.extend(docs)
return documents
def unique_union(self, documents: List[Document]) -> List[Document]:
"""Get uniqe Documents.
Args:
documents: List of retrived Documents
Returns:
List of unique retrived Documents
"""
# Create a dictionary with page_content as keys to remove duplicates
# TODO: Add Document ID property (e.g., UUID)
unique_documents_dict = {
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc
for doc in documents
}
unique_documents = list(unique_documents_dict.values())
return unique_documents
Loading…
Cancel
Save