2022-09-17 05:21:20 +00:00
|
|
|
import hashlib
|
|
|
|
|
2022-10-17 02:46:32 +00:00
|
|
|
import pytest
|
2023-01-02 08:29:09 +00:00
|
|
|
from safetensors import safe_open
|
2022-10-17 02:46:32 +00:00
|
|
|
|
2022-09-17 05:21:20 +00:00
|
|
|
from imaginairy.modules.clip_embedders import FrozenCLIPEmbedder
|
|
|
|
from imaginairy.utils import get_device
|
2023-01-02 08:29:09 +00:00
|
|
|
from tests import TESTS_FOLDER
|
2022-09-17 05:21:20 +00:00
|
|
|
|
|
|
|
|
|
|
|
def hash_tensor(t):
|
|
|
|
t = t.cpu().detach().numpy().tobytes()
|
|
|
|
return hashlib.md5(t).hexdigest()
|
|
|
|
|
|
|
|
|
2022-10-17 02:46:32 +00:00
|
|
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
2022-09-17 05:21:20 +00:00
|
|
|
def test_text_conditioning():
|
|
|
|
embedder = FrozenCLIPEmbedder()
|
|
|
|
embedder.to(get_device())
|
2023-01-02 08:29:09 +00:00
|
|
|
neutral_embedding = embedder.encode([""]).to("cpu")
|
|
|
|
with safe_open(
|
|
|
|
f"{TESTS_FOLDER}/data/neutral_clip_embedding_mps.safetensors",
|
|
|
|
framework="pt",
|
|
|
|
device="cpu",
|
|
|
|
) as f:
|
|
|
|
neutral_embedding_mps_expected = f.get_tensor("neutral_clip_embedding_mps")
|
|
|
|
|
|
|
|
diff = neutral_embedding - neutral_embedding_mps_expected
|
|
|
|
assert diff.sum() < 0.05
|