|
|
|
@ -1,9 +1,11 @@
|
|
|
|
|
import hashlib
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
from safetensors import safe_open
|
|
|
|
|
|
|
|
|
|
from imaginairy.modules.clip_embedders import FrozenCLIPEmbedder
|
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
from tests import TESTS_FOLDER
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hash_tensor(t):
|
|
|
|
@ -15,11 +17,13 @@ def hash_tensor(t):
|
|
|
|
|
def test_text_conditioning():
|
|
|
|
|
embedder = FrozenCLIPEmbedder()
|
|
|
|
|
embedder.to(get_device())
|
|
|
|
|
neutral_embedding = embedder.encode([""])
|
|
|
|
|
hashed = hash_tensor(neutral_embedding)
|
|
|
|
|
assert hashed in {
|
|
|
|
|
"263e5ee7d2be087d816e094b80ffc546", # mps
|
|
|
|
|
"41818051d7c469fc57d0a940c9d24d82",
|
|
|
|
|
"b5f29fb26bceb60dcde19ec7ec5a0711",
|
|
|
|
|
"88245bdb2a83b49092407fc5b4c473ab", # ubuntu, torch 1.12.1 cu116
|
|
|
|
|
}
|
|
|
|
|
neutral_embedding = embedder.encode([""]).to("cpu")
|
|
|
|
|
with safe_open(
|
|
|
|
|
f"{TESTS_FOLDER}/data/neutral_clip_embedding_mps.safetensors",
|
|
|
|
|
framework="pt",
|
|
|
|
|
device="cpu",
|
|
|
|
|
) as f:
|
|
|
|
|
neutral_embedding_mps_expected = f.get_tensor("neutral_clip_embedding_mps")
|
|
|
|
|
|
|
|
|
|
diff = neutral_embedding - neutral_embedding_mps_expected
|
|
|
|
|
assert diff.sum() < 0.05
|
|
|
|
|