From 4e8f11b36ae7c3b35dd9fa9114e3b06d4198b908 Mon Sep 17 00:00:00 2001 From: Yoshi <74702693+yuhuishi-convect@users.noreply.github.com> Date: Thu, 3 Aug 2023 13:36:45 -0700 Subject: [PATCH] Deterministic Fake Embedding Model (#8706) Solves #8644 This embedding models output identical random embedding vectors, given the input texts are identical. Useful when used in unittest. @baskaryan --- .../langchain/embeddings/__init__.py | 3 +- libs/langchain/langchain/embeddings/fake.py | 28 +++++++++++++++++++ .../test_deterministic_embedding.py | 16 +++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 libs/langchain/tests/unit_tests/embeddings/test_deterministic_embedding.py diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index eeb623471b..ee572fd185 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -27,7 +27,7 @@ from langchain.embeddings.deepinfra import DeepInfraEmbeddings from langchain.embeddings.edenai import EdenAiEmbeddings from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.embaas import EmbaasEmbeddings -from langchain.embeddings.fake import FakeEmbeddings +from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.gpt4all import GPT4AllEmbeddings from langchain.embeddings.huggingface import ( @@ -78,6 +78,7 @@ __all__ = [ "SelfHostedHuggingFaceEmbeddings", "SelfHostedHuggingFaceInstructEmbeddings", "FakeEmbeddings", + "DeterministicFakeEmbedding", "AlephAlphaAsymmetricSemanticEmbedding", "AlephAlphaSymmetricSemanticEmbedding", "SentenceTransformerEmbeddings", diff --git a/libs/langchain/langchain/embeddings/fake.py b/libs/langchain/langchain/embeddings/fake.py index 65bf7cfa21..1b8311dec7 100644 --- a/libs/langchain/langchain/embeddings/fake.py +++ b/libs/langchain/langchain/embeddings/fake.py @@ -1,3 +1,4 @@ +import hashlib from typing import List import numpy as np @@ -20,3 +21,30 @@ class FakeEmbeddings(Embeddings, BaseModel): def embed_query(self, text: str) -> List[float]: return self._get_embedding() + + +class DeterministicFakeEmbedding(Embeddings, BaseModel): + """ + Fake embedding model that always returns + the same embedding vector for the same text. + """ + + size: int + """The size of the embedding vector.""" + + def _get_embedding(self, seed: int) -> List[float]: + # set the seed for the random generator + np.random.seed(seed) + return list(np.random.normal(size=self.size)) + + def _get_seed(self, text: str) -> int: + """ + Get a seed for the random generator, using the hash of the text. + """ + return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8 + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self._get_embedding(seed=self._get_seed(_)) for _ in texts] + + def embed_query(self, text: str) -> List[float]: + return self._get_embedding(seed=self._get_seed(text)) diff --git a/libs/langchain/tests/unit_tests/embeddings/test_deterministic_embedding.py b/libs/langchain/tests/unit_tests/embeddings/test_deterministic_embedding.py new file mode 100644 index 0000000000..6ed8cde63c --- /dev/null +++ b/libs/langchain/tests/unit_tests/embeddings/test_deterministic_embedding.py @@ -0,0 +1,16 @@ +from langchain.embeddings import DeterministicFakeEmbedding + + +def test_deterministic_fake_embeddings() -> None: + """ + Test that the deterministic fake embeddings return the same + embedding vector for the same text. + """ + fake = DeterministicFakeEmbedding(size=10) + text = "Hello world!" + assert fake.embed_query(text) == fake.embed_query(text) + assert fake.embed_query(text) != fake.embed_query("Goodbye world!") + assert fake.embed_documents([text, text]) == fake.embed_documents([text, text]) + assert fake.embed_documents([text, text]) != fake.embed_documents( + [text, "Goodbye world!"] + )