forked from Archives/langchain
264 lines
7.7 KiB
Plaintext
264 lines
7.7 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ccb74c9b",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Hypothetical Document Embeddings\n",
|
||
"This notebook goes over how to use Hypothetical Document Embeddings (HyDE), as described in [this paper](https://arxiv.org/abs/2212.10496). \n",
|
||
"\n",
|
||
"At a high level, HyDE is an embedding technique that takes queries, generates a hypothetical answer, and then embeds that generated document and uses that as the final example. \n",
|
||
"\n",
|
||
"In order to use HyDE, we therefore need to provide a base embedding model, as well as an LLMChain that can be used to generate those documents. By default, the HyDE class comes with some default prompts to use (see the paper for more details on them), but we can also create our own."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "546e87ee",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.llms import OpenAI\n",
|
||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||
"from langchain.chains import LLMChain, HypotheticalDocumentEmbedder\n",
|
||
"from langchain.prompts import PromptTemplate"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "c0ea895f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"base_embeddings = OpenAIEmbeddings()\n",
|
||
"llm = OpenAI()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "33bd6905",
|
||
"metadata": {},
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "50729989",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Load with `web_search` prompt\n",
|
||
"embeddings = HypotheticalDocumentEmbedder.from_llm(llm, base_embeddings, \"web_search\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "3aa573d6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Now we can use it as any embedding class!\n",
|
||
"result = embeddings.embed_query(\"Where is the Taj Mahal?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c7a0b556",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Multiple generations\n",
|
||
"We can also generate multiple documents and then combine the embeddings for those. By default, we combine those by taking the average. We can do this by changing the LLM we use to generate documents to return multiple things."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "05da7060",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"multi_llm = OpenAI(n=4, best_of=4)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "9b1e12bd",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"embeddings = HypotheticalDocumentEmbedder.from_llm(multi_llm, base_embeddings, \"web_search\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "a60cd343",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"result = embeddings.embed_query(\"Where is the Taj Mahal?\")"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"id": "1da90437",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Using our own prompts\n",
|
||
"Besides using preconfigured prompts, we can also easily construct our own prompts and use those in the LLMChain that is generating the documents. This can be useful if we know the domain our queries will be in, as we can condition the prompt to generate text more similar to that.\n",
|
||
"\n",
|
||
"In the example below, let's condition it to generate text about a state of the union address (because we will use that in the next example)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "0b4a650f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"prompt_template = \"\"\"Please answer the user's question about the most recent state of the union address\n",
|
||
"Question: {question}\n",
|
||
"Answer:\"\"\"\n",
|
||
"prompt = PromptTemplate(input_variables=[\"question\"], template=prompt_template)\n",
|
||
"llm_chain = LLMChain(llm=llm, prompt=prompt)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "7f7e2b86",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"embeddings = HypotheticalDocumentEmbedder(llm_chain=llm_chain, base_embeddings=base_embeddings)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "6dd83424",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"result = embeddings.embed_query(\"What did the president say about Ketanji Brown Jackson\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "31388123",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Using HyDE\n",
|
||
"Now that we have HyDE, we can use it as we would any other embedding class! Here is using it to find similar passages in the state of the union example."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"id": "97719b29",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||
"from langchain.vectorstores import Chroma\n",
|
||
"\n",
|
||
"with open('../../state_of_the_union.txt') as f:\n",
|
||
" state_of_the_union = f.read()\n",
|
||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||
"texts = text_splitter.split_text(state_of_the_union)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "bfcfc039",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Running Chroma using direct local API.\n",
|
||
"Using DuckDB in-memory for database. Data will be transient.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"docsearch = Chroma.from_texts(texts, embeddings)\n",
|
||
"\n",
|
||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||
"docs = docsearch.similarity_search(query)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "632af7f2",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
|
||
"\n",
|
||
"We cannot let this happen. \n",
|
||
"\n",
|
||
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||
"\n",
|
||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||
"\n",
|
||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||
"\n",
|
||
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(docs[0].page_content)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b9e57b93",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.12 (main, Mar 26 2022, 15:51:15) \n[Clang 13.1.6 (clang-1316.0.21.2)]"
|
||
},
|
||
"vscode": {
|
||
"interpreter": {
|
||
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|