mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Deterministic Fake Embedding Model (#8706)
Solves #8644 This embedding models output identical random embedding vectors, given the input texts are identical. Useful when used in unittest. @baskaryan
This commit is contained in:
parent
2928a1a3c9
commit
4e8f11b36a
@ -27,7 +27,7 @@ from langchain.embeddings.deepinfra import DeepInfraEmbeddings
|
|||||||
from langchain.embeddings.edenai import EdenAiEmbeddings
|
from langchain.embeddings.edenai import EdenAiEmbeddings
|
||||||
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
||||||
from langchain.embeddings.embaas import EmbaasEmbeddings
|
from langchain.embeddings.embaas import EmbaasEmbeddings
|
||||||
from langchain.embeddings.fake import FakeEmbeddings
|
from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
|
||||||
from langchain.embeddings.google_palm import GooglePalmEmbeddings
|
from langchain.embeddings.google_palm import GooglePalmEmbeddings
|
||||||
from langchain.embeddings.gpt4all import GPT4AllEmbeddings
|
from langchain.embeddings.gpt4all import GPT4AllEmbeddings
|
||||||
from langchain.embeddings.huggingface import (
|
from langchain.embeddings.huggingface import (
|
||||||
@ -78,6 +78,7 @@ __all__ = [
|
|||||||
"SelfHostedHuggingFaceEmbeddings",
|
"SelfHostedHuggingFaceEmbeddings",
|
||||||
"SelfHostedHuggingFaceInstructEmbeddings",
|
"SelfHostedHuggingFaceInstructEmbeddings",
|
||||||
"FakeEmbeddings",
|
"FakeEmbeddings",
|
||||||
|
"DeterministicFakeEmbedding",
|
||||||
"AlephAlphaAsymmetricSemanticEmbedding",
|
"AlephAlphaAsymmetricSemanticEmbedding",
|
||||||
"AlephAlphaSymmetricSemanticEmbedding",
|
"AlephAlphaSymmetricSemanticEmbedding",
|
||||||
"SentenceTransformerEmbeddings",
|
"SentenceTransformerEmbeddings",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -20,3 +21,30 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
|||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
return self._get_embedding()
|
return self._get_embedding()
|
||||||
|
|
||||||
|
|
||||||
|
class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||||
|
"""
|
||||||
|
Fake embedding model that always returns
|
||||||
|
the same embedding vector for the same text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
size: int
|
||||||
|
"""The size of the embedding vector."""
|
||||||
|
|
||||||
|
def _get_embedding(self, seed: int) -> List[float]:
|
||||||
|
# set the seed for the random generator
|
||||||
|
np.random.seed(seed)
|
||||||
|
return list(np.random.normal(size=self.size))
|
||||||
|
|
||||||
|
def _get_seed(self, text: str) -> int:
|
||||||
|
"""
|
||||||
|
Get a seed for the random generator, using the hash of the text.
|
||||||
|
"""
|
||||||
|
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
return self._get_embedding(seed=self._get_seed(text))
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
from langchain.embeddings import DeterministicFakeEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
def test_deterministic_fake_embeddings() -> None:
|
||||||
|
"""
|
||||||
|
Test that the deterministic fake embeddings return the same
|
||||||
|
embedding vector for the same text.
|
||||||
|
"""
|
||||||
|
fake = DeterministicFakeEmbedding(size=10)
|
||||||
|
text = "Hello world!"
|
||||||
|
assert fake.embed_query(text) == fake.embed_query(text)
|
||||||
|
assert fake.embed_query(text) != fake.embed_query("Goodbye world!")
|
||||||
|
assert fake.embed_documents([text, text]) == fake.embed_documents([text, text])
|
||||||
|
assert fake.embed_documents([text, text]) != fake.embed_documents(
|
||||||
|
[text, "Goodbye world!"]
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user