tests: support distributed test runs
- switch to expected images instead of hashes. allow fuzzy matches feature: more consistent seeds
2
.github/workflows/ci.yaml
vendored
@ -71,4 +71,4 @@ jobs:
|
||||
- name: Test with pytest
|
||||
timeout-minutes: 10
|
||||
run: |
|
||||
pytest
|
||||
pytest --durations=50 -v
|
6
.gitignore
vendored
@ -17,4 +17,8 @@ dist
|
||||
tests/test_output
|
||||
gfpgan/**
|
||||
.python-version
|
||||
._.DS_Store
|
||||
._.DS_Store
|
||||
tests/vastai_cli.py
|
||||
/tests/test_output_local_cuda/
|
||||
/testing_support/
|
||||
.unison*
|
@ -33,6 +33,7 @@ from imaginairy.utils import (
|
||||
get_device,
|
||||
instantiate_from_config,
|
||||
platform_appropriate_autocast,
|
||||
randn_seeded,
|
||||
)
|
||||
|
||||
LIB_PATH = os.path.dirname(__file__)
|
||||
@ -65,7 +66,18 @@ def load_model_from_config(
|
||||
else:
|
||||
ckpt_path = model_weights_location
|
||||
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
|
||||
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
||||
pl_sd = None
|
||||
try:
|
||||
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
||||
except RuntimeError as e:
|
||||
if "PytorchStreamReader failed reading zip archive" in str(e):
|
||||
if model_weights_location.startswith("http"):
|
||||
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
|
||||
os.remove(ckpt_path)
|
||||
ckpt_path = cached_path(model_weights_location)
|
||||
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
||||
if pl_sd is None:
|
||||
raise e
|
||||
if "global_step" in pl_sd:
|
||||
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
@ -266,7 +278,8 @@ def imagine(
|
||||
log_latent(init_latent, "init_latent")
|
||||
# encode (scaled latent)
|
||||
seed_everything(prompt.seed)
|
||||
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
||||
noise = randn_seeded(seed=prompt.seed, size=init_latent.size())
|
||||
noise = noise.to(get_device())
|
||||
|
||||
schedule = NoiseSchedule(
|
||||
model_num_timesteps=model.num_timesteps,
|
||||
@ -280,14 +293,14 @@ def imagine(
|
||||
# (or setting steps=1000)
|
||||
init_latent_noised = noise
|
||||
else:
|
||||
|
||||
init_latent_noised = noise_an_image(
|
||||
init_latent,
|
||||
torch.tensor([t_enc - 1]).to(get_device()),
|
||||
schedule=schedule,
|
||||
noise=noise,
|
||||
)
|
||||
log_latent(init_latent_noised, "init_latent_noised")
|
||||
|
||||
log_latent(init_latent_noised, "init_latent_noised")
|
||||
|
||||
samples = sampler.sample(
|
||||
num_steps=prompt.steps,
|
||||
|
@ -172,6 +172,9 @@ def disable_transformers_custom_logging():
|
||||
_logger.handlers = []
|
||||
_logger.propagate = True
|
||||
_logger.setLevel(logging.NOTSET)
|
||||
modeling_logger.handlers = []
|
||||
modeling_logger.propagate = True
|
||||
modeling_logger.setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def disable_pytorch_lighting_custom_logging():
|
||||
|
@ -153,7 +153,7 @@ def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model):
|
||||
hint_strength = 1
|
||||
# if we're in the first 10% of the steps then don't fully noise the parts
|
||||
# of the image we're not changing so that the algorithm can learn from the context
|
||||
if ts > 900:
|
||||
if ts > 1000:
|
||||
hinted_orig_latent = (
|
||||
noised_orig_latent * (1 - hint_strength) + orig_latent * hint_strength
|
||||
)
|
||||
|
@ -186,3 +186,27 @@ def get_cached_url_path(url):
|
||||
with open(dest_path, "wb") as f:
|
||||
f.write(r.content)
|
||||
return dest_path
|
||||
|
||||
|
||||
def randn_seeded(seed: int, size: List[int]) -> Tensor:
|
||||
"""Generate a random tensor with a given seed"""
|
||||
g_cpu = torch.Generator()
|
||||
g_cpu.manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
size,
|
||||
device="cpu",
|
||||
generator=g_cpu,
|
||||
)
|
||||
return noise
|
||||
|
||||
|
||||
def check_torch_working():
|
||||
"""Check that torch is working"""
|
||||
try:
|
||||
torch.randn(1, device=get_device())
|
||||
except RuntimeError as e:
|
||||
if "CUDA" in str(e):
|
||||
raise RuntimeError(
|
||||
"CUDA is not working. Make sure you have a GPU and CUDA installed."
|
||||
) from e
|
||||
raise e
|
||||
|
@ -4,7 +4,7 @@
|
||||
#
|
||||
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
|
||||
#
|
||||
absl-py==1.2.0
|
||||
absl-py==1.3.0
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
@ -52,7 +52,7 @@ coverage==6.5.0
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.11.0
|
||||
# via matplotlib
|
||||
diffusers==0.4.1
|
||||
diffusers==0.5.1
|
||||
# via imaginAIry (setup.py)
|
||||
dill==0.3.5.1
|
||||
# via pylint
|
||||
@ -102,7 +102,7 @@ grpcio==1.49.1
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
huggingface-hub==0.10.0
|
||||
huggingface-hub==0.10.1
|
||||
# via
|
||||
# diffusers
|
||||
# timm
|
||||
@ -159,9 +159,9 @@ mypy-extensions==0.4.3
|
||||
# typing-inspect
|
||||
networkx==2.8.7
|
||||
# via scikit-image
|
||||
numba==0.56.2
|
||||
numba==0.56.3
|
||||
# via facexlib
|
||||
numpy==1.23.3
|
||||
numpy==1.23.4
|
||||
# via
|
||||
# basicsr
|
||||
# contourpy
|
||||
@ -303,7 +303,7 @@ requests==2.28.1
|
||||
# transformers
|
||||
requests-oauthlib==1.3.1
|
||||
# via google-auth-oauthlib
|
||||
responses==0.21.0
|
||||
responses==0.22.0
|
||||
# via -r requirements-dev.in
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
@ -324,7 +324,7 @@ six==1.16.0
|
||||
# python-dateutil
|
||||
snowballstemmer==2.2.0
|
||||
# via pydocstyle
|
||||
tb-nightly==2.11.0a20221010
|
||||
tb-nightly==2.11.0a20221016
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
@ -338,12 +338,14 @@ tensorboard-plugin-wit==1.8.1
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
tifffile==2022.8.12
|
||||
tifffile==2022.10.10
|
||||
# via scikit-image
|
||||
timm==0.6.11
|
||||
# via imaginAIry (setup.py)
|
||||
tokenizers==0.12.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via responses
|
||||
tomli==2.0.1
|
||||
# via
|
||||
# black
|
||||
@ -395,6 +397,8 @@ transformers==4.19.2
|
||||
# via imaginAIry (setup.py)
|
||||
typer==0.6.1
|
||||
# via pycln
|
||||
types-toml==0.10.8
|
||||
# via responses
|
||||
typing-extensions==4.4.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from shutil import rmtree
|
||||
|
||||
import pytest
|
||||
import responses
|
||||
@ -8,6 +9,7 @@ from urllib3 import HTTPConnectionPool
|
||||
|
||||
from imaginairy import api
|
||||
from imaginairy.log_utils import suppress_annoying_logs_and_warnings
|
||||
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
|
||||
from imaginairy.utils import (
|
||||
fix_torch_group_norm,
|
||||
fix_torch_nn_layer_norm,
|
||||
@ -21,29 +23,47 @@ if "pytest" in str(sys.argv):
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLERS_FOR_TESTING = SAMPLER_TYPE_OPTIONS
|
||||
if get_device() == "mps:0":
|
||||
SAMPLERS_FOR_TESTING = ["plms", "k_euler_a"]
|
||||
elif get_device() == "cpu":
|
||||
SAMPLERS_FOR_TESTING = []
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def pre_setup():
|
||||
api.IMAGINAIRY_SAFETY_MODE = "disabled"
|
||||
suppress_annoying_logs_and_warnings()
|
||||
# test_output_folder = f"{TESTS_FOLDER}/test_output"
|
||||
test_output_folder = f"{TESTS_FOLDER}/test_output"
|
||||
|
||||
# delete the testoutput folder and recreate it
|
||||
# rmtree(test_output_folder)
|
||||
os.makedirs(f"{TESTS_FOLDER}/test_output", exist_ok=True)
|
||||
try:
|
||||
rmtree(test_output_folder)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
os.makedirs(test_output_folder, exist_ok=True)
|
||||
|
||||
orig_urlopen = HTTPConnectionPool.urlopen
|
||||
|
||||
def urlopen_tattle(self, method, url, *args, **kwargs):
|
||||
# traceback.print_stack()
|
||||
print(os.environ.get("PYTEST_CURRENT_TEST"))
|
||||
print(f"{method} {self.host}{url}")
|
||||
# current_test = os.environ.get("PYTEST_CURRENT_TEST", "")
|
||||
# print(f"{current_test} {method} {self.host}{url}")
|
||||
result = orig_urlopen(self, method, url, *args, **kwargs)
|
||||
print(f"{method} {self.host}{url} DONE")
|
||||
|
||||
# raise HTTPError("NO NETWORK CALLS")
|
||||
return result
|
||||
|
||||
HTTPConnectionPool.urlopen = urlopen_tattle
|
||||
# tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
|
||||
|
||||
# real_randn = torch.randn
|
||||
# def randn_tattle(*args, **kwargs):
|
||||
# print("RANDN CALL RANDN CALL")
|
||||
# traceback.print_stack()
|
||||
# return real_randn(*args, **kwargs)
|
||||
#
|
||||
# torch.randn = randn_tattle
|
||||
|
||||
with fix_torch_nn_layer_norm(), fix_torch_group_norm(), platform_appropriate_autocast():
|
||||
yield
|
||||
@ -56,11 +76,59 @@ def reset_get_device():
|
||||
|
||||
@pytest.fixture()
|
||||
def filename_base_for_outputs(request):
|
||||
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_{get_device()}_"
|
||||
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_"
|
||||
return filename_base
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def filename_base_for_orig_outputs(request):
|
||||
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.originalname}_"
|
||||
return filename_base
|
||||
|
||||
|
||||
@pytest.fixture(params=SAMPLERS_FOR_TESTING)
|
||||
def sampler_type(request):
|
||||
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_responses():
|
||||
with responses.RequestsMock() as rsps:
|
||||
yield rsps
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--subset",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Runs an exclusive subset of tests: '1/3', '2/3', '3/3'. Useful for distributed testing",
|
||||
)
|
||||
|
||||
|
||||
@pytest.hookimpl()
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Only select a subset of tests to run, based on the --subset option."""
|
||||
filtered_node_ids = set()
|
||||
node_ids = [f.nodeid for f in items]
|
||||
node_ids.sort()
|
||||
subset = config.getoption("--subset")
|
||||
if subset:
|
||||
partition_no, total_partitions = subset.split("/")
|
||||
partition_no, total_partitions = int(partition_no), int(total_partitions)
|
||||
if partition_no < 1 or partition_no > total_partitions:
|
||||
raise ValueError("Invalid subset")
|
||||
for i, node_id in enumerate(node_ids):
|
||||
if i % total_partitions == partition_no - 1:
|
||||
filtered_node_ids.add(node_id)
|
||||
|
||||
items[:] = [i for i in items if i.nodeid in filtered_node_ids]
|
||||
|
||||
print(
|
||||
f"Running subset {partition_no}/{total_partitions} {len(filtered_node_ids)} tests:"
|
||||
)
|
||||
filtered_node_ids = list(filtered_node_ids)
|
||||
filtered_node_ids.sort()
|
||||
for n in filtered_node_ids:
|
||||
print(f" {n}")
|
||||
|
After Width: | Height: | Size: 66 KiB |
BIN
tests/expected_output/test_cliptext_inpainting_pearl_doctor_.png
Normal file
After Width: | Height: | Size: 332 KiB |
BIN
tests/expected_output/test_fix_faces_.png
Normal file
After Width: | Height: | Size: 279 KiB |
BIN
tests/expected_output/test_imagine[ddim]_.png
Normal file
After Width: | Height: | Size: 572 KiB |
BIN
tests/expected_output/test_imagine[k_dpm_2]_.png
Normal file
After Width: | Height: | Size: 584 KiB |
BIN
tests/expected_output/test_imagine[k_dpm_2_a]_.png
Normal file
After Width: | Height: | Size: 557 KiB |
BIN
tests/expected_output/test_imagine[k_euler]_.png
Normal file
After Width: | Height: | Size: 601 KiB |
BIN
tests/expected_output/test_imagine[k_euler_a]_.png
Normal file
After Width: | Height: | Size: 581 KiB |
BIN
tests/expected_output/test_imagine[k_heun]_.png
Normal file
After Width: | Height: | Size: 582 KiB |
BIN
tests/expected_output/test_imagine[k_lms]_.png
Normal file
After Width: | Height: | Size: 595 KiB |
BIN
tests/expected_output/test_imagine[plms]_.png
Normal file
After Width: | Height: | Size: 539 KiB |
BIN
tests/expected_output/test_img2img_beach_to_sunset[ddim]_.png
Normal file
After Width: | Height: | Size: 257 KiB |
BIN
tests/expected_output/test_img2img_beach_to_sunset[k_dpm_2]_.png
Normal file
After Width: | Height: | Size: 238 KiB |
After Width: | Height: | Size: 260 KiB |
BIN
tests/expected_output/test_img2img_beach_to_sunset[k_euler]_.png
Normal file
After Width: | Height: | Size: 233 KiB |
After Width: | Height: | Size: 252 KiB |
BIN
tests/expected_output/test_img2img_beach_to_sunset[k_heun]_.png
Normal file
After Width: | Height: | Size: 235 KiB |
BIN
tests/expected_output/test_img2img_beach_to_sunset[k_lms]_.png
Normal file
After Width: | Height: | Size: 236 KiB |
BIN
tests/expected_output/test_img2img_beach_to_sunset[plms]_.png
Normal file
After Width: | Height: | Size: 265 KiB |
BIN
tests/expected_output/test_img_to_img_from_url_cats[ddim]_.png
Normal file
After Width: | Height: | Size: 390 KiB |
After Width: | Height: | Size: 346 KiB |
After Width: | Height: | Size: 430 KiB |
After Width: | Height: | Size: 332 KiB |
After Width: | Height: | Size: 430 KiB |
BIN
tests/expected_output/test_img_to_img_from_url_cats[k_heun]_.png
Normal file
After Width: | Height: | Size: 337 KiB |
BIN
tests/expected_output/test_img_to_img_from_url_cats[k_lms]_.png
Normal file
After Width: | Height: | Size: 339 KiB |
BIN
tests/expected_output/test_img_to_img_from_url_cats[plms]_.png
Normal file
After Width: | Height: | Size: 393 KiB |
After Width: | Height: | Size: 262 KiB |
After Width: | Height: | Size: 264 KiB |
BIN
tests/expected_output/test_img_to_img_fruit_2_gold[ddim-0]_.png
Normal file
After Width: | Height: | Size: 230 KiB |
BIN
tests/expected_output/test_img_to_img_fruit_2_gold[ddim-1]_.png
Normal file
After Width: | Height: | Size: 241 KiB |
After Width: | Height: | Size: 264 KiB |
After Width: | Height: | Size: 264 KiB |
After Width: | Height: | Size: 230 KiB |
After Width: | Height: | Size: 240 KiB |
After Width: | Height: | Size: 260 KiB |
After Width: | Height: | Size: 265 KiB |
After Width: | Height: | Size: 238 KiB |
After Width: | Height: | Size: 240 KiB |
After Width: | Height: | Size: 266 KiB |
After Width: | Height: | Size: 265 KiB |
After Width: | Height: | Size: 228 KiB |
After Width: | Height: | Size: 240 KiB |
After Width: | Height: | Size: 260 KiB |
After Width: | Height: | Size: 267 KiB |
After Width: | Height: | Size: 245 KiB |
After Width: | Height: | Size: 240 KiB |
After Width: | Height: | Size: 266 KiB |
After Width: | Height: | Size: 265 KiB |
After Width: | Height: | Size: 228 KiB |
After Width: | Height: | Size: 240 KiB |
After Width: | Height: | Size: 266 KiB |
After Width: | Height: | Size: 265 KiB |
BIN
tests/expected_output/test_img_to_img_fruit_2_gold[k_lms-0]_.png
Normal file
After Width: | Height: | Size: 228 KiB |
BIN
tests/expected_output/test_img_to_img_fruit_2_gold[k_lms-1]_.png
Normal file
After Width: | Height: | Size: 240 KiB |
After Width: | Height: | Size: 255 KiB |
After Width: | Height: | Size: 263 KiB |
BIN
tests/expected_output/test_img_to_img_fruit_2_gold[plms-0]_.png
Normal file
After Width: | Height: | Size: 234 KiB |
BIN
tests/expected_output/test_img_to_img_fruit_2_gold[plms-1]_.png
Normal file
After Width: | Height: | Size: 241 KiB |
BIN
tests/expected_output/test_inpainting_bench_.png
Normal file
After Width: | Height: | Size: 556 KiB |
After Width: | Height: | Size: 556 KiB |
@ -3,84 +3,28 @@ import os.path
|
||||
import pytest
|
||||
|
||||
from imaginairy import LazyLoadingImage
|
||||
from imaginairy.api import imagine, imagine_image_files, prompt_normalized
|
||||
from imaginairy.api import imagine, imagine_image_files
|
||||
from imaginairy.img_utils import pillow_fit_image_within
|
||||
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
|
||||
from imaginairy.schema import ImaginePrompt
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
from . import TESTS_FOLDER
|
||||
|
||||
device_sampler_type_test_cases = {
|
||||
"mps:0": [
|
||||
("plms", "78539ae3a3097dc8232da6d630551ab3"),
|
||||
(
|
||||
"ddim",
|
||||
("828fc143cd40586347b2f8403c288c9b", "4c7905d4a36f6f9c456b7e074b52707e"),
|
||||
),
|
||||
("k_lms", "53d25e59add39c8447537be30e4eff4b"),
|
||||
("k_dpm_2", "5108bceb58a38d88a585f37b2ba1b072"),
|
||||
("k_dpm_2_a", "20396daa6c920d1cfd6db90e73558c01"),
|
||||
("k_euler", "9ab4666ebe6c3aa68673912bb17fb2b1"),
|
||||
("k_euler_a", "c4b03829cc93422801f3243a46bad4bc"),
|
||||
("k_heun", "0d3aad6800d4a9a43f0b0514af9d23b5"),
|
||||
],
|
||||
"cuda": [
|
||||
("plms", "b98e1248ad1f144d34122d8809b39fb8"),
|
||||
("ddim", "a645ca24575ed3f18bf48f11354233bb"),
|
||||
("k_lms", "3ddbdef45e3f38768730961771d01727"),
|
||||
("k_dpm_2", "b6e88e16ec2c43e6382b1adec828479d"),
|
||||
("k_dpm_2_a", "b0791770d48cb22d308ad76c72fb660f"),
|
||||
("k_euler", "bcf375769d64d9ca224864d35565ac1d"),
|
||||
("k_euler_a", "38b970ff6a67428efbf00df66a9e48f7"),
|
||||
("k_heun", "ccbd0804c7ce2bb637c682951bd8b693"),
|
||||
],
|
||||
"cpu": [],
|
||||
}
|
||||
sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
|
||||
from .utils import assert_image_similar_to_expectation
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases)
|
||||
def test_imagine(sampler_type, expected_md5, filename_base_for_outputs):
|
||||
prompt_text = "a scenic landscape"
|
||||
def test_imagine(sampler_type, filename_base_for_outputs):
|
||||
prompt_text = "a scenic old-growth forest with diffuse light poking through the canopy. high resolution nature photography"
|
||||
prompt = ImaginePrompt(
|
||||
prompt_text, width=512, height=256, steps=20, seed=1, sampler_type=sampler_type
|
||||
prompt_text, width=512, height=512, steps=20, seed=1, sampler_type=sampler_type
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
assert result.md5() in expected_md5
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)
|
||||
|
||||
|
||||
device_sampler_type_test_cases_img_2_img = {
|
||||
"mps:0": {
|
||||
("plms", "0d9c40c348cdac7bdc8d5a472f378f42"),
|
||||
("ddim", "0d9c40c348cdac7bdc8d5a472f378f42"),
|
||||
("k_lms", ""),
|
||||
("k_dpm_2", ""),
|
||||
("k_dpm_2_a", ""),
|
||||
("k_euler", ""),
|
||||
("k_euler_a", ""),
|
||||
("k_heun", ""),
|
||||
},
|
||||
"cuda": {
|
||||
("plms", "841723966344dd8678aee1ce5f9cbb3d"),
|
||||
("ddim", "1f0d72370fabcf2ff716e4068d5b2360"),
|
||||
("k_lms", ""),
|
||||
("k_dpm_2", ""),
|
||||
("k_dpm_2_a", ""),
|
||||
("k_euler", ""),
|
||||
("k_euler_a", ""),
|
||||
("k_heun", ""),
|
||||
},
|
||||
}
|
||||
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img.get(
|
||||
get_device(), []
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
|
||||
def test_img2img_beach_to_sunset(sampler_type, expected_md5, filename_base_for_outputs):
|
||||
def test_img2img_beach_to_sunset(
|
||||
sampler_type, filename_base_for_outputs, filename_base_for_orig_outputs
|
||||
):
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
|
||||
prompt = ImaginePrompt(
|
||||
"a painting of beautiful cloudy sunset at the beach",
|
||||
@ -96,48 +40,17 @@ def test_img2img_beach_to_sunset(sampler_type, expected_md5, filename_base_for_o
|
||||
sampler_type=sampler_type,
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
|
||||
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg")
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)
|
||||
|
||||
|
||||
device_sampler_type_test_cases_img_2_img = {
|
||||
"mps:0": {
|
||||
(
|
||||
"plms",
|
||||
("e9bb714771f7984e61debabc4bb3cd22", "af344c404de70da5db519869f8fcd0c1"),
|
||||
),
|
||||
(
|
||||
"ddim",
|
||||
("62bacc4ae391e6775a3723c88738ec61", "5f0d2ee426e1bb6ccc1d57dfdd8c73bf"),
|
||||
),
|
||||
("k_lms", tuple()),
|
||||
("k_dpm_2", tuple()),
|
||||
("k_dpm_2_a", tuple()),
|
||||
("k_euler", tuple()),
|
||||
("k_euler_a", tuple()),
|
||||
("k_heun", tuple()),
|
||||
},
|
||||
"cuda": {
|
||||
("plms", ("b8c7b52da977c1531a9a61c0a082404c",)),
|
||||
("ddim", ("d6784710dd78e4cb628aba28322b04cf",)),
|
||||
("k_lms", ("3246b588155f430a79d08a0b1c7287f5",)),
|
||||
("k_dpm_2", ("724fa459adec6a7b3ebb523263dd5176",)),
|
||||
("k_dpm_2_a", ("5c36fa9c051db80e3969c63d500340f4",)),
|
||||
("k_euler", ("d6800b8a3e31f81fb3902d34ee786b33",)),
|
||||
("k_euler_a", ("6477863f35d0c9032b959a9cc7a0b61c",)),
|
||||
("k_heun", ("1ed62ad0cfd03dba8b487a36259833a3",)),
|
||||
},
|
||||
}
|
||||
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img.get(
|
||||
get_device(), []
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
|
||||
def test_img_to_img_from_url_cats(
|
||||
sampler_type, expected_md5, filename_base_for_outputs, mocked_responses
|
||||
sampler_type,
|
||||
filename_base_for_outputs,
|
||||
mocked_responses,
|
||||
filename_base_for_orig_outputs,
|
||||
):
|
||||
with open(
|
||||
os.path.join(TESTS_FOLDER, "data", "val2017-000000039769-cococats.jpg"), "rb"
|
||||
@ -165,17 +78,17 @@ def test_img_to_img_from_url_cats(
|
||||
result = next(imagine(prompt))
|
||||
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
|
||||
assert result.md5() in expected_md5
|
||||
img.save(f"{filename_base_for_orig_outputs}__orig.jpg")
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=12000)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@pytest.mark.parametrize("sampler_type", SAMPLER_TYPE_OPTIONS)
|
||||
@pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1])
|
||||
def test_img_to_img_fruit_2_gold(
|
||||
filename_base_for_outputs, sampler_type, init_strength
|
||||
filename_base_for_outputs,
|
||||
sampler_type,
|
||||
init_strength,
|
||||
filename_base_for_orig_outputs,
|
||||
):
|
||||
img = LazyLoadingImage(
|
||||
filepath=os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg")
|
||||
@ -195,29 +108,16 @@ def test_img_to_img_fruit_2_gold(
|
||||
|
||||
result = next(imagine(prompt))
|
||||
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg")
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=9000)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_img_to_img_fruit_2_gold_repeat():
|
||||
"""Run this test manually to"""
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
|
||||
outdir = f"{TESTS_FOLDER}/test_output/"
|
||||
run_count = 1
|
||||
|
||||
def _record_step(img, description, step_count, prompt):
|
||||
steps_path = os.path.join(
|
||||
outdir,
|
||||
f"steps_fruit_2_gold_repeat_{get_device()}_S{prompt.seed}_run_{run_count:02}",
|
||||
)
|
||||
os.makedirs(steps_path, exist_ok=True)
|
||||
filename = f"fruit_2_gold_repeat_{get_device()}_S{prompt.seed}_step{step_count:04}_{prompt_normalized(description)[:40]}.jpg"
|
||||
|
||||
destination = os.path.join(steps_path, filename)
|
||||
img.save(destination)
|
||||
|
||||
kwargs = dict(
|
||||
prompt="a white bowl filled with gold coins. sharp focus",
|
||||
prompt_strength=12,
|
||||
@ -237,8 +137,6 @@ def test_img_to_img_fruit_2_gold_repeat():
|
||||
ImaginePrompt(**kwargs),
|
||||
]
|
||||
for result in imagine(prompts, img_callback=None):
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold__orig.jpg")
|
||||
result.img.save(
|
||||
f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_plms_{get_device()}_run-{run_count:02}.jpg"
|
||||
)
|
||||
@ -261,7 +159,7 @@ def test_img_to_file():
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_inpainting_bench(filename_base_for_outputs):
|
||||
def test_inpainting_bench(filename_base_for_outputs, filename_base_for_orig_outputs):
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
|
||||
prompt = ImaginePrompt(
|
||||
"a wise old man",
|
||||
@ -276,13 +174,15 @@ def test_inpainting_bench(filename_base_for_outputs):
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}_orig.jpg")
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_cliptext_inpainting_pearl_doctor(filename_base_for_outputs):
|
||||
def test_cliptext_inpainting_pearl_doctor(
|
||||
filename_base_for_outputs, filename_base_for_orig_outputs
|
||||
):
|
||||
img = LazyLoadingImage(
|
||||
filepath=f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg"
|
||||
)
|
||||
@ -301,6 +201,6 @@ def test_cliptext_inpainting_pearl_doctor(filename_base_for_outputs):
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}_{prompt.seed}_01.jpg")
|
||||
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}_orig.jpg")
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)
|
||||
|
@ -1,5 +1,3 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from pytorch_lightning import seed_everything
|
||||
@ -12,28 +10,20 @@ from imaginairy.enhancers.describe_image_clip import find_img_text_similarity
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.utils import get_device
|
||||
from tests import TESTS_FOLDER
|
||||
from tests.utils import assert_image_similar_to_expectation
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
get_device() == "cpu", reason="TypeError: Got unsupported ScalarType BFloat16"
|
||||
)
|
||||
def test_fix_faces():
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/distorted_face.png")
|
||||
def test_fix_faces(filename_base_for_orig_outputs, filename_base_for_outputs):
|
||||
distorted_img = Image.open(f"{TESTS_FOLDER}/data/distorted_face.png")
|
||||
seed_everything(1)
|
||||
img = enhance_faces(img)
|
||||
img.save(f"{TESTS_FOLDER}/test_output/fixed_face.png")
|
||||
if "mps" in get_device():
|
||||
assert img_hash(img) == "a75991307eda675a26eeb7073f828e93"
|
||||
else:
|
||||
# probably different based on whether first run or not. looks the same either way
|
||||
assert img_hash(img) in [
|
||||
"c840cf3bfe5a7760734f425a3f8941cf",
|
||||
"e56c1205bbc8f251be05773f2ba7fa24",
|
||||
]
|
||||
img = enhance_faces(distorted_img)
|
||||
|
||||
|
||||
def img_hash(img):
|
||||
return hashlib.md5(img.tobytes()).hexdigest()
|
||||
distorted_img.save(f"{filename_base_for_orig_outputs}__orig.jpg")
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
assert_image_similar_to_expectation(img, img_path=img_path, threshold=2800)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@ -58,13 +48,13 @@ def test_clip_masking():
|
||||
)
|
||||
|
||||
prompt = ImaginePrompt(
|
||||
"a female firefighter in front of a burning building",
|
||||
"",
|
||||
init_image=img,
|
||||
init_image_strength=0.95,
|
||||
init_image_strength=0.5,
|
||||
# lower steps for faster tests
|
||||
steps=40,
|
||||
mask_prompt="(head OR face){*5}",
|
||||
mask_mode="replace",
|
||||
mask_mode="keep",
|
||||
upscale=False,
|
||||
fix_faces=True,
|
||||
)
|
||||
@ -72,7 +62,7 @@ def test_clip_masking():
|
||||
result = next(imagine(prompt))
|
||||
result.save(
|
||||
f"{TESTS_FOLDER}/test_output/earring_mask_photo.png",
|
||||
image_type="modified_original",
|
||||
image_type="generated",
|
||||
)
|
||||
|
||||
|
||||
|
25
tests/utils.py
Normal file
@ -0,0 +1,25 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def assert_image_similar_to_expectation(img, img_path, threshold=100):
|
||||
img.save(img_path)
|
||||
expected_img_path = img_path.replace("/test_output/", "/expected_output/")
|
||||
expected_img = Image.open(expected_img_path)
|
||||
norm_sum_sq_diff = calc_norm_sum_sq_diff(img, expected_img)
|
||||
|
||||
if norm_sum_sq_diff > threshold:
|
||||
diff_img = Image.fromarray(np.asarray(img) - np.asarray(expected_img))
|
||||
diff_img.save(img_path + f"_diff_{norm_sum_sq_diff:.1f}.png")
|
||||
expected_img.save(img_path + "_expected.png")
|
||||
assert (
|
||||
norm_sum_sq_diff < threshold
|
||||
), f"{norm_sum_sq_diff:.3f} is bigger than threshold {threshold}"
|
||||
|
||||
|
||||
def calc_norm_sum_sq_diff(img, img2):
|
||||
sum_sq_diff = np.sum(
|
||||
(np.asarray(img).astype("float") - np.asarray(img2).astype("float")) ** 2
|
||||
)
|
||||
norm_sum_sq_diff = sum_sq_diff / np.sqrt(sum_sq_diff)
|
||||
return norm_sum_sq_diff
|
2
tox.ini
@ -8,7 +8,7 @@ filterwarnings =
|
||||
|
||||
[pylama]
|
||||
format = pylint
|
||||
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*
|
||||
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*,testing_support/vastai_cli_official.py
|
||||
linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy
|
||||
ignore =
|
||||
Z999,C0103,C0301,C0114,C0115,C0116,
|
||||
|