Harrison/pinecone hybrid (#2405)

This commit is contained in:
Harrison Chase 2023-04-04 14:09:57 -07:00 committed by GitHub
parent 2b975de94d
commit 41832042cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 398 additions and 1 deletions

View File

@ -0,0 +1,254 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ab66dd43",
"metadata": {},
"source": [
"# Pinecone Hybrid Search\n",
"\n",
"This notebook goes over how to use a retriever that under the hood uses Pinecone and Hybrid Search.\n",
"\n",
"The logic of this retriever is largely taken from [this blog post](https://www.pinecone.io/learn/hybrid-search-intro/)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "393ac030",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import PineconeHybridSearchRetriever"
]
},
{
"cell_type": "markdown",
"id": "aaf80e7f",
"metadata": {},
"source": [
"## Setup Pinecone"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "15390796",
"metadata": {},
"outputs": [],
"source": [
"import pinecone # !pip install pinecone-client\n",
"\n",
"pinecone.init(\n",
" api_key=\"...\", # API key here\n",
" environment=\"...\" # find next to api key in console\n",
")\n",
"# choose a name for your index\n",
"index_name = \"...\""
]
},
{
"cell_type": "markdown",
"id": "95d5d7f9",
"metadata": {},
"source": [
"You should only have to do this part once."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cfa3a8d8",
"metadata": {},
"outputs": [],
"source": [
"# create the index\n",
"pinecone.create_index(\n",
" name = index_name,\n",
" dimension = 1536, # dimensionality of dense model\n",
" metric = \"dotproduct\",\n",
" pod_type = \"s1\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "e01549af",
"metadata": {},
"source": [
"Now that its created, we can use it"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bcb3c8c2",
"metadata": {},
"outputs": [],
"source": [
"index = pinecone.Index(index_name)"
]
},
{
"cell_type": "markdown",
"id": "dbc025d6",
"metadata": {},
"source": [
"## Get embeddings and tokenizers\n",
"\n",
"Embeddings are used for the dense vectors, tokenizer is used for the sparse vector"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2f63c911",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import OpenAIEmbeddings\n",
"embeddings = OpenAIEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c3f030e5",
"metadata": {},
"outputs": [],
"source": [
"from transformers import BertTokenizerFast # !pip install transformers\n",
"\n",
"# load bert tokenizer from huggingface\n",
"tokenizer = BertTokenizerFast.from_pretrained(\n",
" 'bert-base-uncased'\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5462801e",
"metadata": {},
"source": [
"## Load Retriever\n",
"\n",
"We can now construct the retriever!"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ac77d835",
"metadata": {},
"outputs": [],
"source": [
"retriever = PineconeHybridSearchRetriever(embeddings=embeddings, index=index, tokenizer=tokenizer)"
]
},
{
"cell_type": "markdown",
"id": "1c518c42",
"metadata": {},
"source": [
"## Add texts (if necessary)\n",
"\n",
"We can optionally add texts to the retriever (if they aren't already in there)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "98b1c017",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4d6f3ee7ca754d07a1a18d100d99e0cd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"retriever.add_texts([\"foo\", \"bar\", \"world\", \"hello\"])"
]
},
{
"cell_type": "markdown",
"id": "08437fa2",
"metadata": {},
"source": [
"## Use Retriever\n",
"\n",
"We can now use the retriever!"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c0455218",
"metadata": {},
"outputs": [],
"source": [
"result = retriever.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7dfa5c29",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Document(page_content='foo', metadata={})"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74bd9256",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -1,5 +1,11 @@
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever", "MetalRetriever"] __all__ = [
"ChatGPTPluginRetriever",
"RemoteLangChainRetriever",
"PineconeHybridSearchRetriever",
"MetalRetriever",
]

View File

@ -0,0 +1,137 @@
"""Taken from: https://www.pinecone.io/learn/hybrid-search-intro/"""
from collections import Counter
from typing import Any, Dict, List, Tuple
from pydantic import BaseModel, Extra
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
def build_dict(input_batch: List[List[int]]) -> List[Dict]:
# store a batch of sparse embeddings
sparse_emb = []
# iterate through input batch
for token_ids in input_batch:
indices = []
values = []
# convert the input_ids list to a dictionary of key to frequency values
d = dict(Counter(token_ids))
for idx in d:
indices.append(idx)
values.append(d[idx])
sparse_emb.append({"indices": indices, "values": values})
# return sparse_emb list
return sparse_emb
def create_index(
contexts: List[str], index: Any, embeddings: Embeddings, tokenizer: Any
) -> None:
batch_size = 32
_iterator = range(0, len(contexts), batch_size)
try:
from tqdm.auto import tqdm
_iterator = tqdm(_iterator)
except ImportError:
pass
for i in _iterator:
# find end of batch
i_end = min(i + batch_size, len(contexts))
# extract batch
context_batch = contexts[i:i_end]
# create unique IDs
ids = [str(x) for x in range(i, i_end)]
# add context passages as metadata
meta = [{"context": context} for context in context_batch]
# create dense vectors
dense_embeds = embeddings.embed_documents(context_batch)
# create sparse vectors
sparse_embeds = generate_sparse_vectors(context_batch, tokenizer)
for s in sparse_embeds:
s["values"] = [float(s1) for s1 in s["values"]]
vectors = []
# loop through the data and create dictionaries for upserts
for _id, sparse, dense, metadata in zip(ids, sparse_embeds, dense_embeds, meta):
vectors.append(
{
"id": _id,
"sparse_values": sparse,
"values": dense,
"metadata": metadata,
}
)
# upload the documents to the new hybrid index
index.upsert(vectors)
def generate_sparse_vectors(context_batch: List[str], tokenizer: Any) -> List[Dict]:
# create batch of input_ids
inputs = tokenizer(
context_batch,
padding=True,
truncation=True,
max_length=512, # special_tokens=False
)["input_ids"]
# create sparse dictionaries
sparse_embeds = build_dict(inputs)
return sparse_embeds
def hybrid_scale(
dense: List[float], sparse: Dict, alpha: float
) -> Tuple[List[float], Dict]:
# check alpha value is in range
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
# scale sparse and dense vectors to create hybrid search vecs
hsparse = {
"indices": sparse["indices"],
"values": [v * (1 - alpha) for v in sparse["values"]],
}
hdense = [v * alpha for v in dense]
return hdense, hsparse
class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
embeddings: Embeddings
index: Any
tokenizer: Any
top_k: int = 4
alpha: float = 0.5
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def add_texts(self, texts: List[str]) -> None:
create_index(texts, self.index, self.embeddings, self.tokenizer)
def get_relevant_documents(self, query: str) -> List[Document]:
sparse_vec = generate_sparse_vectors([query], self.tokenizer)[0]
# convert the question into a dense vector
dense_vec = self.embeddings.embed_query(query)
# scale alpha with hybrid_scale
dense_vec, sparse_vec = hybrid_scale(dense_vec, sparse_vec, self.alpha)
sparse_vec["values"] = [float(s1) for s1 in sparse_vec["values"]]
# query pinecone with the query parameters
result = self.index.query(
vector=dense_vec,
sparse_vector=sparse_vec,
top_k=self.top_k,
include_metadata=True,
)
final_result = []
for res in result["matches"]:
final_result.append(Document(page_content=res["metadata"]["context"]))
# return search results as json
return final_result
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError