Harrison/elastic search (#2419)

This commit is contained in:
Harrison Chase 2023-04-04 21:29:06 -07:00 committed by GitHub
parent 659c67e896
commit af7f20fa42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 292 additions and 0 deletions

View File

@ -0,0 +1,164 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ab66dd43",
"metadata": {},
"source": [
"# ElasticSearch BM25\n",
"\n",
"This notebook goes over how to use a retriever that under the hood uses ElasticSearcha and BM25.\n",
"\n",
"For more information on the details of BM25 see [this blog post](https://www.elastic.co/blog/practical-bm25-part-2-the-bm25-algorithm-and-its-variables)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "393ac030",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import ElasticSearchBM25Retriever"
]
},
{
"cell_type": "markdown",
"id": "aaf80e7f",
"metadata": {},
"source": [
"## Create New Retriever"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bcb3c8c2",
"metadata": {},
"outputs": [],
"source": [
"elasticsearch_url=\"http://localhost:9200\"\n",
"retriever = ElasticSearchBM25Retriever.create(elasticsearch_url, \"langchain-index-3\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "b605284d",
"metadata": {},
"outputs": [],
"source": [
"# Alternatively, you can load an existing index\n",
"# import elasticsearch\n",
"# elasticsearch_url=\"http://localhost:9200\"\n",
"# retriever = ElasticSearchBM25Retriever(elasticsearch.Elasticsearch(elasticsearch_url), \"langchain-index\")"
]
},
{
"cell_type": "markdown",
"id": "1c518c42",
"metadata": {},
"source": [
"## Add texts (if necessary)\n",
"\n",
"We can optionally add texts to the retriever (if they aren't already in there)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "98b1c017",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['386c76c9-4355-4c12-aaeb-7b80054caf93',\n",
" 'fffd279c-a0c9-4158-a904-6e242c517c99',\n",
" '7f5528a3-18d0-43b0-894d-f6770a002219',\n",
" 'e2ef5e32-d5bd-44e2-b045-cfc5a8e0a0a1',\n",
" 'cce8ba48-e473-4235-bca2-2c8d65e73ccf']"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.add_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"])"
]
},
{
"cell_type": "markdown",
"id": "08437fa2",
"metadata": {},
"source": [
"## Use Retriever\n",
"\n",
"We can now use the retriever!"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "c0455218",
"metadata": {},
"outputs": [],
"source": [
"result = retriever.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "7dfa5c29",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='foo', metadata={}),\n",
" Document(page_content='foo bar', metadata={})]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74bd9256",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -1,4 +1,5 @@
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
@ -8,4 +9,5 @@ __all__ = [
"RemoteLangChainRetriever", "RemoteLangChainRetriever",
"PineconeHybridSearchRetriever", "PineconeHybridSearchRetriever",
"MetalRetriever", "MetalRetriever",
"ElasticSearchBM25Retriever",
] ]

View File

@ -0,0 +1,126 @@
"""Wrapper around Elasticsearch vector database."""
from __future__ import annotations
import uuid
from typing import Any, Iterable, List
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
class ElasticSearchBM25Retriever(BaseRetriever):
"""Wrapper around Elasticsearch using BM25 as a retrieval method.
To connect to an Elasticsearch instance that requires login credentials,
including Elastic Cloud, use the Elasticsearch URL format
https://username:password@es_host:9243. For example, to connect to Elastic
Cloud, create the Elasticsearch URL with the required authentication details and
pass it to the ElasticVectorSearch constructor as the named parameter
elasticsearch_url.
You can obtain your Elastic Cloud URL and login credentials by logging in to the
Elastic Cloud console at https://cloud.elastic.co, selecting your deployment, and
navigating to the "Deployments" page.
To obtain your Elastic Cloud password for the default "elastic" user:
1. Log in to the Elastic Cloud console at https://cloud.elastic.co
2. Go to "Security" > "Users"
3. Locate the "elastic" user and click "Edit"
4. Click "Reset password"
5. Follow the prompts to reset the password
The format for Elastic Cloud URLs is
https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
"""
def __init__(self, client: Any, index_name: str):
self.client = client
self.index_name = index_name
@classmethod
def create(
cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75
) -> ElasticSearchBM25Retriever:
from elasticsearch import Elasticsearch
# Create an Elasticsearch client instance
es = Elasticsearch(elasticsearch_url)
# Define the index settings and mappings
index_settings = {
"settings": {
"analysis": {"analyzer": {"default": {"type": "standard"}}},
"similarity": {
"custom_bm25": {
"type": "BM25",
"k1": k1,
"b": b,
}
},
},
"mappings": {
"properties": {
"content": {
"type": "text",
"similarity": "custom_bm25", # Use the custom BM25 similarity
}
}
},
}
# Create the index with the specified settings and mappings
es.indices.create(index=index_name, body=index_settings)
return cls(es, index_name)
def add_texts(
self,
texts: Iterable[str],
refresh_indices: bool = True,
) -> List[str]:
"""Run more texts through the embeddings and add to the retriver.
Args:
texts: Iterable of strings to add to the retriever.
refresh_indices: bool to refresh ElasticSearch indices
Returns:
List of ids from adding the texts into the retriever.
"""
try:
from elasticsearch.helpers import bulk
except ImportError:
raise ValueError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
requests = []
ids = []
for i, text in enumerate(texts):
_id = str(uuid.uuid4())
request = {
"_op_type": "index",
"_index": self.index_name,
"content": text,
"_id": _id,
}
ids.append(_id)
requests.append(request)
bulk(self.client, requests)
if refresh_indices:
self.client.indices.refresh(index=self.index_name)
return ids
def get_relevant_documents(self, query: str) -> List[Document]:
query_dict = {"query": {"match": {"content": query}}}
res = self.client.search(index=self.index_name, body=query_dict)
docs = []
for r in res["hits"]["hits"]:
docs.append(Document(page_content=r["_source"]["content"]))
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError