From d564308e0f459557cc6b779cbc171fa4a9bcef27 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 2 Feb 2023 08:44:02 -0800 Subject: [PATCH] rfc: instruct embeddings (#811) Co-authored-by: seanaedmiston --- .../combine_docs_examples/embeddings.ipynb | 60 +++++++++++++++- langchain/embeddings/__init__.py | 6 +- langchain/embeddings/huggingface.py | 70 +++++++++++++++++++ .../embeddings/test_huggingface.py | 22 +++++- 4 files changed, 155 insertions(+), 3 deletions(-) diff --git a/docs/modules/utils/combine_docs_examples/embeddings.ipynb b/docs/modules/utils/combine_docs_examples/embeddings.ipynb index fde4f4d197..9ddae5ae97 100644 --- a/docs/modules/utils/combine_docs_examples/embeddings.ipynb +++ b/docs/modules/utils/combine_docs_examples/embeddings.ipynb @@ -255,10 +255,68 @@ "query_result = embeddings.embed_query(text)" ] }, + { + "cell_type": "markdown", + "id": "59428e05", + "metadata": {}, + "source": [ + "## InstructEmbeddings\n", + "Let's load the HuggingFace instruct Embeddings class." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "92c5b61e", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings import HuggingFaceInstructEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "062547b9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "load INSTRUCTOR_Transformer\n", + "max_seq_length 512\n" + ] + } + ], + "source": [ + "embeddings = HuggingFaceInstructEmbeddings(query_instruction=\"Represent the query for retrieval: \")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e1dcc4bd", + "metadata": {}, + "outputs": [], + "source": [ + "text = \"This is a test document.\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "90f0db94", + "metadata": {}, + "outputs": [], + "source": [ + "query_result = embeddings.embed_query(text)" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "90f0db94", + "id": "a961cdb5", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index c66434e437..ee981a64ee 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -3,7 +3,10 @@ import logging from typing import Any from langchain.embeddings.cohere import CohereEmbeddings -from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.huggingface import ( + HuggingFaceEmbeddings, + HuggingFaceInstructEmbeddings, +) from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings @@ -16,6 +19,7 @@ __all__ = [ "CohereEmbeddings", "HuggingFaceHubEmbeddings", "TensorflowHubEmbeddings", + "HuggingFaceInstructEmbeddings", ] diff --git a/langchain/embeddings/huggingface.py b/langchain/embeddings/huggingface.py index 98f9986ad2..095a2c9173 100644 --- a/langchain/embeddings/huggingface.py +++ b/langchain/embeddings/huggingface.py @@ -6,6 +6,11 @@ from pydantic import BaseModel, Extra from langchain.embeddings.base import Embeddings DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" +DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" +DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: " +DEFAULT_QUERY_INSTRUCTION = ( + "Represent the question for retrieving supporting documents: " +) class HuggingFaceEmbeddings(BaseModel, Embeddings): @@ -68,3 +73,68 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): text = text.replace("\n", " ") embedding = self.client.encode(text) return embedding.tolist() + + +class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): + """Wrapper around sentence_transformers embedding models. + + To use, you should have the ``sentence_transformers`` python package installed. + + Example: + .. code-block:: python + + from langchain.embeddings import HuggingFaceInstructEmbeddings + model_name = "hkunlp/instructor-large" + hf = HuggingFaceInstructEmbeddings(model_name=model_name) + """ + + client: Any #: :meta private: + model_name: str = DEFAULT_INSTRUCT_MODEL + """Model name to use.""" + embed_instruction: str = DEFAULT_EMBED_INSTRUCTION + """Instruction to use for embedding documents.""" + query_instruction: str = DEFAULT_QUERY_INSTRUCTION + """Instruction to use for embedding query.""" + + def __init__(self, **kwargs: Any): + """Initialize the sentence_transformer.""" + super().__init__(**kwargs) + try: + from InstructorEmbedding import INSTRUCTOR + + self.client = INSTRUCTOR(self.model_name) + except ImportError as e: + raise ValueError("Dependencies for InstructorEmbedding not found.") from e + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace instruct model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + instruction_pairs = [] + for text in texts: + instruction_pairs.append([self.embed_instruction, text]) + embeddings = self.client.encode(instruction_pairs) + return embeddings.tolist() + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace instruct model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + instruction_pair = [self.query_instruction, text] + embedding = self.client.encode([instruction_pair])[0] + return embedding.tolist() diff --git a/tests/integration_tests/embeddings/test_huggingface.py b/tests/integration_tests/embeddings/test_huggingface.py index e71fbb0066..4c941580c6 100644 --- a/tests/integration_tests/embeddings/test_huggingface.py +++ b/tests/integration_tests/embeddings/test_huggingface.py @@ -1,7 +1,10 @@ """Test huggingface embeddings.""" import unittest -from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.huggingface import ( + HuggingFaceEmbeddings, + HuggingFaceInstructEmbeddings, +) @unittest.skip("This test causes a segfault.") @@ -21,3 +24,20 @@ def test_huggingface_embedding_query() -> None: embedding = HuggingFaceEmbeddings() output = embedding.embed_query(document) assert len(output) == 768 + + +def test_huggingface_instructor_embedding_documents() -> None: + """Test huggingface embeddings.""" + documents = ["foo bar"] + embedding = HuggingFaceInstructEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 768 + + +def test_huggingface_instructor_embedding_query() -> None: + """Test huggingface embeddings.""" + query = "foo bar" + embedding = HuggingFaceInstructEmbeddings() + output = embedding.embed_query(query) + assert len(output) == 768