mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/pinecone hybrid (#2405)
This commit is contained in:
parent
2b975de94d
commit
41832042cc
@ -0,0 +1,254 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ab66dd43",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Pinecone Hybrid Search\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use a retriever that under the hood uses Pinecone and Hybrid Search.\n",
|
||||||
|
"\n",
|
||||||
|
"The logic of this retriever is largely taken from [this blog post](https://www.pinecone.io/learn/hybrid-search-intro/)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "393ac030",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.retrievers import PineconeHybridSearchRetriever"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "aaf80e7f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Setup Pinecone"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "15390796",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pinecone # !pip install pinecone-client\n",
|
||||||
|
"\n",
|
||||||
|
"pinecone.init(\n",
|
||||||
|
" api_key=\"...\", # API key here\n",
|
||||||
|
" environment=\"...\" # find next to api key in console\n",
|
||||||
|
")\n",
|
||||||
|
"# choose a name for your index\n",
|
||||||
|
"index_name = \"...\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "95d5d7f9",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You should only have to do this part once."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "cfa3a8d8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# create the index\n",
|
||||||
|
"pinecone.create_index(\n",
|
||||||
|
" name = index_name,\n",
|
||||||
|
" dimension = 1536, # dimensionality of dense model\n",
|
||||||
|
" metric = \"dotproduct\",\n",
|
||||||
|
" pod_type = \"s1\"\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e01549af",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Now that its created, we can use it"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "bcb3c8c2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"index = pinecone.Index(index_name)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "dbc025d6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Get embeddings and tokenizers\n",
|
||||||
|
"\n",
|
||||||
|
"Embeddings are used for the dense vectors, tokenizer is used for the sparse vector"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "2f63c911",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||||
|
"embeddings = OpenAIEmbeddings()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "c3f030e5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from transformers import BertTokenizerFast # !pip install transformers\n",
|
||||||
|
"\n",
|
||||||
|
"# load bert tokenizer from huggingface\n",
|
||||||
|
"tokenizer = BertTokenizerFast.from_pretrained(\n",
|
||||||
|
" 'bert-base-uncased'\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5462801e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load Retriever\n",
|
||||||
|
"\n",
|
||||||
|
"We can now construct the retriever!"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "ac77d835",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = PineconeHybridSearchRetriever(embeddings=embeddings, index=index, tokenizer=tokenizer)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "1c518c42",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Add texts (if necessary)\n",
|
||||||
|
"\n",
|
||||||
|
"We can optionally add texts to the retriever (if they aren't already in there)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "98b1c017",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "4d6f3ee7ca754d07a1a18d100d99e0cd",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"retriever.add_texts([\"foo\", \"bar\", \"world\", \"hello\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "08437fa2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Use Retriever\n",
|
||||||
|
"\n",
|
||||||
|
"We can now use the retriever!"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "c0455218",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"result = retriever.get_relevant_documents(\"foo\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "7dfa5c29",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"Document(page_content='foo', metadata={})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"result[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "74bd9256",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -1,5 +1,11 @@
|
|||||||
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||||
from langchain.retrievers.metal import MetalRetriever
|
from langchain.retrievers.metal import MetalRetriever
|
||||||
|
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
|
||||||
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
|
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
|
||||||
|
|
||||||
__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever", "MetalRetriever"]
|
__all__ = [
|
||||||
|
"ChatGPTPluginRetriever",
|
||||||
|
"RemoteLangChainRetriever",
|
||||||
|
"PineconeHybridSearchRetriever",
|
||||||
|
"MetalRetriever",
|
||||||
|
]
|
||||||
|
137
langchain/retrievers/pinecone_hybrid_search.py
Normal file
137
langchain/retrievers/pinecone_hybrid_search.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"""Taken from: https://www.pinecone.io/learn/hybrid-search-intro/"""
|
||||||
|
from collections import Counter
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
|
def build_dict(input_batch: List[List[int]]) -> List[Dict]:
|
||||||
|
# store a batch of sparse embeddings
|
||||||
|
sparse_emb = []
|
||||||
|
# iterate through input batch
|
||||||
|
for token_ids in input_batch:
|
||||||
|
indices = []
|
||||||
|
values = []
|
||||||
|
# convert the input_ids list to a dictionary of key to frequency values
|
||||||
|
d = dict(Counter(token_ids))
|
||||||
|
for idx in d:
|
||||||
|
indices.append(idx)
|
||||||
|
values.append(d[idx])
|
||||||
|
sparse_emb.append({"indices": indices, "values": values})
|
||||||
|
# return sparse_emb list
|
||||||
|
return sparse_emb
|
||||||
|
|
||||||
|
|
||||||
|
def create_index(
|
||||||
|
contexts: List[str], index: Any, embeddings: Embeddings, tokenizer: Any
|
||||||
|
) -> None:
|
||||||
|
batch_size = 32
|
||||||
|
_iterator = range(0, len(contexts), batch_size)
|
||||||
|
try:
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
_iterator = tqdm(_iterator)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for i in _iterator:
|
||||||
|
# find end of batch
|
||||||
|
i_end = min(i + batch_size, len(contexts))
|
||||||
|
# extract batch
|
||||||
|
context_batch = contexts[i:i_end]
|
||||||
|
# create unique IDs
|
||||||
|
ids = [str(x) for x in range(i, i_end)]
|
||||||
|
# add context passages as metadata
|
||||||
|
meta = [{"context": context} for context in context_batch]
|
||||||
|
# create dense vectors
|
||||||
|
dense_embeds = embeddings.embed_documents(context_batch)
|
||||||
|
# create sparse vectors
|
||||||
|
sparse_embeds = generate_sparse_vectors(context_batch, tokenizer)
|
||||||
|
for s in sparse_embeds:
|
||||||
|
s["values"] = [float(s1) for s1 in s["values"]]
|
||||||
|
|
||||||
|
vectors = []
|
||||||
|
# loop through the data and create dictionaries for upserts
|
||||||
|
for _id, sparse, dense, metadata in zip(ids, sparse_embeds, dense_embeds, meta):
|
||||||
|
vectors.append(
|
||||||
|
{
|
||||||
|
"id": _id,
|
||||||
|
"sparse_values": sparse,
|
||||||
|
"values": dense,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# upload the documents to the new hybrid index
|
||||||
|
index.upsert(vectors)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sparse_vectors(context_batch: List[str], tokenizer: Any) -> List[Dict]:
|
||||||
|
# create batch of input_ids
|
||||||
|
inputs = tokenizer(
|
||||||
|
context_batch,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=512, # special_tokens=False
|
||||||
|
)["input_ids"]
|
||||||
|
# create sparse dictionaries
|
||||||
|
sparse_embeds = build_dict(inputs)
|
||||||
|
return sparse_embeds
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_scale(
|
||||||
|
dense: List[float], sparse: Dict, alpha: float
|
||||||
|
) -> Tuple[List[float], Dict]:
|
||||||
|
# check alpha value is in range
|
||||||
|
if alpha < 0 or alpha > 1:
|
||||||
|
raise ValueError("Alpha must be between 0 and 1")
|
||||||
|
# scale sparse and dense vectors to create hybrid search vecs
|
||||||
|
hsparse = {
|
||||||
|
"indices": sparse["indices"],
|
||||||
|
"values": [v * (1 - alpha) for v in sparse["values"]],
|
||||||
|
}
|
||||||
|
hdense = [v * alpha for v in dense]
|
||||||
|
return hdense, hsparse
|
||||||
|
|
||||||
|
|
||||||
|
class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
||||||
|
embeddings: Embeddings
|
||||||
|
index: Any
|
||||||
|
tokenizer: Any
|
||||||
|
top_k: int = 4
|
||||||
|
alpha: float = 0.5
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def add_texts(self, texts: List[str]) -> None:
|
||||||
|
create_index(texts, self.index, self.embeddings, self.tokenizer)
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
sparse_vec = generate_sparse_vectors([query], self.tokenizer)[0]
|
||||||
|
# convert the question into a dense vector
|
||||||
|
dense_vec = self.embeddings.embed_query(query)
|
||||||
|
# scale alpha with hybrid_scale
|
||||||
|
dense_vec, sparse_vec = hybrid_scale(dense_vec, sparse_vec, self.alpha)
|
||||||
|
sparse_vec["values"] = [float(s1) for s1 in sparse_vec["values"]]
|
||||||
|
# query pinecone with the query parameters
|
||||||
|
result = self.index.query(
|
||||||
|
vector=dense_vec,
|
||||||
|
sparse_vector=sparse_vec,
|
||||||
|
top_k=self.top_k,
|
||||||
|
include_metadata=True,
|
||||||
|
)
|
||||||
|
final_result = []
|
||||||
|
for res in result["matches"]:
|
||||||
|
final_result.append(Document(page_content=res["metadata"]["context"]))
|
||||||
|
# return search results as json
|
||||||
|
return final_result
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
raise NotImplementedError
|
Loading…
Reference in New Issue
Block a user