mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
test: more flexible embedding test
This commit is contained in:
parent
1381c7fed4
commit
eb40842078
@ -47,7 +47,7 @@ click-shell==2.1
|
||||
# via imaginAIry (setup.py)
|
||||
contourpy==1.0.6
|
||||
# via matplotlib
|
||||
coverage==7.0.1
|
||||
coverage==7.0.2
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.11.0
|
||||
# via matplotlib
|
||||
@ -310,6 +310,8 @@ rsa==4.9
|
||||
# via google-auth
|
||||
ruff==0.0.206
|
||||
# via -r requirements-dev.in
|
||||
safetensors==0.2.7
|
||||
# via imaginAIry (setup.py)
|
||||
scikit-image==0.19.3
|
||||
# via basicsr
|
||||
scipy==1.9.3
|
||||
@ -328,7 +330,7 @@ six==1.16.0
|
||||
# python-dateutil
|
||||
snowballstemmer==2.2.0
|
||||
# via pydocstyle
|
||||
tb-nightly==2.12.0a20230101
|
||||
tb-nightly==2.12.0a20230102
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
|
1
setup.py
1
setup.py
@ -51,6 +51,7 @@ setup(
|
||||
"open-clip-torch",
|
||||
"requests",
|
||||
"einops==0.3.0",
|
||||
"safetensors",
|
||||
"timm>=0.4.12", # for vendored blip
|
||||
"torchdiffeq",
|
||||
"transformers==4.19.2",
|
||||
|
BIN
tests/data/neutral_clip_embedding_mps.safetensors
Normal file
BIN
tests/data/neutral_clip_embedding_mps.safetensors
Normal file
Binary file not shown.
@ -1,9 +1,11 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
from safetensors import safe_open
|
||||
|
||||
from imaginairy.modules.clip_embedders import FrozenCLIPEmbedder
|
||||
from imaginairy.utils import get_device
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
|
||||
def hash_tensor(t):
|
||||
@ -15,11 +17,13 @@ def hash_tensor(t):
|
||||
def test_text_conditioning():
|
||||
embedder = FrozenCLIPEmbedder()
|
||||
embedder.to(get_device())
|
||||
neutral_embedding = embedder.encode([""])
|
||||
hashed = hash_tensor(neutral_embedding)
|
||||
assert hashed in {
|
||||
"263e5ee7d2be087d816e094b80ffc546", # mps
|
||||
"41818051d7c469fc57d0a940c9d24d82",
|
||||
"b5f29fb26bceb60dcde19ec7ec5a0711",
|
||||
"88245bdb2a83b49092407fc5b4c473ab", # ubuntu, torch 1.12.1 cu116
|
||||
}
|
||||
neutral_embedding = embedder.encode([""]).to("cpu")
|
||||
with safe_open(
|
||||
f"{TESTS_FOLDER}/data/neutral_clip_embedding_mps.safetensors",
|
||||
framework="pt",
|
||||
device="cpu",
|
||||
) as f:
|
||||
neutral_embedding_mps_expected = f.get_tensor("neutral_clip_embedding_mps")
|
||||
|
||||
diff = neutral_embedding - neutral_embedding_mps_expected
|
||||
assert diff.sum() < 0.05
|
||||
|
Loading…
Reference in New Issue
Block a user