(WIP) add HyDE (#393)

Co-authored-by: cameronccohen <cameron.c.cohen@gmail.com>
Co-authored-by: Cameron Cohen <cameron.cohen@quantco.com>
fork-chains
Harrison Chase 1 year ago committed by GitHub
parent 543db9c2df
commit 6b60c509ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,242 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ccb74c9b",
"metadata": {},
"source": [
"# Hypothetical Document Embeddings\n",
"This notebook goes over how to use Hypothetical Document Embeddings (HyDE), as described in [this paper](https://arxiv.org/abs/2212.10496). \n",
"\n",
"At a high level, HyDE is an embedding technique that takes queries, generates a hypothetical answer, and then embeds that generated document and uses that as the final example. \n",
"\n",
"In order to use HyDE, we therefor need to provide a base embedding model, as well as an LLMChain that can be used to generate those documents. By default, the HyDE class comes with some default prompts to use (see the paper for more details on them), but we can also create our own."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "546e87ee",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.embeddings import OpenAIEmbeddings, HypotheticalDocumentEmbedder\n",
"from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c0ea895f",
"metadata": {},
"outputs": [],
"source": [
"base_embeddings = OpenAIEmbeddings()\n",
"llm = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "50729989",
"metadata": {},
"outputs": [],
"source": [
"# Load with `web_search` prompt\n",
"embeddings = HypotheticalDocumentEmbedder.from_llm(llm, base_embeddings, \"web_search\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3aa573d6",
"metadata": {},
"outputs": [],
"source": [
"# Now we can use it as any embedding class!\n",
"result = embeddings.embed_query(\"Where is the Taj Mahal?\")"
]
},
{
"cell_type": "markdown",
"id": "c7a0b556",
"metadata": {},
"source": [
"## Multiple generations\n",
"We can also generate multiple documents and then combine the embeddings for those. By default, we combine those by taking the average. We can do this by changing the LLM we use to generate documents to return multiple things."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "05da7060",
"metadata": {},
"outputs": [],
"source": [
"multi_llm = OpenAI(n=4, best_of=4)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9b1e12bd",
"metadata": {},
"outputs": [],
"source": [
"embeddings = HypotheticalDocumentEmbedder.from_llm(multi_llm, base_embeddings, \"web_search\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a60cd343",
"metadata": {},
"outputs": [],
"source": [
"result = embeddings.embed_query(\"Where is the Taj Mahal?\")"
]
},
{
"cell_type": "markdown",
"id": "1da90437",
"metadata": {},
"source": [
"## Using our own prompts\n",
"Besides using preconfigured prompts, we can also easily construct our own prompts and use those in the LLMChain that is generating the documents. This can be useful if we know the domain our queries will be in, as we can condition the prompt to generate text more similar to that.\n",
"\n",
"In the example below, let's condition it generate text about a state of the union address (because we will use that in the next example)."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0b4a650f",
"metadata": {},
"outputs": [],
"source": [
"prompt_template = \"\"\"Please answer the user's question about the most recent state of the union address\n",
"Question: {question}\n",
"Answer:\"\"\"\n",
"prompt = PromptTemplate(input_variables=[\"question\"], template=prompt_template)\n",
"llm_chain = LLMChain(llm=llm, prompt=prompt)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7f7e2b86",
"metadata": {},
"outputs": [],
"source": [
"embeddings = HypotheticalDocumentEmbedder(llm_chain=llm_chain, base_embeddings=base_embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6dd83424",
"metadata": {},
"outputs": [],
"source": [
"result = embeddings.embed_query(\"Where is the Taj Mahal?\")"
]
},
{
"cell_type": "markdown",
"id": "31388123",
"metadata": {},
"source": [
"## Using HyDE\n",
"Now that we have HyDE, we can use it as we would any other embedding class! Here is using it to find similar passages in the state of the union example."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "97719b29",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import FAISS\n",
"\n",
"with open('../state_of_the_union.txt') as f:\n",
" state_of_the_union = f.read()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"texts = text_splitter.split_text(state_of_the_union)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bfcfc039",
"metadata": {},
"outputs": [],
"source": [
"docsearch = FAISS.from_texts(texts, embeddings)\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = docsearch.similarity_search(query)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "632af7f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
"\n",
"We cannot let this happen. \n",
"\n",
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
"\n",
"Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
"\n",
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
"\n",
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence. \n"
]
}
],
"source": [
"print(docs[0].page_content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9e57b93",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "e82c4685",
"metadata": {},
"outputs": [],
@ -42,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "79ff6737",
"metadata": {},
"outputs": [],
@ -57,17 +57,17 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "38547666",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \\n\\nLast year COVID-19 kept us apart. This year we are finally together again. \\n\\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \\n\\nWith a duty to one another to the American people to the Constitution. \\n\\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \\n\\nSix days ago, Russias Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \\n\\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \\n\\nHe met the Ukrainian people. \\n\\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. '"
"'Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \\n\\nLast year COVID-19 kept us apart. This year we are finally together again. \\n\\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \\n\\nWith a duty to one another to the American people to the Constitution. \\n\\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \\n\\nSix days ago, Russias Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \\n\\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \\n\\nHe met the Ukrainian people. \\n\\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \\n\\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. '"
]
},
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}

@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra
import langchain
from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM
from langchain.llms.base import BaseLLM, LLMResult
from langchain.prompts.base import BasePromptTemplate
@ -51,8 +51,8 @@ class LLMChain(Chain, BaseModel):
"""
return [self.output_key]
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
@ -68,6 +68,11 @@ class LLMChain(Chain, BaseModel):
)
prompts.append(prompt)
response = self.llm.generate(prompts, stop=stop)
return response
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = self.generate(input_list)
outputs = []
for generation in response.generations:
# Get the text of the top generated string.

@ -2,6 +2,7 @@
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
from langchain.embeddings.openai import OpenAIEmbeddings
__all__ = [
@ -9,4 +10,5 @@ __all__ = [
"HuggingFaceEmbeddings",
"CohereEmbeddings",
"HuggingFaceHubEmbeddings",
"HypotheticalDocumentEmbedder",
]

@ -0,0 +1,4 @@
"""Hypothetical Document Embeddings.
https://arxiv.org/abs/2212.10496
"""

@ -0,0 +1,56 @@
"""Hypothetical Document Embeddings.
https://arxiv.org/abs/2212.10496
"""
from __future__ import annotations
from typing import List
import numpy as np
from pydantic import BaseModel, Extra
from langchain.chains.llm import LLMChain
from langchain.embeddings.base import Embeddings
from langchain.embeddings.hyde.prompts import PROMPT_MAP
from langchain.llms.base import BaseLLM
class HypotheticalDocumentEmbedder(Embeddings, BaseModel):
"""Generate hypothetical document for query, and then embed that.
Based on https://arxiv.org/abs/2212.10496
"""
base_embeddings: Embeddings
llm_chain: LLMChain
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call the base embeddings."""
return self.base_embeddings.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Generate a hypothetical document and embedded it."""
var_name = self.llm_chain.input_keys[0]
result = self.llm_chain.generate([{var_name: text}])
documents = [generation.text for generation in result.generations[0]]
embeddings = self.embed_documents(documents)
return self.combine_embeddings(embeddings)
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
"""Combine embeddings into final embeddings."""
return list(np.array(embeddings).mean(axis=0))
@classmethod
def from_llm(
cls, llm: BaseLLM, base_embeddings: Embeddings, prompt_key: str
) -> HypotheticalDocumentEmbedder:
"""Load and use LLMChain for a specific prompt key."""
prompt = PROMPT_MAP[prompt_key]
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain)

@ -0,0 +1,47 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
web_search_template = """Please write a passage to answer the question
Question: {QUESTION}
Passage:"""
web_search = PromptTemplate(template=web_search_template, input_variables=["QUESTION"])
sci_fact_template = """Please write a scientific paper passage to support/refute the claim
Claim: {Claim}
Passage:"""
sci_fact = PromptTemplate(template=sci_fact_template, input_variables=["Claim"])
arguana_template = """Please write a counter argument for the passage
Passage: {PASSAGE}
Counter Argument:"""
arguana = PromptTemplate(template=arguana_template, input_variables=["PASSAGE"])
trec_covid_template = """Please write a scientific paper passage to answer the question
Question: {QUESTION}
Passage:"""
trec_covid = PromptTemplate(template=trec_covid_template, input_variables=["QUESTION"])
fiqa_template = """Please write a financial article passage to answer the question
Question: {QUESTION}
Passage:"""
fiqa = PromptTemplate(template=fiqa_template, input_variables=["QUESTION"])
dbpedia_entity_template = """Please write a passage to answer the question.
Question: {QUESTION}
Passage:"""
dbpedia_entity = PromptTemplate(
template=dbpedia_entity_template, input_variables=["QUESTION"]
)
trec_news_template = """Please write a news passage about the topic.
Topic: {TOPIC}
Passage:"""
trec_news = PromptTemplate(template=trec_news_template, input_variables=["TOPIC"])
mr_tydi_template = """Please write a passage in Swahili/Korean/Japanese/Bengali to answer the question in detail.
Question: {QUESTION}
Passage:"""
mr_tydi = PromptTemplate(template=mr_tydi_template, input_variables=["QUESTION"])
PROMPT_MAP = {
"web_search": web_search,
"sci_fact": sci_fact,
"arguana": arguana,
"trec_covid": trec_covid,
"fiqa": fiqa,
"dbpedia_entity": dbpedia_entity,
"trec_news": trec_news,
"mr_tydi": mr_tydi,
}

@ -0,0 +1,57 @@
"""Test HyDE."""
from typing import List, Optional
import numpy as np
from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
from langchain.embeddings.hyde.prompts import PROMPT_MAP
from langchain.llms.base import BaseLLM, LLMResult
from langchain.schema import Generation
class FakeEmbeddings(Embeddings):
"""Fake embedding class for tests."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return random floats."""
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
def embed_query(self, text: str) -> List[float]:
"""Return random floats."""
return list(np.random.uniform(0, 1, 10))
class FakeLLM(BaseLLM, BaseModel):
"""Fake LLM wrapper for testing purposes."""
n: int = 1
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake"
def test_hyde_from_llm() -> None:
"""Test loading HyDE from all prompts."""
for key in PROMPT_MAP:
embedding = HypotheticalDocumentEmbedder.from_llm(
FakeLLM(), FakeEmbeddings(), key
)
embedding.embed_query("foo")
def test_hyde_from_llm_with_multiple_n() -> None:
"""Test loading HyDE from all prompts."""
for key in PROMPT_MAP:
embedding = HypotheticalDocumentEmbedder.from_llm(
FakeLLM(n=8), FakeEmbeddings(), key
)
embedding.embed_query("foo")
Loading…
Cancel
Save