diff --git a/docs/extras/integrations/retrievers/re_phrase.ipynb b/docs/extras/integrations/retrievers/re_phrase.ipynb new file mode 100644 index 0000000000..31221663e3 --- /dev/null +++ b/docs/extras/integrations/retrievers/re_phrase.ipynb @@ -0,0 +1,222 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e8624be2", + "metadata": {}, + "source": [ + "# RePhraseQueryRetriever\n", + "\n", + "Simple retriever that applies an LLM between the user input and the query pass the to retriever.\n", + "\n", + "It can be used to pre-process the user input in any way.\n", + "\n", + "The default prompt used in the `from_llm` classmethod:\n", + "\n", + "```\n", + "DEFAULT_TEMPLATE = \"\"\"You are an assistant tasked with taking a natural language \\\n", + "query from a user and converting it into a query for a vectorstore. \\\n", + "In this process, you strip out information that is not relevant for \\\n", + "the retrieval task. Here is the user query: {question}\"\"\"\n", + "```\n", + "\n", + "Create a vectorstore." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1bfa6834", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import WebBaseLoader\n", + "\n", + "loader = WebBaseLoader(\"https://lilianweng.github.io/posts/2023-06-23-agent/\")\n", + "data = loader.load()\n", + "\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "\n", + "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n", + "all_splits = text_splitter.split_documents(data)\n", + "\n", + "from langchain.vectorstores import Chroma\n", + "from langchain.embeddings import OpenAIEmbeddings\n", + "\n", + "vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d0b51556", + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.basicConfig()\n", + "logging.getLogger(\"langchain.retrievers.re_phraser\").setLevel(logging.INFO)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "20e1e787", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.retrievers import RePhraseQueryRetriever" + ] + }, + { + "cell_type": "markdown", + "id": "88c0a972", + "metadata": {}, + "source": [ + "## Using the default prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "503994bd", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(temperature=0)\n", + "retriever_from_llm = RePhraseQueryRetriever.from_llm(\n", + " retriever=vectorstore.as_retriever(), llm=llm\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8d17ecc9", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:langchain.retrievers.re_phraser:Re-phrased question: The user query can be converted into a query for a vectorstore as follows:\n", + "\n", + "\"approaches to Task Decomposition\"\n" + ] + } + ], + "source": [ + "docs = retriever_from_llm.get_relevant_documents(\n", + " \"Hi I'm Lance. What are the approaches to Task Decomposition?\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "76d54f1a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:langchain.retrievers.re_phraser:Re-phrased question: Query for vectorstore: \"Types of Memory\"\n" + ] + } + ], + "source": [ + "docs = retriever_from_llm.get_relevant_documents(\n", + " \"I live in San Francisco. What are the Types of Memory?\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0513a6e2", + "metadata": {}, + "source": [ + "## Supply a prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "410d6a64", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import LLMChain\n", + "from langchain.prompts import PromptTemplate\n", + "\n", + "QUERY_PROMPT = PromptTemplate(\n", + " input_variables=[\"question\"],\n", + " template=\"\"\"You are an assistant tasked with taking a natural languge query from a user\n", + " and converting it into a query for a vectorstore. In the process, strip out all \n", + " information that is not relevant for the retrieval task and return a new, simplified\n", + " question for vectorstore retrieval. The new user query should be in pirate speech.\n", + " Here is the user query: {question} \"\"\",\n", + ")\n", + "llm = ChatOpenAI(temperature=0)\n", + "llm_chain = LLMChain(llm=llm, prompt=QUERY_PROMPT)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2dbffdd3", + "metadata": {}, + "outputs": [], + "source": [ + "retriever_from_llm_chain = RePhraseQueryRetriever(\n", + " retriever=vectorstore.as_retriever(), llm_chain=llm_chain\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "103b4be3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:langchain.retrievers.re_phraser:Re-phrased question: Ahoy matey! What be Maximum Inner Product Search, ye scurvy dog?\n" + ] + } + ], + "source": [ + "docs = retriever_from_llm_chain.get_relevant_documents(\n", + " \"Hi I'm Lance. What is Maximum Inner Product Search?\"\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/retrievers/__init__.py b/libs/langchain/langchain/retrievers/__init__.py index c820daa982..f0922460a1 100644 --- a/libs/langchain/langchain/retrievers/__init__.py +++ b/libs/langchain/langchain/retrievers/__init__.py @@ -42,6 +42,7 @@ from langchain.retrievers.milvus import MilvusRetriever from langchain.retrievers.multi_query import MultiQueryRetriever from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever from langchain.retrievers.pubmed import PubMedRetriever +from langchain.retrievers.re_phraser import RePhraseQueryRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever from langchain.retrievers.self_query.base import SelfQueryRetriever from langchain.retrievers.svm import SVMRetriever @@ -86,6 +87,7 @@ __all__ = [ "ZepRetriever", "ZillizRetriever", "DocArrayRetriever", + "RePhraseQueryRetriever", "WebResearchRetriever", "EnsembleRetriever", ] diff --git a/libs/langchain/langchain/retrievers/re_phraser.py b/libs/langchain/langchain/retrievers/re_phraser.py new file mode 100644 index 0000000000..50baea43f4 --- /dev/null +++ b/libs/langchain/langchain/retrievers/re_phraser.py @@ -0,0 +1,87 @@ +import logging +from typing import List + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) +from langchain.chains.llm import LLMChain +from langchain.llms.base import BaseLLM +from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BaseRetriever, Document + +logger = logging.getLogger(__name__) + +# Default template +DEFAULT_TEMPLATE = """You are an assistant tasked with taking a natural language \ +query from a user and converting it into a query for a vectorstore. \ +In this process, you strip out information that is not relevant for \ +the retrieval task. Here is the user query: {question}""" + +# Default prompt +DEFAULT_QUERY_PROMPT = PromptTemplate.from_template(DEFAULT_TEMPLATE) + + +class RePhraseQueryRetriever(BaseRetriever): + + """Given a user query, use an LLM to re-phrase it. + Then, retrieve docs for re-phrased query.""" + + retriever: BaseRetriever + llm_chain: LLMChain + + @classmethod + def from_llm( + cls, + retriever: BaseRetriever, + llm: BaseLLM, + prompt: PromptTemplate = DEFAULT_QUERY_PROMPT, + ) -> "RePhraseQueryRetriever": + """Initialize from llm using default template. + + The prompt used here expects a single input: `question` + + Args: + retriever: retriever to query documents from + llm: llm for query generation using DEFAULT_QUERY_PROMPT + prompt: prompt template for query generation + + Returns: + RePhraseQueryRetriever + """ + + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls( + retriever=retriever, + llm_chain=llm_chain, + ) + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + ) -> List[Document]: + """Get relevated documents given a user question. + + Args: + query: user question + + Returns: + Relevant documents for re-phrased question + """ + response = self.llm_chain(query, callbacks=run_manager.get_child()) + re_phrased_question = response["text"] + logger.info(f"Re-phrased question: {re_phrased_question}") + docs = self.retriever.get_relevant_documents( + re_phrased_question, callbacks=run_manager.get_child() + ) + return docs + + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + ) -> List[Document]: + raise NotImplementedError