2022-09-17 05:21:20 +00:00
|
|
|
import hashlib
|
|
|
|
|
2022-10-17 02:46:32 +00:00
|
|
|
import pytest
|
|
|
|
|
2022-09-17 05:21:20 +00:00
|
|
|
from imaginairy.modules.clip_embedders import FrozenCLIPEmbedder
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
|
|
|
|
def hash_tensor(t):
|
|
|
|
t = t.cpu().detach().numpy().tobytes()
|
|
|
|
return hashlib.md5(t).hexdigest()
|
|
|
|
|
|
|
|
|
2022-10-17 02:46:32 +00:00
|
|
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
2022-09-17 05:21:20 +00:00
|
|
|
def test_text_conditioning():
|
|
|
|
embedder = FrozenCLIPEmbedder()
|
|
|
|
embedder.to(get_device())
|
|
|
|
neutral_embedding = embedder.encode([""])
|
2022-09-17 19:24:27 +00:00
|
|
|
hashed = hash_tensor(neutral_embedding)
|
2022-10-17 02:46:32 +00:00
|
|
|
assert hashed in {
|
|
|
|
"263e5ee7d2be087d816e094b80ffc546", # mps
|
|
|
|
"41818051d7c469fc57d0a940c9d24d82",
|
|
|
|
"b5f29fb26bceb60dcde19ec7ec5a0711",
|
|
|
|
}
|