Add Sentence Transformers Embeddings (#3409)

Add embeddings based on the sentence transformers library.
Add a notebook and integration tests.

Co-authored-by: khimaros <me@khimaros.com>
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent 73bc70b4fa
commit 20f530e9c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,120 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "ed47bb62",
"metadata": {},
"source": [
"# Sentence Transformers Embeddings\n",
"\n",
"Let's generate embeddings using the [SentenceTransformers](https://www.sbert.net/) integration. SentenceTransformers is a python package that can generate text and image embeddings, originating from [Sentence-BERT](https://arxiv.org/abs/1908.10084)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "06c9f47d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
}
],
"source": [
"!pip install sentence_transformers > /dev/null"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "861521a9",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import SentenceTransformerEmbeddings "
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ff9be586",
"metadata": {},
"outputs": [],
"source": [
"embeddings = SentenceTransformerEmbeddings(model=\"all-MiniLM-L6-v2\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d0a98ae9",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5d6c682b",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bb5e74c0",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text, \"This is not a test document.\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aaad49f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
},
"vscode": {
"interpreter": {
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -22,6 +22,7 @@ from langchain.embeddings.self_hosted_hugging_face import (
SelfHostedHuggingFaceEmbeddings,
SelfHostedHuggingFaceInstructEmbeddings,
)
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings
logger = logging.getLogger(__name__)
@ -42,6 +43,7 @@ __all__ = [
"FakeEmbeddings",
"AlephAlphaAsymmetricSemanticEmbedding",
"AlephAlphaSymmetricSemanticEmbedding",
"SentenceTransformerEmbeddings",
]

@ -0,0 +1,63 @@
"""Wrapper around sentence transformer embedding models."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.embeddings.base import Embeddings
class SentenceTransformerEmbeddings(BaseModel, Embeddings):
embedding_function: Any #: :meta private:
model: Optional[str] = Field("all-MiniLM-L6-v2", alias="model")
"""Transformer model to use."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that sentence_transformers library is installed."""
model = values["model"]
try:
from sentence_transformers import SentenceTransformer
values["embedding_function"] = SentenceTransformer(model)
except ImportError:
raise ModuleNotFoundError(
"Could not import sentence_transformers library. "
"Please install the sentence_transformers library to "
"use this embedding model: pip install sentence_transformers"
)
except Exception:
raise NameError(f"Could not load SentenceTransformer model {model}.")
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents using the SentenceTransformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = self.embedding_function.encode(
texts, convert_to_numpy=True
).tolist()
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Embed a query using the SentenceTransformer model.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
return self.embed_documents([text])[0]

@ -117,6 +117,7 @@ torch = "^1.0.0"
chromadb = "^0.3.21"
tiktoken = "^0.3.3"
python-dotenv = "^1.0.0"
sentence-transformers = "^2"
gptcache = "^0.1.9"
promptlayer = "^0.1.80"
@ -144,7 +145,8 @@ llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifes
qdrant = ["qdrant-client"]
openai = ["openai"]
cohere = ["cohere"]
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect"]
embeddings = ["sentence-transformers"]
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect"]
[tool.ruff]
select = [

@ -0,0 +1,38 @@
# flake8: noqa
"""Test sentence_transformer embeddings."""
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
def test_sentence_transformer_embedding_documents() -> None:
"""Test sentence_transformer embeddings."""
embedding = SentenceTransformerEmbeddings()
documents = ["foo bar"]
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 384
def test_sentence_transformer_embedding_query() -> None:
"""Test sentence_transformer embeddings."""
embedding = SentenceTransformerEmbeddings()
query = "what the foo is a bar?"
query_vector = embedding.embed_query(query)
assert len(query_vector) == 384
def test_sentence_transformer_db_query() -> None:
"""Test sentence_transformer similarity search."""
embedding = SentenceTransformerEmbeddings()
texts = [
"we will foo your bar until you can't foo any more",
"the quick brown fox jumped over the lazy dog",
]
query = "what the foo is a bar?"
query_vector = embedding.embed_query(query)
assert len(query_vector) == 384
db = Chroma(embedding_function=embedding)
db.add_texts(texts)
docs = db.similarity_search_by_vector(query_vector, k=2)
assert docs[0].page_content == "we will foo your bar until you can't foo any more"
Loading…
Cancel
Save