From 6b60c509acf785a564a530b8750a93c48d354876 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 21 Dec 2022 20:46:41 -0500 Subject: [PATCH] (WIP) add HyDE (#393) Co-authored-by: cameronccohen Co-authored-by: Cameron Cohen --- .../data_augmented_generation/hyde.ipynb | 242 ++++++++++++++++++ .../textsplitter.ipynb | 10 +- langchain/chains/llm.py | 11 +- langchain/embeddings/__init__.py | 2 + langchain/embeddings/hyde/__init__.py | 4 + langchain/embeddings/hyde/base.py | 56 ++++ langchain/embeddings/hyde/prompts.py | 47 ++++ tests/unit_tests/test_hyde.py | 57 +++++ 8 files changed, 421 insertions(+), 8 deletions(-) create mode 100644 docs/examples/data_augmented_generation/hyde.ipynb create mode 100644 langchain/embeddings/hyde/__init__.py create mode 100644 langchain/embeddings/hyde/base.py create mode 100644 langchain/embeddings/hyde/prompts.py create mode 100644 tests/unit_tests/test_hyde.py diff --git a/docs/examples/data_augmented_generation/hyde.ipynb b/docs/examples/data_augmented_generation/hyde.ipynb new file mode 100644 index 0000000000..fd963dd746 --- /dev/null +++ b/docs/examples/data_augmented_generation/hyde.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ccb74c9b", + "metadata": {}, + "source": [ + "# Hypothetical Document Embeddings\n", + "This notebook goes over how to use Hypothetical Document Embeddings (HyDE), as described in [this paper](https://arxiv.org/abs/2212.10496). \n", + "\n", + "At a high level, HyDE is an embedding technique that takes queries, generates a hypothetical answer, and then embeds that generated document and uses that as the final example. \n", + "\n", + "In order to use HyDE, we therefor need to provide a base embedding model, as well as an LLMChain that can be used to generate those documents. By default, the HyDE class comes with some default prompts to use (see the paper for more details on them), but we can also create our own." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "546e87ee", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.embeddings import OpenAIEmbeddings, HypotheticalDocumentEmbedder\n", + "from langchain.chains import LLMChain\n", + "from langchain.prompts import PromptTemplate" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c0ea895f", + "metadata": {}, + "outputs": [], + "source": [ + "base_embeddings = OpenAIEmbeddings()\n", + "llm = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "50729989", + "metadata": {}, + "outputs": [], + "source": [ + "# Load with `web_search` prompt\n", + "embeddings = HypotheticalDocumentEmbedder.from_llm(llm, base_embeddings, \"web_search\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3aa573d6", + "metadata": {}, + "outputs": [], + "source": [ + "# Now we can use it as any embedding class!\n", + "result = embeddings.embed_query(\"Where is the Taj Mahal?\")" + ] + }, + { + "cell_type": "markdown", + "id": "c7a0b556", + "metadata": {}, + "source": [ + "## Multiple generations\n", + "We can also generate multiple documents and then combine the embeddings for those. By default, we combine those by taking the average. We can do this by changing the LLM we use to generate documents to return multiple things." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "05da7060", + "metadata": {}, + "outputs": [], + "source": [ + "multi_llm = OpenAI(n=4, best_of=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9b1e12bd", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = HypotheticalDocumentEmbedder.from_llm(multi_llm, base_embeddings, \"web_search\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a60cd343", + "metadata": {}, + "outputs": [], + "source": [ + "result = embeddings.embed_query(\"Where is the Taj Mahal?\")" + ] + }, + { + "cell_type": "markdown", + "id": "1da90437", + "metadata": {}, + "source": [ + "## Using our own prompts\n", + "Besides using preconfigured prompts, we can also easily construct our own prompts and use those in the LLMChain that is generating the documents. This can be useful if we know the domain our queries will be in, as we can condition the prompt to generate text more similar to that.\n", + "\n", + "In the example below, let's condition it generate text about a state of the union address (because we will use that in the next example)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0b4a650f", + "metadata": {}, + "outputs": [], + "source": [ + "prompt_template = \"\"\"Please answer the user's question about the most recent state of the union address\n", + "Question: {question}\n", + "Answer:\"\"\"\n", + "prompt = PromptTemplate(input_variables=[\"question\"], template=prompt_template)\n", + "llm_chain = LLMChain(llm=llm, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7f7e2b86", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = HypotheticalDocumentEmbedder(llm_chain=llm_chain, base_embeddings=base_embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "6dd83424", + "metadata": {}, + "outputs": [], + "source": [ + "result = embeddings.embed_query(\"Where is the Taj Mahal?\")" + ] + }, + { + "cell_type": "markdown", + "id": "31388123", + "metadata": {}, + "source": [ + "## Using HyDE\n", + "Now that we have HyDE, we can use it as we would any other embedding class! Here is using it to find similar passages in the state of the union example." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "97719b29", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import FAISS\n", + "\n", + "with open('../state_of_the_union.txt') as f:\n", + " state_of_the_union = f.read()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "texts = text_splitter.split_text(state_of_the_union)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bfcfc039", + "metadata": {}, + "outputs": [], + "source": [ + "docsearch = FAISS.from_texts(texts, embeddings)\n", + "\n", + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs = docsearch.similarity_search(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "632af7f2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n", + "\n", + "We cannot let this happen. \n", + "\n", + "Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n", + "\n", + "Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n", + "\n", + "One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. \n" + ] + } + ], + "source": [ + "print(docs[0].page_content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9e57b93", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/data_augmented_generation/textsplitter.ipynb b/docs/examples/data_augmented_generation/textsplitter.ipynb index 77c5e04cb3..718faf3618 100644 --- a/docs/examples/data_augmented_generation/textsplitter.ipynb +++ b/docs/examples/data_augmented_generation/textsplitter.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "e82c4685", "metadata": {}, "outputs": [], @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "79ff6737", "metadata": {}, "outputs": [], @@ -57,17 +57,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "38547666", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \\n\\nLast year COVID-19 kept us apart. This year we are finally together again. \\n\\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \\n\\nWith a duty to one another to the American people to the Constitution. \\n\\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \\n\\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \\n\\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \\n\\nHe met the Ukrainian people. \\n\\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. '" + "'Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \\n\\nLast year COVID-19 kept us apart. This year we are finally together again. \\n\\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \\n\\nWith a duty to one another to the American people to the Constitution. \\n\\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \\n\\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \\n\\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \\n\\nHe met the Ukrainian people. \\n\\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \\n\\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. '" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 138b62a502..f25f542329 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra import langchain from langchain.chains.base import Chain -from langchain.llms.base import BaseLLM +from langchain.llms.base import BaseLLM, LLMResult from langchain.prompts.base import BasePromptTemplate @@ -51,8 +51,8 @@ class LLMChain(Chain, BaseModel): """ return [self.output_key] - def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: - """Utilize the LLM generate method for speed gains.""" + def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: + """Generate LLM result from inputs.""" stop = None if "stop" in input_list[0]: stop = input_list[0]["stop"] @@ -68,6 +68,11 @@ class LLMChain(Chain, BaseModel): ) prompts.append(prompt) response = self.llm.generate(prompts, stop=stop) + return response + + def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + """Utilize the LLM generate method for speed gains.""" + response = self.generate(input_list) outputs = [] for generation in response.generations: # Get the text of the top generated string. diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index 6a57deb136..5a2019e8bf 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -2,6 +2,7 @@ from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings +from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder from langchain.embeddings.openai import OpenAIEmbeddings __all__ = [ @@ -9,4 +10,5 @@ __all__ = [ "HuggingFaceEmbeddings", "CohereEmbeddings", "HuggingFaceHubEmbeddings", + "HypotheticalDocumentEmbedder", ] diff --git a/langchain/embeddings/hyde/__init__.py b/langchain/embeddings/hyde/__init__.py new file mode 100644 index 0000000000..946d0ab116 --- /dev/null +++ b/langchain/embeddings/hyde/__init__.py @@ -0,0 +1,4 @@ +"""Hypothetical Document Embeddings. + +https://arxiv.org/abs/2212.10496 +""" diff --git a/langchain/embeddings/hyde/base.py b/langchain/embeddings/hyde/base.py new file mode 100644 index 0000000000..dbad3535c2 --- /dev/null +++ b/langchain/embeddings/hyde/base.py @@ -0,0 +1,56 @@ +"""Hypothetical Document Embeddings. + +https://arxiv.org/abs/2212.10496 +""" +from __future__ import annotations + +from typing import List + +import numpy as np +from pydantic import BaseModel, Extra + +from langchain.chains.llm import LLMChain +from langchain.embeddings.base import Embeddings +from langchain.embeddings.hyde.prompts import PROMPT_MAP +from langchain.llms.base import BaseLLM + + +class HypotheticalDocumentEmbedder(Embeddings, BaseModel): + """Generate hypothetical document for query, and then embed that. + + Based on https://arxiv.org/abs/2212.10496 + """ + + base_embeddings: Embeddings + llm_chain: LLMChain + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call the base embeddings.""" + return self.base_embeddings.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + """Generate a hypothetical document and embedded it.""" + var_name = self.llm_chain.input_keys[0] + result = self.llm_chain.generate([{var_name: text}]) + documents = [generation.text for generation in result.generations[0]] + embeddings = self.embed_documents(documents) + return self.combine_embeddings(embeddings) + + def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: + """Combine embeddings into final embeddings.""" + return list(np.array(embeddings).mean(axis=0)) + + @classmethod + def from_llm( + cls, llm: BaseLLM, base_embeddings: Embeddings, prompt_key: str + ) -> HypotheticalDocumentEmbedder: + """Load and use LLMChain for a specific prompt key.""" + prompt = PROMPT_MAP[prompt_key] + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(base_embeddings=base_embeddings, llm_chain=llm_chain) diff --git a/langchain/embeddings/hyde/prompts.py b/langchain/embeddings/hyde/prompts.py new file mode 100644 index 0000000000..746cce3a1d --- /dev/null +++ b/langchain/embeddings/hyde/prompts.py @@ -0,0 +1,47 @@ +# flake8: noqa +from langchain.prompts.prompt import PromptTemplate + +web_search_template = """Please write a passage to answer the question +Question: {QUESTION} +Passage:""" +web_search = PromptTemplate(template=web_search_template, input_variables=["QUESTION"]) +sci_fact_template = """Please write a scientific paper passage to support/refute the claim +Claim: {Claim} +Passage:""" +sci_fact = PromptTemplate(template=sci_fact_template, input_variables=["Claim"]) +arguana_template = """Please write a counter argument for the passage +Passage: {PASSAGE} +Counter Argument:""" +arguana = PromptTemplate(template=arguana_template, input_variables=["PASSAGE"]) +trec_covid_template = """Please write a scientific paper passage to answer the question +Question: {QUESTION} +Passage:""" +trec_covid = PromptTemplate(template=trec_covid_template, input_variables=["QUESTION"]) +fiqa_template = """Please write a financial article passage to answer the question +Question: {QUESTION} +Passage:""" +fiqa = PromptTemplate(template=fiqa_template, input_variables=["QUESTION"]) +dbpedia_entity_template = """Please write a passage to answer the question. +Question: {QUESTION} +Passage:""" +dbpedia_entity = PromptTemplate( + template=dbpedia_entity_template, input_variables=["QUESTION"] +) +trec_news_template = """Please write a news passage about the topic. +Topic: {TOPIC} +Passage:""" +trec_news = PromptTemplate(template=trec_news_template, input_variables=["TOPIC"]) +mr_tydi_template = """Please write a passage in Swahili/Korean/Japanese/Bengali to answer the question in detail. +Question: {QUESTION} +Passage:""" +mr_tydi = PromptTemplate(template=mr_tydi_template, input_variables=["QUESTION"]) +PROMPT_MAP = { + "web_search": web_search, + "sci_fact": sci_fact, + "arguana": arguana, + "trec_covid": trec_covid, + "fiqa": fiqa, + "dbpedia_entity": dbpedia_entity, + "trec_news": trec_news, + "mr_tydi": mr_tydi, +} diff --git a/tests/unit_tests/test_hyde.py b/tests/unit_tests/test_hyde.py new file mode 100644 index 0000000000..91df0f3407 --- /dev/null +++ b/tests/unit_tests/test_hyde.py @@ -0,0 +1,57 @@ +"""Test HyDE.""" +from typing import List, Optional + +import numpy as np +from pydantic import BaseModel + +from langchain.embeddings.base import Embeddings +from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder +from langchain.embeddings.hyde.prompts import PROMPT_MAP +from langchain.llms.base import BaseLLM, LLMResult +from langchain.schema import Generation + + +class FakeEmbeddings(Embeddings): + """Fake embedding class for tests.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return random floats.""" + return [list(np.random.uniform(0, 1, 10)) for _ in range(10)] + + def embed_query(self, text: str) -> List[float]: + """Return random floats.""" + return list(np.random.uniform(0, 1, 10)) + + +class FakeLLM(BaseLLM, BaseModel): + """Fake LLM wrapper for testing purposes.""" + + n: int = 1 + + def _generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "fake" + + +def test_hyde_from_llm() -> None: + """Test loading HyDE from all prompts.""" + for key in PROMPT_MAP: + embedding = HypotheticalDocumentEmbedder.from_llm( + FakeLLM(), FakeEmbeddings(), key + ) + embedding.embed_query("foo") + + +def test_hyde_from_llm_with_multiple_n() -> None: + """Test loading HyDE from all prompts.""" + for key in PROMPT_MAP: + embedding = HypotheticalDocumentEmbedder.from_llm( + FakeLLM(n=8), FakeEmbeddings(), key + ) + embedding.embed_query("foo")