test: more flexible embedding test

This commit is contained in:
Bryce 2023-01-02 00:29:09 -08:00 committed by Bryce Drennan
parent 1381c7fed4
commit eb40842078
4 changed files with 17 additions and 10 deletions

View File

@ -47,7 +47,7 @@ click-shell==2.1
# via imaginAIry (setup.py)
contourpy==1.0.6
# via matplotlib
coverage==7.0.1
coverage==7.0.2
# via -r requirements-dev.in
cycler==0.11.0
# via matplotlib
@ -310,6 +310,8 @@ rsa==4.9
# via google-auth
ruff==0.0.206
# via -r requirements-dev.in
safetensors==0.2.7
# via imaginAIry (setup.py)
scikit-image==0.19.3
# via basicsr
scipy==1.9.3
@ -328,7 +330,7 @@ six==1.16.0
# python-dateutil
snowballstemmer==2.2.0
# via pydocstyle
tb-nightly==2.12.0a20230101
tb-nightly==2.12.0a20230102
# via
# basicsr
# gfpgan

View File

@ -51,6 +51,7 @@ setup(
"open-clip-torch",
"requests",
"einops==0.3.0",
"safetensors",
"timm>=0.4.12", # for vendored blip
"torchdiffeq",
"transformers==4.19.2",

Binary file not shown.

View File

@ -1,9 +1,11 @@
import hashlib
import pytest
from safetensors import safe_open
from imaginairy.modules.clip_embedders import FrozenCLIPEmbedder
from imaginairy.utils import get_device
from tests import TESTS_FOLDER
def hash_tensor(t):
@ -15,11 +17,13 @@ def hash_tensor(t):
def test_text_conditioning():
embedder = FrozenCLIPEmbedder()
embedder.to(get_device())
neutral_embedding = embedder.encode([""])
hashed = hash_tensor(neutral_embedding)
assert hashed in {
"263e5ee7d2be087d816e094b80ffc546", # mps
"41818051d7c469fc57d0a940c9d24d82",
"b5f29fb26bceb60dcde19ec7ec5a0711",
"88245bdb2a83b49092407fc5b4c473ab", # ubuntu, torch 1.12.1 cu116
}
neutral_embedding = embedder.encode([""]).to("cpu")
with safe_open(
f"{TESTS_FOLDER}/data/neutral_clip_embedding_mps.safetensors",
framework="pt",
device="cpu",
) as f:
neutral_embedding_mps_expected = f.get_tensor("neutral_clip_embedding_mps")
diff = neutral_embedding - neutral_embedding_mps_expected
assert diff.sum() < 0.05