From 20f530e9c5c58000f1ad941d1f563d773c42bb1c Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Sun, 23 Apr 2023 18:25:20 -0700 Subject: [PATCH] Add Sentence Transformers Embeddings (#3409) Add embeddings based on the sentence transformers library. Add a notebook and integration tests. Co-authored-by: khimaros --- .../examples/sentence_transformers.ipynb | 120 ++++++++++++++++++ langchain/embeddings/__init__.py | 2 + langchain/embeddings/sentence_transformer.py | 63 +++++++++ pyproject.toml | 4 +- .../embeddings/test_sentence_transformer.py | 38 ++++++ 5 files changed, 226 insertions(+), 1 deletion(-) create mode 100644 docs/modules/models/text_embedding/examples/sentence_transformers.ipynb create mode 100644 langchain/embeddings/sentence_transformer.py create mode 100644 tests/integration_tests/embeddings/test_sentence_transformer.py diff --git a/docs/modules/models/text_embedding/examples/sentence_transformers.ipynb b/docs/modules/models/text_embedding/examples/sentence_transformers.ipynb new file mode 100644 index 00000000..eda1c7dd --- /dev/null +++ b/docs/modules/models/text_embedding/examples/sentence_transformers.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "ed47bb62", + "metadata": {}, + "source": [ + "# Sentence Transformers Embeddings\n", + "\n", + "Let's generate embeddings using the [SentenceTransformers](https://www.sbert.net/) integration. SentenceTransformers is a python package that can generate text and image embeddings, originating from [Sentence-BERT](https://arxiv.org/abs/1908.10084)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "06c9f47d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + } + ], + "source": [ + "!pip install sentence_transformers > /dev/null" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "861521a9", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings import SentenceTransformerEmbeddings " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ff9be586", + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = SentenceTransformerEmbeddings(model=\"all-MiniLM-L6-v2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d0a98ae9", + "metadata": {}, + "outputs": [], + "source": [ + "text = \"This is a test document.\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "5d6c682b", + "metadata": {}, + "outputs": [], + "source": [ + "query_result = embeddings.embed_query(text)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bb5e74c0", + "metadata": {}, + "outputs": [], + "source": [ + "doc_result = embeddings.embed_documents([text, \"This is not a test document.\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aaad49f8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + }, + "vscode": { + "interpreter": { + "hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index b46c9de4..edcd11ff 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -22,6 +22,7 @@ from langchain.embeddings.self_hosted_hugging_face import ( SelfHostedHuggingFaceEmbeddings, SelfHostedHuggingFaceInstructEmbeddings, ) +from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings logger = logging.getLogger(__name__) @@ -42,6 +43,7 @@ __all__ = [ "FakeEmbeddings", "AlephAlphaAsymmetricSemanticEmbedding", "AlephAlphaSymmetricSemanticEmbedding", + "SentenceTransformerEmbeddings", ] diff --git a/langchain/embeddings/sentence_transformer.py b/langchain/embeddings/sentence_transformer.py new file mode 100644 index 00000000..b3bba97e --- /dev/null +++ b/langchain/embeddings/sentence_transformer.py @@ -0,0 +1,63 @@ +"""Wrapper around sentence transformer embedding models.""" +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.embeddings.base import Embeddings + + +class SentenceTransformerEmbeddings(BaseModel, Embeddings): + embedding_function: Any #: :meta private: + + model: Optional[str] = Field("all-MiniLM-L6-v2", alias="model") + """Transformer model to use.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that sentence_transformers library is installed.""" + model = values["model"] + + try: + from sentence_transformers import SentenceTransformer + + values["embedding_function"] = SentenceTransformer(model) + except ImportError: + raise ModuleNotFoundError( + "Could not import sentence_transformers library. " + "Please install the sentence_transformers library to " + "use this embedding model: pip install sentence_transformers" + ) + except Exception: + raise NameError(f"Could not load SentenceTransformer model {model}.") + + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of documents using the SentenceTransformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + embeddings = self.embedding_function.encode( + texts, convert_to_numpy=True + ).tolist() + return [list(map(float, e)) for e in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Embed a query using the SentenceTransformer model. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + return self.embed_documents([text])[0] diff --git a/pyproject.toml b/pyproject.toml index 750f54f7..de48aaf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,7 @@ torch = "^1.0.0" chromadb = "^0.3.21" tiktoken = "^0.3.3" python-dotenv = "^1.0.0" +sentence-transformers = "^2" gptcache = "^0.1.9" promptlayer = "^0.1.80" @@ -144,7 +145,8 @@ llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifes qdrant = ["qdrant-client"] openai = ["openai"] cohere = ["cohere"] -all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect"] +embeddings = ["sentence-transformers"] +all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect"] [tool.ruff] select = [ diff --git a/tests/integration_tests/embeddings/test_sentence_transformer.py b/tests/integration_tests/embeddings/test_sentence_transformer.py new file mode 100644 index 00000000..ce253ef4 --- /dev/null +++ b/tests/integration_tests/embeddings/test_sentence_transformer.py @@ -0,0 +1,38 @@ +# flake8: noqa +"""Test sentence_transformer embeddings.""" + +from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings +from langchain.vectorstores import Chroma + + +def test_sentence_transformer_embedding_documents() -> None: + """Test sentence_transformer embeddings.""" + embedding = SentenceTransformerEmbeddings() + documents = ["foo bar"] + output = embedding.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 384 + + +def test_sentence_transformer_embedding_query() -> None: + """Test sentence_transformer embeddings.""" + embedding = SentenceTransformerEmbeddings() + query = "what the foo is a bar?" + query_vector = embedding.embed_query(query) + assert len(query_vector) == 384 + + +def test_sentence_transformer_db_query() -> None: + """Test sentence_transformer similarity search.""" + embedding = SentenceTransformerEmbeddings() + texts = [ + "we will foo your bar until you can't foo any more", + "the quick brown fox jumped over the lazy dog", + ] + query = "what the foo is a bar?" + query_vector = embedding.embed_query(query) + assert len(query_vector) == 384 + db = Chroma(embedding_function=embedding) + db.add_texts(texts) + docs = db.similarity_search_by_vector(query_vector, k=2) + assert docs[0].page_content == "we will foo your bar until you can't foo any more"