mirror of https://github.com/hwchase17/langchain
Harrison/pinecone hybrid (#2405)
parent
2b975de94d
commit
41832042cc
@ -0,0 +1,254 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ab66dd43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Pinecone Hybrid Search\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use a retriever that under the hood uses Pinecone and Hybrid Search.\n",
|
||||
"\n",
|
||||
"The logic of this retriever is largely taken from [this blog post](https://www.pinecone.io/learn/hybrid-search-intro/)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "393ac030",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.retrievers import PineconeHybridSearchRetriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aaf80e7f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup Pinecone"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "15390796",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pinecone # !pip install pinecone-client\n",
|
||||
"\n",
|
||||
"pinecone.init(\n",
|
||||
" api_key=\"...\", # API key here\n",
|
||||
" environment=\"...\" # find next to api key in console\n",
|
||||
")\n",
|
||||
"# choose a name for your index\n",
|
||||
"index_name = \"...\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "95d5d7f9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You should only have to do this part once."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cfa3a8d8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create the index\n",
|
||||
"pinecone.create_index(\n",
|
||||
" name = index_name,\n",
|
||||
" dimension = 1536, # dimensionality of dense model\n",
|
||||
" metric = \"dotproduct\",\n",
|
||||
" pod_type = \"s1\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e01549af",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now that its created, we can use it"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "bcb3c8c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"index = pinecone.Index(index_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dbc025d6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Get embeddings and tokenizers\n",
|
||||
"\n",
|
||||
"Embeddings are used for the dense vectors, tokenizer is used for the sparse vector"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "2f63c911",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "c3f030e5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import BertTokenizerFast # !pip install transformers\n",
|
||||
"\n",
|
||||
"# load bert tokenizer from huggingface\n",
|
||||
"tokenizer = BertTokenizerFast.from_pretrained(\n",
|
||||
" 'bert-base-uncased'\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5462801e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load Retriever\n",
|
||||
"\n",
|
||||
"We can now construct the retriever!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "ac77d835",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = PineconeHybridSearchRetriever(embeddings=embeddings, index=index, tokenizer=tokenizer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c518c42",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Add texts (if necessary)\n",
|
||||
"\n",
|
||||
"We can optionally add texts to the retriever (if they aren't already in there)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "98b1c017",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "4d6f3ee7ca754d07a1a18d100d99e0cd",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever.add_texts([\"foo\", \"bar\", \"world\", \"hello\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "08437fa2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use Retriever\n",
|
||||
"\n",
|
||||
"We can now use the retriever!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "c0455218",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = retriever.get_relevant_documents(\"foo\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "7dfa5c29",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Document(page_content='foo', metadata={})"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "74bd9256",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -1,5 +1,11 @@
|
||||
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||
from langchain.retrievers.metal import MetalRetriever
|
||||
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
|
||||
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
|
||||
|
||||
__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever", "MetalRetriever"]
|
||||
__all__ = [
|
||||
"ChatGPTPluginRetriever",
|
||||
"RemoteLangChainRetriever",
|
||||
"PineconeHybridSearchRetriever",
|
||||
"MetalRetriever",
|
||||
]
|
||||
|
@ -0,0 +1,137 @@
|
||||
"""Taken from: https://www.pinecone.io/learn/hybrid-search-intro/"""
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
def build_dict(input_batch: List[List[int]]) -> List[Dict]:
|
||||
# store a batch of sparse embeddings
|
||||
sparse_emb = []
|
||||
# iterate through input batch
|
||||
for token_ids in input_batch:
|
||||
indices = []
|
||||
values = []
|
||||
# convert the input_ids list to a dictionary of key to frequency values
|
||||
d = dict(Counter(token_ids))
|
||||
for idx in d:
|
||||
indices.append(idx)
|
||||
values.append(d[idx])
|
||||
sparse_emb.append({"indices": indices, "values": values})
|
||||
# return sparse_emb list
|
||||
return sparse_emb
|
||||
|
||||
|
||||
def create_index(
|
||||
contexts: List[str], index: Any, embeddings: Embeddings, tokenizer: Any
|
||||
) -> None:
|
||||
batch_size = 32
|
||||
_iterator = range(0, len(contexts), batch_size)
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
_iterator = tqdm(_iterator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
for i in _iterator:
|
||||
# find end of batch
|
||||
i_end = min(i + batch_size, len(contexts))
|
||||
# extract batch
|
||||
context_batch = contexts[i:i_end]
|
||||
# create unique IDs
|
||||
ids = [str(x) for x in range(i, i_end)]
|
||||
# add context passages as metadata
|
||||
meta = [{"context": context} for context in context_batch]
|
||||
# create dense vectors
|
||||
dense_embeds = embeddings.embed_documents(context_batch)
|
||||
# create sparse vectors
|
||||
sparse_embeds = generate_sparse_vectors(context_batch, tokenizer)
|
||||
for s in sparse_embeds:
|
||||
s["values"] = [float(s1) for s1 in s["values"]]
|
||||
|
||||
vectors = []
|
||||
# loop through the data and create dictionaries for upserts
|
||||
for _id, sparse, dense, metadata in zip(ids, sparse_embeds, dense_embeds, meta):
|
||||
vectors.append(
|
||||
{
|
||||
"id": _id,
|
||||
"sparse_values": sparse,
|
||||
"values": dense,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
# upload the documents to the new hybrid index
|
||||
index.upsert(vectors)
|
||||
|
||||
|
||||
def generate_sparse_vectors(context_batch: List[str], tokenizer: Any) -> List[Dict]:
|
||||
# create batch of input_ids
|
||||
inputs = tokenizer(
|
||||
context_batch,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512, # special_tokens=False
|
||||
)["input_ids"]
|
||||
# create sparse dictionaries
|
||||
sparse_embeds = build_dict(inputs)
|
||||
return sparse_embeds
|
||||
|
||||
|
||||
def hybrid_scale(
|
||||
dense: List[float], sparse: Dict, alpha: float
|
||||
) -> Tuple[List[float], Dict]:
|
||||
# check alpha value is in range
|
||||
if alpha < 0 or alpha > 1:
|
||||
raise ValueError("Alpha must be between 0 and 1")
|
||||
# scale sparse and dense vectors to create hybrid search vecs
|
||||
hsparse = {
|
||||
"indices": sparse["indices"],
|
||||
"values": [v * (1 - alpha) for v in sparse["values"]],
|
||||
}
|
||||
hdense = [v * alpha for v in dense]
|
||||
return hdense, hsparse
|
||||
|
||||
|
||||
class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
||||
embeddings: Embeddings
|
||||
index: Any
|
||||
tokenizer: Any
|
||||
top_k: int = 4
|
||||
alpha: float = 0.5
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def add_texts(self, texts: List[str]) -> None:
|
||||
create_index(texts, self.index, self.embeddings, self.tokenizer)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
sparse_vec = generate_sparse_vectors([query], self.tokenizer)[0]
|
||||
# convert the question into a dense vector
|
||||
dense_vec = self.embeddings.embed_query(query)
|
||||
# scale alpha with hybrid_scale
|
||||
dense_vec, sparse_vec = hybrid_scale(dense_vec, sparse_vec, self.alpha)
|
||||
sparse_vec["values"] = [float(s1) for s1 in sparse_vec["values"]]
|
||||
# query pinecone with the query parameters
|
||||
result = self.index.query(
|
||||
vector=dense_vec,
|
||||
sparse_vector=sparse_vec,
|
||||
top_k=self.top_k,
|
||||
include_metadata=True,
|
||||
)
|
||||
final_result = []
|
||||
for res in result["matches"]:
|
||||
final_result.append(Document(page_content=res["metadata"]["context"]))
|
||||
# return search results as json
|
||||
return final_result
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue