mirror of https://github.com/hwchase17/langchain
(WIP) add HyDE (#393)
Co-authored-by: cameronccohen <cameron.c.cohen@gmail.com> Co-authored-by: Cameron Cohen <cameron.cohen@quantco.com>pull/398/head^2
parent
543db9c2df
commit
6b60c509ac
@ -0,0 +1,4 @@
|
|||||||
|
"""Hypothetical Document Embeddings.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2212.10496
|
||||||
|
"""
|
@ -0,0 +1,56 @@
|
|||||||
|
"""Hypothetical Document Embeddings.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2212.10496
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class HypotheticalDocumentEmbedder(Embeddings, BaseModel):
|
||||||
|
"""Generate hypothetical document for query, and then embed that.
|
||||||
|
|
||||||
|
Based on https://arxiv.org/abs/2212.10496
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_embeddings: Embeddings
|
||||||
|
llm_chain: LLMChain
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Call the base embeddings."""
|
||||||
|
return self.base_embeddings.embed_documents(texts)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Generate a hypothetical document and embedded it."""
|
||||||
|
var_name = self.llm_chain.input_keys[0]
|
||||||
|
result = self.llm_chain.generate([{var_name: text}])
|
||||||
|
documents = [generation.text for generation in result.generations[0]]
|
||||||
|
embeddings = self.embed_documents(documents)
|
||||||
|
return self.combine_embeddings(embeddings)
|
||||||
|
|
||||||
|
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
|
||||||
|
"""Combine embeddings into final embeddings."""
|
||||||
|
return list(np.array(embeddings).mean(axis=0))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(
|
||||||
|
cls, llm: BaseLLM, base_embeddings: Embeddings, prompt_key: str
|
||||||
|
) -> HypotheticalDocumentEmbedder:
|
||||||
|
"""Load and use LLMChain for a specific prompt key."""
|
||||||
|
prompt = PROMPT_MAP[prompt_key]
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain)
|
@ -0,0 +1,47 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
web_search_template = """Please write a passage to answer the question
|
||||||
|
Question: {QUESTION}
|
||||||
|
Passage:"""
|
||||||
|
web_search = PromptTemplate(template=web_search_template, input_variables=["QUESTION"])
|
||||||
|
sci_fact_template = """Please write a scientific paper passage to support/refute the claim
|
||||||
|
Claim: {Claim}
|
||||||
|
Passage:"""
|
||||||
|
sci_fact = PromptTemplate(template=sci_fact_template, input_variables=["Claim"])
|
||||||
|
arguana_template = """Please write a counter argument for the passage
|
||||||
|
Passage: {PASSAGE}
|
||||||
|
Counter Argument:"""
|
||||||
|
arguana = PromptTemplate(template=arguana_template, input_variables=["PASSAGE"])
|
||||||
|
trec_covid_template = """Please write a scientific paper passage to answer the question
|
||||||
|
Question: {QUESTION}
|
||||||
|
Passage:"""
|
||||||
|
trec_covid = PromptTemplate(template=trec_covid_template, input_variables=["QUESTION"])
|
||||||
|
fiqa_template = """Please write a financial article passage to answer the question
|
||||||
|
Question: {QUESTION}
|
||||||
|
Passage:"""
|
||||||
|
fiqa = PromptTemplate(template=fiqa_template, input_variables=["QUESTION"])
|
||||||
|
dbpedia_entity_template = """Please write a passage to answer the question.
|
||||||
|
Question: {QUESTION}
|
||||||
|
Passage:"""
|
||||||
|
dbpedia_entity = PromptTemplate(
|
||||||
|
template=dbpedia_entity_template, input_variables=["QUESTION"]
|
||||||
|
)
|
||||||
|
trec_news_template = """Please write a news passage about the topic.
|
||||||
|
Topic: {TOPIC}
|
||||||
|
Passage:"""
|
||||||
|
trec_news = PromptTemplate(template=trec_news_template, input_variables=["TOPIC"])
|
||||||
|
mr_tydi_template = """Please write a passage in Swahili/Korean/Japanese/Bengali to answer the question in detail.
|
||||||
|
Question: {QUESTION}
|
||||||
|
Passage:"""
|
||||||
|
mr_tydi = PromptTemplate(template=mr_tydi_template, input_variables=["QUESTION"])
|
||||||
|
PROMPT_MAP = {
|
||||||
|
"web_search": web_search,
|
||||||
|
"sci_fact": sci_fact,
|
||||||
|
"arguana": arguana,
|
||||||
|
"trec_covid": trec_covid,
|
||||||
|
"fiqa": fiqa,
|
||||||
|
"dbpedia_entity": dbpedia_entity,
|
||||||
|
"trec_news": trec_news,
|
||||||
|
"mr_tydi": mr_tydi,
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
"""Test HyDE."""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
|
||||||
|
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
||||||
|
from langchain.llms.base import BaseLLM, LLMResult
|
||||||
|
from langchain.schema import Generation
|
||||||
|
|
||||||
|
|
||||||
|
class FakeEmbeddings(Embeddings):
|
||||||
|
"""Fake embedding class for tests."""
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Return random floats."""
|
||||||
|
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Return random floats."""
|
||||||
|
return list(np.random.uniform(0, 1, 10))
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLLM(BaseLLM, BaseModel):
|
||||||
|
"""Fake LLM wrapper for testing purposes."""
|
||||||
|
|
||||||
|
n: int = 1
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "fake"
|
||||||
|
|
||||||
|
|
||||||
|
def test_hyde_from_llm() -> None:
|
||||||
|
"""Test loading HyDE from all prompts."""
|
||||||
|
for key in PROMPT_MAP:
|
||||||
|
embedding = HypotheticalDocumentEmbedder.from_llm(
|
||||||
|
FakeLLM(), FakeEmbeddings(), key
|
||||||
|
)
|
||||||
|
embedding.embed_query("foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_hyde_from_llm_with_multiple_n() -> None:
|
||||||
|
"""Test loading HyDE from all prompts."""
|
||||||
|
for key in PROMPT_MAP:
|
||||||
|
embedding = HypotheticalDocumentEmbedder.from_llm(
|
||||||
|
FakeLLM(n=8), FakeEmbeddings(), key
|
||||||
|
)
|
||||||
|
embedding.embed_query("foo")
|
Loading…
Reference in New Issue