From 9d9198de0b213c19138d7f8a5fc9bb164e4419ca Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 21 Oct 2023 09:31:10 -0700 Subject: [PATCH] rewrite (#12111) --- cookbook/rewrite.ipynb | 351 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 351 insertions(+) create mode 100644 cookbook/rewrite.ipynb diff --git a/cookbook/rewrite.ipynb b/cookbook/rewrite.ipynb new file mode 100644 index 0000000000..9631b300fb --- /dev/null +++ b/cookbook/rewrite.ipynb @@ -0,0 +1,351 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "59fb0852", + "metadata": {}, + "source": [ + "# Rewrite-Retrieve-Read\n", + "\n", + "**Rewrite-Retrieve-Read** is a method proposed in the paper [Query Rewriting for Retrieval-Augmented Large Language Models](https://arxiv.org/pdf/2305.14283.pdf)\n", + "\n", + "> Because the original query can not be always optimal to retrieve for the LLM, especially in the real world... we first prompt an LLM to rewrite the queries, then conduct retrieval-augmented reading\n", + "\n", + "We show how you can easily do that with LangChain Expression Language" + ] + }, + { + "cell_type": "markdown", + "id": "e11473f2", + "metadata": {}, + "source": [ + "## Baseline\n", + "\n", + "Baseline RAG (**Retrieve-and-read**) can be done like the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ea9022ff", + "metadata": {}, + "outputs": [], + "source": [ + "from operator import itemgetter\n", + "\n", + "from langchain.prompts import ChatPromptTemplate\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.schema.output_parser import StrOutputParser\n", + "from langchain.schema.runnable import RunnablePassthrough, RunnableLambda\n", + "from langchain.utilities import DuckDuckGoSearchAPIWrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "75aceda8", + "metadata": {}, + "outputs": [], + "source": [ + "template = \"\"\"Answer the users question based only on the following context:\n", + "\n", + "\n", + "{context}\n", + "\n", + "\n", + "Question: {question}\n", + "\"\"\"\n", + "prompt = ChatPromptTemplate.from_template(template)\n", + "\n", + "model = ChatOpenAI(temperature=0)\n", + "\n", + "search = DuckDuckGoSearchAPIWrapper()\n", + "\n", + "\n", + "def retriever(query):\n", + " return search.run(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "db8a2a8f", + "metadata": {}, + "outputs": [], + "source": [ + "chain = (\n", + " {\"context\": retriever, \"question\": RunnablePassthrough()} \n", + " | prompt \n", + " | model \n", + " | StrOutputParser()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "002efdfa", + "metadata": {}, + "outputs": [], + "source": [ + "simple_query = \"what is langchain?\"" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "2b74cc69", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'LangChain is a powerful framework and versatile Python library that simplifies the development of language-based applications. It offers a suite of features for artificial general intelligence, including document analysis and summarization, as well as the ability to build chatbots that interact with users naturally. It is an open-source library that enables developers and researchers to create, experiment with, and analyze language models and agents. LangChain provides a generic interface to many foundation models, prompt management, and acts as a central interface to other components like prompt templates, other language models, external data, and other tools via agents. Overall, LangChain is designed to help developers build end-to-end applications using language models and offers a range of tools, components, and interfaces to simplify the process.'" + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.invoke(simple_query)" + ] + }, + { + "cell_type": "markdown", + "id": "09dedf51", + "metadata": {}, + "source": [ + "While this is fine for well formatted queries, it can break down for more complicated queries" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "7687cbf4", + "metadata": {}, + "outputs": [], + "source": [ + "distracted_query = \"man that sam bankman fried trial was crazy! what is langchain?\"" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "9ef1f1aa", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Based on the given context, there is no information about \"langchain.\"'" + ] + }, + "execution_count": 95, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.invoke(distracted_query)" + ] + }, + { + "cell_type": "markdown", + "id": "1a7df277", + "metadata": {}, + "source": [ + "This is because the retriever does a bad job with these \"distracted\" queries" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "id": "72df8d50", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Sam Bankman-Fried, FTX\\'s founder, responded with a single word: \"Oof.\". Less than a year later, Mr. Bankman-Fried, 31, is on trial in federal court in Manhattan, fighting criminal charges ... NEW YORK, Oct 18 (Reuters) - A U.S. judge on Wednesday overruled objections by Sam Bankman-Fried\\'s lawyers and allowed jurors in the FTX founder\\'s fraud trial to see a profane message he... Business FTX founder Sam Bankman-Fried\\'s trial is about to start. Here\\'s what you need to know In testimony on Tuesday and Wednesday that got tearful at times, Ellison accused... Sam Bankman-Fried, who was once hailed as a virtuoso in cryptocurrency trading, is on trial over the collapse of FTX, the financial exchange he founded. Bankman-Fried is accused of... Business Oct 2, 2023 11:29 AM The Trial of Sam Bankman-Fried, Explained White-collar defendants use three main defenses: \"It wasn\\'t me, I didn\\'t mean it, and the people that say I did are...'" + ] + }, + "execution_count": 96, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "retriever(distracted_query)" + ] + }, + { + "cell_type": "markdown", + "id": "987fbfd1", + "metadata": {}, + "source": [ + "## Rewrite-Retrieve-Read Implementation\n", + "\n", + "The main part is a rewriter to rewrite the search query" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "id": "fbe530b7", + "metadata": {}, + "outputs": [], + "source": [ + "# template = \"\"\"Provide a better search query for \\\n", + "# web search engine to answer the given question, end \\\n", + "# the queries with ’**’. Question: \\\n", + "# {x} Answer:\"\"\"\n", + "# prompt = ChatPromptTemplate.from_template(template)" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "id": "01604b7d", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import hub\n", + "\n", + "prompt = hub.pull(\"langchain-ai/rewrite\")" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "id": "d4fa60fb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Provide a better search query for web search engine to answer the given question, end the queries with ’**’. Question {x} Answer:\n" + ] + } + ], + "source": [ + "print(prompt.template)" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "id": "3e505a44", + "metadata": {}, + "outputs": [], + "source": [ + "# Parser to remove the `**`\n", + "\n", + "def _parse(text):\n", + " return text.strip(\"**\")" + ] + }, + { + "cell_type": "code", + "execution_count": 112, + "id": "4f4c12e6", + "metadata": {}, + "outputs": [], + "source": [ + "rewriter = prompt | ChatOpenAI(temperature=0) | StrOutputParser() | _parse" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "id": "ddbe5ac2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'What is the definition and purpose of Langchain?'" + ] + }, + "execution_count": 113, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rewriter.invoke({\"x\": distracted_query})" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "id": "23a4f3f0", + "metadata": {}, + "outputs": [], + "source": [ + "rewrite_retrieve_read_chain = (\n", + " {\n", + " \"context\": {\"x\": RunnablePassthrough()} | rewriter | retriever,\n", + " \"question\": RunnablePassthrough()} \n", + " | prompt \n", + " | model \n", + " | StrOutputParser()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "3c66e5c9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Based on the given context, LangChain is an open-source framework designed to simplify the creation of applications using large language models (LLMs). It provides a standard interface for chains, integrations with other tools, and end-to-end chains for common applications. LangChain enables LLM models to generate responses based on up-to-date online information and simplifies the organization of large volumes of data for easy access by LLMs. It is an AI framework with unique features that simplify the development of language-based applications.'" + ] + }, + "execution_count": 102, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rewrite_retrieve_read_chain.invoke(distracted_query)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "974a17b3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}