Harrison/functions in retrieval (#6463)

master
Harrison Chase 11 months ago committed by GitHub
parent dc4ffa8d9b
commit 02c0a1e77e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,387 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "71a43144",
"metadata": {},
"source": [
"# Retrieval QA using OpenAI functions\n",
"\n",
"OpenAI functions allows for structuring of response output. This is often useful in question answering when you want to not only get the final answer but also supporting evidence, citations, etc.\n",
"\n",
"In this notebook we show how to use an LLM chain which uses OpenAI functions as part of an overall retrieval pipeline."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f059012e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain.chains import RetrievalQA\n",
"from langchain.document_loaders import TextLoader\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import Chroma"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f10b831c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using embedded DuckDB without persistence: data will be transient\n"
]
}
],
"source": [
"loader = TextLoader(\"../../state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"texts = text_splitter.split_documents(documents)\n",
"for i, text in enumerate(texts):\n",
" text.metadata['source'] = f\"{i}-pl\"\n",
"embeddings = OpenAIEmbeddings()\n",
"docsearch = Chroma.from_documents(texts, embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "70f3a38c",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import create_qa_with_sources_chain"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7b3e1731",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "70a9ccff",
"metadata": {},
"outputs": [],
"source": [
"qa_chain = create_qa_with_sources_chain(llm)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "efcdb6fb",
"metadata": {},
"outputs": [],
"source": [
"doc_prompt = PromptTemplate(\n",
" template=\"Content: {page_content}\\nSource: {source}\",\n",
" input_variables=[\"page_content\", \"source\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "64a08263",
"metadata": {},
"outputs": [],
"source": [
"final_qa_chain = StuffDocumentsChain(\n",
" llm_chain=qa_chain, \n",
" document_variable_name='context',\n",
" document_prompt=doc_prompt,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cb876c97",
"metadata": {},
"outputs": [],
"source": [
"retrieval_qa = RetrievalQA(\n",
" retriever=docsearch.as_retriever(),\n",
" combine_documents_chain=final_qa_chain\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a75bad9b",
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about russia\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9a60f109",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\\n \"answer\": \"The President expressed strong condemnation of Russia\\'s actions in Ukraine and announced measures to isolate Russia and provide support to Ukraine. He stated that Russia\\'s invasion of Ukraine will have long-term consequences for Russia and emphasized the commitment of the United States and its allies to defend NATO countries. The President also mentioned the imposition of sanctions on Russia and the release of oil reserves to help mitigate gas prices. Overall, the President\\'s remarks conveyed a firm stance against Russia\\'s aggression and a commitment to supporting Ukraine and protecting American interests.\",\\n \"sources\": [\"0-pl\", \"4-pl\", \"5-pl\", \"6-pl\"]\\n}'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retrieval_qa.run(query)"
]
},
{
"cell_type": "markdown",
"id": "a60f93a4",
"metadata": {},
"source": [
"## Using Pydantic\n",
"\n",
"If we want to, we can set the chain to return in Pydantic. Note that if downstream chains consume the output of this chain - including memory - they will generally expect it to be in string format, so you should only use this chain when it is the final chain."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3559727f",
"metadata": {},
"outputs": [],
"source": [
"qa_chain_pydantic = create_qa_with_sources_chain(llm, output_parser=\"pydantic\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "5a7997d1",
"metadata": {},
"outputs": [],
"source": [
"final_qa_chain_pydantic = StuffDocumentsChain(\n",
" llm_chain=qa_chain_pydantic, \n",
" document_variable_name='context',\n",
" document_prompt=doc_prompt,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "79368e40",
"metadata": {},
"outputs": [],
"source": [
"retrieval_qa_pydantic = RetrievalQA(\n",
" retriever=docsearch.as_retriever(),\n",
" combine_documents_chain=final_qa_chain_pydantic\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6b8641de",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AnswerWithSources(answer=\"The President expressed strong condemnation of Russia's actions in Ukraine and announced measures to isolate Russia and provide support to Ukraine. He stated that Russia's invasion of Ukraine will have long-term consequences for Russia and emphasized the commitment of the United States and its allies to defend NATO countries. The President also mentioned the economic impact of sanctions on Russia and the release of oil reserves to help mitigate gas prices. Overall, the President conveyed a message of solidarity with Ukraine and a determination to protect American interests and support freedom.\", sources=['0-pl', '4-pl', '5-pl', '6-pl'])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retrieval_qa_pydantic.run(query)"
]
},
{
"cell_type": "markdown",
"id": "e4c15395",
"metadata": {},
"source": [
"## Using in ConversationalRetrievalChain\n",
"\n",
"We can also show what it's like to use this in the ConversationalRetrievalChain. Note that because this chain involves memory, we will NOT use the Pydantic return type."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "18e5f090",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import ConversationalRetrievalChain\n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.chains import LLMChain\n",
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n",
"_template = \"\"\"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.\\\n",
"Make sure to avoid using any unclear pronouns.\n",
"\n",
"Chat History:\n",
"{chat_history}\n",
"Follow Up Input: {question}\n",
"Standalone question:\"\"\"\n",
"CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)\n",
"condense_question_chain = LLMChain(\n",
" llm=llm,\n",
" prompt=CONDENSE_QUESTION_PROMPT,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "975c3c2b",
"metadata": {},
"outputs": [],
"source": [
"qa = ConversationalRetrievalChain(\n",
" question_generator=condense_question_chain, \n",
" retriever=docsearch.as_retriever(),\n",
" memory=memory, \n",
" combine_docs_chain=final_qa_chain\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "784aee3a",
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"result = qa({\"question\": query})"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "dfd0ccc1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'question': 'What did the president say about Ketanji Brown Jackson',\n",
" 'chat_history': [HumanMessage(content='What did the president say about Ketanji Brown Jackson', additional_kwargs={}, example=False),\n",
" AIMessage(content='{\\n \"answer\": \"The President nominated Ketanji Brown Jackson as a Circuit Court of Appeals Judge and praised her as one of the nation\\'s top legal minds who will continue Justice Breyer\\'s legacy of excellence.\",\\n \"sources\": [\"31-pl\"]\\n}', additional_kwargs={}, example=False)],\n",
" 'answer': '{\\n \"answer\": \"The President nominated Ketanji Brown Jackson as a Circuit Court of Appeals Judge and praised her as one of the nation\\'s top legal minds who will continue Justice Breyer\\'s legacy of excellence.\",\\n \"sources\": [\"31-pl\"]\\n}'}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "c93f805b",
"metadata": {},
"outputs": [],
"source": [
"query = \"what did he say about her predecessor?\"\n",
"result = qa({\"question\": query})"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "5d8612c0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'question': 'what did he say about her predecessor?',\n",
" 'chat_history': [HumanMessage(content='What did the president say about Ketanji Brown Jackson', additional_kwargs={}, example=False),\n",
" AIMessage(content='{\\n \"answer\": \"The President nominated Ketanji Brown Jackson as a Circuit Court of Appeals Judge and praised her as one of the nation\\'s top legal minds who will continue Justice Breyer\\'s legacy of excellence.\",\\n \"sources\": [\"31-pl\"]\\n}', additional_kwargs={}, example=False),\n",
" HumanMessage(content='what did he say about her predecessor?', additional_kwargs={}, example=False),\n",
" AIMessage(content='{\\n \"answer\": \"The President honored Justice Stephen Breyer for his service as an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court.\",\\n \"sources\": [\"31-pl\"]\\n}', additional_kwargs={}, example=False)],\n",
" 'answer': '{\\n \"answer\": \"The President honored Justice Stephen Breyer for his service as an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court.\",\\n \"sources\": [\"31-pl\"]\\n}'}"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8d07f6f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -31,6 +31,8 @@ from langchain.chains.openai_functions import (
create_citation_fuzzy_match_chain,
create_extraction_chain,
create_extraction_chain_pydantic,
create_qa_with_sources_chain,
create_qa_with_structure_chain,
create_tagging_chain,
create_tagging_chain_pydantic,
)
@ -99,6 +101,8 @@ __all__ = [
"create_tagging_chain_pydantic",
"load_chain",
"create_citation_fuzzy_match_chain",
"create_qa_with_structure_chain",
"create_qa_with_sources_chain",
"StuffDocumentsChain",
"MapRerankDocumentsChain",
"MapReduceDocumentsChain",

@ -5,6 +5,10 @@ from langchain.chains.openai_functions.extraction import (
create_extraction_chain,
create_extraction_chain_pydantic,
)
from langchain.chains.openai_functions.qa_with_structure import (
create_qa_with_sources_chain,
create_qa_with_structure_chain,
)
from langchain.chains.openai_functions.tagging import (
create_tagging_chain,
create_tagging_chain_pydantic,
@ -16,4 +20,6 @@ __all__ = [
"create_extraction_chain_pydantic",
"create_extraction_chain",
"create_citation_fuzzy_match_chain",
"create_qa_with_structure_chain",
"create_qa_with_sources_chain",
]

@ -0,0 +1,70 @@
from typing import Any, List
from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import get_llm_kwargs
from langchain.output_parsers.openai_functions import (
OutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import BaseLLMOutputParser, HumanMessage, SystemMessage
class AnswerWithSources(BaseModel):
"""An answer to the question being asked, with sources."""
answer: str = Field(..., description="Answer to the question that was asked")
sources: List[str] = Field(
..., description="List of sources used to answer the question"
)
def create_qa_with_structure_chain(
llm: BaseLanguageModel, schema: Any, output_parser: str = "base"
) -> LLMChain:
if output_parser == "pydantic":
_output_parser: BaseLLMOutputParser = PydanticOutputFunctionsParser(
pydantic_schema=schema
)
elif output_parser == "base":
_output_parser = OutputFunctionsParser()
else:
raise ValueError(
f"Got unexpected output_parser: {output_parser}. "
f"Should be one of `pydantic` or `base`."
)
schema = AnswerWithSources.schema()
function = {
"name": schema["title"],
"description": schema["description"],
"parameters": schema,
}
llm_kwargs = get_llm_kwargs(function)
messages = [
SystemMessage(
content=(
"You are a world class algorithm to answer "
"questions in a specific format."
)
),
HumanMessage(content="Answer question using the following context"),
HumanMessagePromptTemplate.from_template("{context}"),
HumanMessagePromptTemplate.from_template("Question: {question}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = ChatPromptTemplate(messages=messages)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=_output_parser,
)
return chain
def create_qa_with_sources_chain(llm: BaseLanguageModel, **kwargs: Any) -> LLMChain:
return create_qa_with_structure_chain(llm, AnswerWithSources, **kwargs)
Loading…
Cancel
Save