Deterministic Fake Embedding Model (#8706)

Solves #8644 
This embedding models output identical random embedding vectors, given
the input texts are identical.
Useful when used in unittest.
@baskaryan
This commit is contained in:
Yoshi 2023-08-03 13:36:45 -07:00 committed by GitHub
parent 2928a1a3c9
commit 4e8f11b36a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 1 deletions

View File

@ -27,7 +27,7 @@ from langchain.embeddings.deepinfra import DeepInfraEmbeddings
from langchain.embeddings.edenai import EdenAiEmbeddings from langchain.embeddings.edenai import EdenAiEmbeddings
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
from langchain.embeddings.embaas import EmbaasEmbeddings from langchain.embeddings.embaas import EmbaasEmbeddings
from langchain.embeddings.fake import FakeEmbeddings from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings
from langchain.embeddings.gpt4all import GPT4AllEmbeddings from langchain.embeddings.gpt4all import GPT4AllEmbeddings
from langchain.embeddings.huggingface import ( from langchain.embeddings.huggingface import (
@ -78,6 +78,7 @@ __all__ = [
"SelfHostedHuggingFaceEmbeddings", "SelfHostedHuggingFaceEmbeddings",
"SelfHostedHuggingFaceInstructEmbeddings", "SelfHostedHuggingFaceInstructEmbeddings",
"FakeEmbeddings", "FakeEmbeddings",
"DeterministicFakeEmbedding",
"AlephAlphaAsymmetricSemanticEmbedding", "AlephAlphaAsymmetricSemanticEmbedding",
"AlephAlphaSymmetricSemanticEmbedding", "AlephAlphaSymmetricSemanticEmbedding",
"SentenceTransformerEmbeddings", "SentenceTransformerEmbeddings",

View File

@ -1,3 +1,4 @@
import hashlib
from typing import List from typing import List
import numpy as np import numpy as np
@ -20,3 +21,30 @@ class FakeEmbeddings(Embeddings, BaseModel):
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
return self._get_embedding() return self._get_embedding()
class DeterministicFakeEmbedding(Embeddings, BaseModel):
"""
Fake embedding model that always returns
the same embedding vector for the same text.
"""
size: int
"""The size of the embedding vector."""
def _get_embedding(self, seed: int) -> List[float]:
# set the seed for the random generator
np.random.seed(seed)
return list(np.random.normal(size=self.size))
def _get_seed(self, text: str) -> int:
"""
Get a seed for the random generator, using the hash of the text.
"""
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
def embed_query(self, text: str) -> List[float]:
return self._get_embedding(seed=self._get_seed(text))

View File

@ -0,0 +1,16 @@
from langchain.embeddings import DeterministicFakeEmbedding
def test_deterministic_fake_embeddings() -> None:
"""
Test that the deterministic fake embeddings return the same
embedding vector for the same text.
"""
fake = DeterministicFakeEmbedding(size=10)
text = "Hello world!"
assert fake.embed_query(text) == fake.embed_query(text)
assert fake.embed_query(text) != fake.embed_query("Goodbye world!")
assert fake.embed_documents([text, text]) == fake.embed_documents([text, text])
assert fake.embed_documents([text, text]) != fake.embed_documents(
[text, "Goodbye world!"]
)