You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/docs/modules/chat/examples/vector_db_qa_with_sources.i...

207 lines
5.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "efc5be67",
"metadata": {},
"source": [
"# Retrieval Question Answering with Sources\n",
"\n",
"This notebook goes over how to do question-answering with sources with a chat model over a vector database. It does this by using the `RetrievalQAWithSourcesChain`, which does the lookup of the documents from a vector database. \n",
"\n",
"This notebook is very similar to the example of using an LLM in the RetrievalQAWithSources. The only differences here are (1) using a ChatModel, and (2) passing in a ChatPromptTemplate (optimized for chat models)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1c613960",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.embeddings.cohere import CohereEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch\n",
"from langchain.vectorstores import Chroma"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "17d1306e",
"metadata": {},
"outputs": [],
"source": [
"with open('../../state_of_the_union.txt') as f:\n",
" state_of_the_union = f.read()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"texts = text_splitter.split_text(state_of_the_union)\n",
"\n",
"embeddings = OpenAIEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0e745d99",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Chroma using direct local API.\n",
"Using DuckDB in-memory for database. Data will be transient.\n",
"Running Chroma using direct local API.\n",
"Using DuckDB in-memory for database. Data will be transient.\n"
]
}
],
"source": [
"docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{\"source\": f\"{i}-pl\"} for i in range(len(texts))])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8aa571ae",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import RetrievalQAWithSourcesChain"
]
},
{
"cell_type": "markdown",
"id": "1f73b14a",
"metadata": {},
"source": [
"We can now set up the chat model and chat model specific prompt"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9643c775",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts.chat import (\n",
" ChatPromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
" AIMessagePromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
")\n",
"from langchain.schema import (\n",
" AIMessage,\n",
" HumanMessage,\n",
" SystemMessage\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ed00e906",
"metadata": {},
"outputs": [],
"source": [
"system_template=\"\"\"Use the following pieces of context to answer the users question. \n",
"If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
"ALWAYS return a \"SOURCES\" part in your answer.\n",
"The \"SOURCES\" part should be a reference to the source of the document from which you got your answer.\n",
"\n",
"Example of your response should be:\n",
"\n",
"```\n",
"The answer is foo\n",
"SOURCES: xyz\n",
"```\n",
"\n",
"Begin!\n",
"----------------\n",
"{summaries}\"\"\"\n",
"messages = [\n",
" SystemMessagePromptTemplate.from_template(system_template),\n",
" HumanMessagePromptTemplate.from_template(\"{question}\")\n",
"]\n",
"prompt = ChatPromptTemplate.from_messages(messages)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "aa859d4c",
"metadata": {},
"outputs": [],
"source": [
"chain_type_kwargs = {\"prompt\": prompt}\n",
"chain = RetrievalQAWithSourcesChain.from_chain_type(\n",
" ChatOpenAI(temperature=0), \n",
" chain_type=\"stuff\", \n",
" retriever=docsearch.as_retriever(),\n",
" chain_type_kwargs=chain_type_kwargs\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8ba36fa7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'answer': 'The President honored Justice Stephen Breyer, an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court, for his dedicated service to the country. \\n',\n",
" 'sources': '31-pl'}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain({\"question\": \"What did the president say about Justice Breyer\"}, return_only_outputs=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8308fbf7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
},
"vscode": {
"interpreter": {
"hash": "b1677b440931f40d89ef8be7bf03acb108ce003de0ac9b18e8d43753ea2e7103"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}