From 0b542a970620e6e731d3bfbc37d3961362baef8e Mon Sep 17 00:00:00 2001 From: Jeff Vestal <53237856+jeffvestal@users.noreply.github.com> Date: Tue, 23 May 2023 16:50:33 -0500 Subject: [PATCH] Add ElasticsearchEmbeddings class for generating embeddings using Elasticsearch models (#3401) This PR introduces a new module, `elasticsearch_embeddings.py`, which provides a wrapper around Elasticsearch embedding models. The new ElasticsearchEmbeddings class allows users to generate embeddings for documents and query texts using a [model deployed in an Elasticsearch cluster](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-model-ref.html#ml-nlp-model-ref-text-embedding). ### Main features: 1. The ElasticsearchEmbeddings class initializes with an Elasticsearch connection object and a model_id, providing an interface to interact with the Elasticsearch ML client through [infer_trained_model](https://elasticsearch-py.readthedocs.io/en/v8.7.0/api.html?highlight=trained%20model%20infer#elasticsearch.client.MlClient.infer_trained_model) . 2. The `embed_documents()` method generates embeddings for a list of documents, and the `embed_query()` method generates an embedding for a single query text. 3. The class supports custom input text field names in case the deployed model expects a different field name than the default `text_field`. 4. The implementation is compatible with any model deployed in Elasticsearch that generates embeddings as output. ### Benefits: 1. Simplifies the process of generating embeddings using Elasticsearch models. 2. Provides a clean and intuitive interface to interact with the Elasticsearch ML client. 3. Allows users to easily integrate Elasticsearch-generated embeddings. Related issue https://github.com/hwchase17/langchain/issues/3400 --------- Co-authored-by: Dev 2049 --- .../examples/elasticsearch.ipynb | 137 ++++++++++++++++ langchain/embeddings/__init__.py | 2 + langchain/embeddings/elasticsearch.py | 155 ++++++++++++++++++ .../embeddings/test_elasticsearch.py | 30 ++++ 4 files changed, 324 insertions(+) create mode 100644 docs/modules/models/text_embedding/examples/elasticsearch.ipynb create mode 100644 langchain/embeddings/elasticsearch.py create mode 100644 tests/integration_tests/embeddings/test_elasticsearch.py diff --git a/docs/modules/models/text_embedding/examples/elasticsearch.ipynb b/docs/modules/models/text_embedding/examples/elasticsearch.ipynb new file mode 100644 index 00000000..6b025652 --- /dev/null +++ b/docs/modules/models/text_embedding/examples/elasticsearch.ipynb @@ -0,0 +1,137 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "!pip install elasticsearch langchain" + ], + "metadata": { + "id": "OOiBBjc0Kd-6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%env ES_CLOUDID=\n", + "%env ES_USER=\n", + "%env ES_PASS=\n", + "\n", + "es_cloudid = os.environ.get(\"ES_CLOUDID\")\n", + "es_user = os.environ.get(\"ES_USER\")\n", + "es_pass = os.environ.get(\"ES_PASS\")" + ], + "metadata": { + "id": "Wr8unljAKdCh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Connect to Elasticsearch\n", + "es_connection = Elasticsearch(cloud_id=es_cloudid, basic_auth=(es_user, es_pass))" + ], + "metadata": { + "id": "YIDsrBqTKs85" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Define the model ID and input field name (if different from default)\n", + "model_id = \"your_model_id\"\n", + "input_field = \"your_input_field\" # Optional, only if different from 'text_field'" + ], + "metadata": { + "id": "sfFhnFHOKvbM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Initialize the ElasticsearchEmbeddings instance\n", + "embeddings_generator = ElasticsearchEmbeddings(es_connection, model_id, input_field)" + ], + "metadata": { + "id": "V-pCgqLCKvYs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Generate embeddings for a list of documents\n", + "documents = [\n", + " \"This is an example document.\",\n", + " \"Another example document to generate embeddings for.\",\n", + " ]\n", + "document_embeddings = embeddings_generator.embed_documents(documents)" + ], + "metadata": { + "id": "lJg2iRDWKvV_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Print the generated document embeddings\n", + "for i, doc_embedding in enumerate(document_embeddings):\n", + " print(f\"Embedding for document {i + 1}: {doc_embedding}\")" + ], + "metadata": { + "id": "R3sYQlh3KvTQ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Generate an embedding for a single query text\n", + "query_text = \"What is the meaning of life?\"\n", + "query_embedding = embeddings_generator.embed_query(query_text)" + ], + "metadata": { + "id": "n0un5Vc0KvQd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Print the generated query embedding\n", + "print(f\"Embedding for query: {query_embedding}\")" + ], + "metadata": { + "id": "PANph6pmKvLD" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index 1e123f12..5ba96520 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -7,6 +7,7 @@ from langchain.embeddings.aleph_alpha import ( AlephAlphaSymmetricSemanticEmbedding, ) from langchain.embeddings.cohere import CohereEmbeddings +from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.fake import FakeEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.huggingface import ( @@ -32,6 +33,7 @@ __all__ = [ "OpenAIEmbeddings", "HuggingFaceEmbeddings", "CohereEmbeddings", + "ElasticsearchEmbeddings", "JinaEmbeddings", "LlamaCppEmbeddings", "HuggingFaceHubEmbeddings", diff --git a/langchain/embeddings/elasticsearch.py b/langchain/embeddings/elasticsearch.py new file mode 100644 index 00000000..78d7dec0 --- /dev/null +++ b/langchain/embeddings/elasticsearch.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +from langchain.utils import get_from_env + +if TYPE_CHECKING: + from elasticsearch.client import MlClient + +from langchain.embeddings.base import Embeddings + + +class ElasticsearchEmbeddings(Embeddings): + """ + Wrapper around Elasticsearch embedding models. + + This class provides an interface to generate embeddings using a model deployed + in an Elasticsearch cluster. It requires an Elasticsearch connection object + and the model_id of the model deployed in the cluster. + + In Elasticsearch you need to have an embedding model loaded and deployed. + - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html + - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html + """ # noqa: E501 + + def __init__( + self, + client: MlClient, + model_id: str, + *, + input_field: str = "text_field", + ): + """ + Initialize the ElasticsearchEmbeddings instance. + + Args: + client (MlClient): An Elasticsearch ML client object. + model_id (str): The model_id of the model deployed in the Elasticsearch + cluster. + input_field (str): The name of the key for the input text field in the + document. Defaults to 'text_field'. + """ + self.client = client + self.model_id = model_id + self.input_field = input_field + + @classmethod + def from_credentials( + cls, + model_id: str, + *, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_password: Optional[str] = None, + input_field: str = "text_field", + ) -> ElasticsearchEmbeddings: + """Instantiate embeddings from Elasticsearch credentials. + + Args: + model_id (str): The model_id of the model deployed in the Elasticsearch + cluster. + input_field (str): The name of the key for the input text field in the + document. Defaults to 'text_field'. + es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to. + es_user: (str, optional): Elasticsearch username. + es_password: (str, optional): Elasticsearch password. + + Example Usage: + from langchain.embeddings import ElasticsearchEmbeddings + + # Define the model ID and input field name (if different from default) + model_id = "your_model_id" + # Optional, only if different from 'text_field' + input_field = "your_input_field" + + # Credentials can be passed in two ways. Either set the env vars + # ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically pulled + # in, or pass them in directly as kwargs. + embeddings = ElasticsearchEmbeddings.from_credentials( + model_id, + input_field=input_field, + # es_cloud_id="foo", + # es_user="bar", + # es_password="baz", + ) + + documents = [ + "This is an example document.", + "Another example document to generate embeddings for.", + ] + embeddings_generator.embed_documents(documents) + """ + try: + from elasticsearch import Elasticsearch + from elasticsearch.client import MlClient + except ImportError: + raise ImportError( + "elasticsearch package not found, please install with 'pip install " + "elasticsearch'" + ) + + es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID") + es_user = es_user or get_from_env("es_user", "ES_USER") + es_password = es_password or get_from_env("es_password", "ES_PASSWORD") + + # Connect to Elasticsearch + es_connection = Elasticsearch( + cloud_id=es_cloud_id, basic_auth=(es_user, es_password) + ) + client = MlClient(es_connection) + return cls(client, model_id, input_field=input_field) + + def _embedding_func(self, texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for the given texts using the Elasticsearch model. + + Args: + texts (List[str]): A list of text strings to generate embeddings for. + + Returns: + List[List[float]]: A list of embeddings, one for each text in the input + list. + """ + response = self.client.infer_trained_model( + model_id=self.model_id, docs=[{self.input_field: text} for text in texts] + ) + + embeddings = [doc["predicted_value"] for doc in response["inference_results"]] + return embeddings + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for a list of documents. + + Args: + texts (List[str]): A list of document text strings to generate embeddings + for. + + Returns: + List[List[float]]: A list of embeddings, one for each document in the input + list. + """ + return self._embedding_func(texts) + + def embed_query(self, text: str) -> List[float]: + """ + Generate an embedding for a single query text. + + Args: + text (str): The query text to generate an embedding for. + + Returns: + List[float]: The embedding for the input query text. + """ + return self._embedding_func([text])[0] diff --git a/tests/integration_tests/embeddings/test_elasticsearch.py b/tests/integration_tests/embeddings/test_elasticsearch.py new file mode 100644 index 00000000..2c016831 --- /dev/null +++ b/tests/integration_tests/embeddings/test_elasticsearch.py @@ -0,0 +1,30 @@ +"""Test elasticsearch_embeddings embeddings.""" + +import pytest + +from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings + + +@pytest.fixture +def model_id() -> str: + # Replace with your actual model_id + return "your_model_id" + + +def test_elasticsearch_embedding_documents(model_id: str) -> None: + """Test Elasticsearch embedding documents.""" + documents = ["foo bar", "bar foo", "foo"] + embedding = ElasticsearchEmbeddings.from_credentials(model_id) + output = embedding.embed_documents(documents) + assert len(output) == 3 + assert len(output[0]) == 768 # Change 768 to the expected embedding size + assert len(output[1]) == 768 # Change 768 to the expected embedding size + assert len(output[2]) == 768 # Change 768 to the expected embedding size + + +def test_elasticsearch_embedding_query(model_id: str) -> None: + """Test Elasticsearch embedding query.""" + document = "foo bar" + embedding = ElasticsearchEmbeddings.from_credentials(model_id) + output = embedding.embed_query(document) + assert len(output) == 768 # Change 768 to the expected embedding size