tests: support distributed test runs

- switch to expected images instead of hashes. allow fuzzy matches
feature: more consistent seeds
This commit is contained in:
Bryce 2022-10-16 16:42:46 -07:00 committed by Bryce Drennan
parent 4ba1965db8
commit dcf953383e
73 changed files with 213 additions and 182 deletions

View File

@ -71,4 +71,4 @@ jobs:
- name: Test with pytest
timeout-minutes: 10
run: |
pytest
pytest --durations=50 -v

6
.gitignore vendored
View File

@ -17,4 +17,8 @@ dist
tests/test_output
gfpgan/**
.python-version
._.DS_Store
._.DS_Store
tests/vastai_cli.py
/tests/test_output_local_cuda/
/testing_support/
.unison*

View File

@ -33,6 +33,7 @@ from imaginairy.utils import (
get_device,
instantiate_from_config,
platform_appropriate_autocast,
randn_seeded,
)
LIB_PATH = os.path.dirname(__file__)
@ -65,7 +66,18 @@ def load_model_from_config(
else:
ckpt_path = model_weights_location
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
pl_sd = torch.load(ckpt_path, map_location="cpu")
pl_sd = None
try:
pl_sd = torch.load(ckpt_path, map_location="cpu")
except RuntimeError as e:
if "PytorchStreamReader failed reading zip archive" in str(e):
if model_weights_location.startswith("http"):
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
os.remove(ckpt_path)
ckpt_path = cached_path(model_weights_location)
pl_sd = torch.load(ckpt_path, map_location="cpu")
if pl_sd is None:
raise e
if "global_step" in pl_sd:
logger.debug(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
@ -266,7 +278,8 @@ def imagine(
log_latent(init_latent, "init_latent")
# encode (scaled latent)
seed_everything(prompt.seed)
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
noise = randn_seeded(seed=prompt.seed, size=init_latent.size())
noise = noise.to(get_device())
schedule = NoiseSchedule(
model_num_timesteps=model.num_timesteps,
@ -280,14 +293,14 @@ def imagine(
# (or setting steps=1000)
init_latent_noised = noise
else:
init_latent_noised = noise_an_image(
init_latent,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
)
log_latent(init_latent_noised, "init_latent_noised")
log_latent(init_latent_noised, "init_latent_noised")
samples = sampler.sample(
num_steps=prompt.steps,

View File

@ -172,6 +172,9 @@ def disable_transformers_custom_logging():
_logger.handlers = []
_logger.propagate = True
_logger.setLevel(logging.NOTSET)
modeling_logger.handlers = []
modeling_logger.propagate = True
modeling_logger.setLevel(logging.ERROR)
def disable_pytorch_lighting_custom_logging():

View File

@ -153,7 +153,7 @@ def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model):
hint_strength = 1
# if we're in the first 10% of the steps then don't fully noise the parts
# of the image we're not changing so that the algorithm can learn from the context
if ts > 900:
if ts > 1000:
hinted_orig_latent = (
noised_orig_latent * (1 - hint_strength) + orig_latent * hint_strength
)

View File

@ -186,3 +186,27 @@ def get_cached_url_path(url):
with open(dest_path, "wb") as f:
f.write(r.content)
return dest_path
def randn_seeded(seed: int, size: List[int]) -> Tensor:
"""Generate a random tensor with a given seed"""
g_cpu = torch.Generator()
g_cpu.manual_seed(seed)
noise = torch.randn(
size,
device="cpu",
generator=g_cpu,
)
return noise
def check_torch_working():
"""Check that torch is working"""
try:
torch.randn(1, device=get_device())
except RuntimeError as e:
if "CUDA" in str(e):
raise RuntimeError(
"CUDA is not working. Make sure you have a GPU and CUDA installed."
) from e
raise e

View File

@ -4,7 +4,7 @@
#
# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py
#
absl-py==1.2.0
absl-py==1.3.0
# via
# tb-nightly
# tensorboard
@ -52,7 +52,7 @@ coverage==6.5.0
# via -r requirements-dev.in
cycler==0.11.0
# via matplotlib
diffusers==0.4.1
diffusers==0.5.1
# via imaginAIry (setup.py)
dill==0.3.5.1
# via pylint
@ -102,7 +102,7 @@ grpcio==1.49.1
# via
# tb-nightly
# tensorboard
huggingface-hub==0.10.0
huggingface-hub==0.10.1
# via
# diffusers
# timm
@ -159,9 +159,9 @@ mypy-extensions==0.4.3
# typing-inspect
networkx==2.8.7
# via scikit-image
numba==0.56.2
numba==0.56.3
# via facexlib
numpy==1.23.3
numpy==1.23.4
# via
# basicsr
# contourpy
@ -303,7 +303,7 @@ requests==2.28.1
# transformers
requests-oauthlib==1.3.1
# via google-auth-oauthlib
responses==0.21.0
responses==0.22.0
# via -r requirements-dev.in
rsa==4.9
# via google-auth
@ -324,7 +324,7 @@ six==1.16.0
# python-dateutil
snowballstemmer==2.2.0
# via pydocstyle
tb-nightly==2.11.0a20221010
tb-nightly==2.11.0a20221016
# via
# basicsr
# gfpgan
@ -338,12 +338,14 @@ tensorboard-plugin-wit==1.8.1
# via
# tb-nightly
# tensorboard
tifffile==2022.8.12
tifffile==2022.10.10
# via scikit-image
timm==0.6.11
# via imaginAIry (setup.py)
tokenizers==0.12.1
# via transformers
toml==0.10.2
# via responses
tomli==2.0.1
# via
# black
@ -395,6 +397,8 @@ transformers==4.19.2
# via imaginAIry (setup.py)
typer==0.6.1
# via pycln
types-toml==0.10.8
# via responses
typing-extensions==4.4.0
# via
# huggingface-hub

View File

@ -1,6 +1,7 @@
import logging
import os
import sys
from shutil import rmtree
import pytest
import responses
@ -8,6 +9,7 @@ from urllib3 import HTTPConnectionPool
from imaginairy import api
from imaginairy.log_utils import suppress_annoying_logs_and_warnings
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
from imaginairy.utils import (
fix_torch_group_norm,
fix_torch_nn_layer_norm,
@ -21,29 +23,47 @@ if "pytest" in str(sys.argv):
logger = logging.getLogger(__name__)
SAMPLERS_FOR_TESTING = SAMPLER_TYPE_OPTIONS
if get_device() == "mps:0":
SAMPLERS_FOR_TESTING = ["plms", "k_euler_a"]
elif get_device() == "cpu":
SAMPLERS_FOR_TESTING = []
@pytest.fixture(scope="session", autouse=True)
def pre_setup():
api.IMAGINAIRY_SAFETY_MODE = "disabled"
suppress_annoying_logs_and_warnings()
# test_output_folder = f"{TESTS_FOLDER}/test_output"
test_output_folder = f"{TESTS_FOLDER}/test_output"
# delete the testoutput folder and recreate it
# rmtree(test_output_folder)
os.makedirs(f"{TESTS_FOLDER}/test_output", exist_ok=True)
try:
rmtree(test_output_folder)
except FileNotFoundError:
pass
os.makedirs(test_output_folder, exist_ok=True)
orig_urlopen = HTTPConnectionPool.urlopen
def urlopen_tattle(self, method, url, *args, **kwargs):
# traceback.print_stack()
print(os.environ.get("PYTEST_CURRENT_TEST"))
print(f"{method} {self.host}{url}")
# current_test = os.environ.get("PYTEST_CURRENT_TEST", "")
# print(f"{current_test} {method} {self.host}{url}")
result = orig_urlopen(self, method, url, *args, **kwargs)
print(f"{method} {self.host}{url} DONE")
# raise HTTPError("NO NETWORK CALLS")
return result
HTTPConnectionPool.urlopen = urlopen_tattle
# tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
# real_randn = torch.randn
# def randn_tattle(*args, **kwargs):
# print("RANDN CALL RANDN CALL")
# traceback.print_stack()
# return real_randn(*args, **kwargs)
#
# torch.randn = randn_tattle
with fix_torch_nn_layer_norm(), fix_torch_group_norm(), platform_appropriate_autocast():
yield
@ -56,11 +76,59 @@ def reset_get_device():
@pytest.fixture()
def filename_base_for_outputs(request):
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_{get_device()}_"
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_"
return filename_base
@pytest.fixture()
def filename_base_for_orig_outputs(request):
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.originalname}_"
return filename_base
@pytest.fixture(params=SAMPLERS_FOR_TESTING)
def sampler_type(request):
return request.param
@pytest.fixture
def mocked_responses():
with responses.RequestsMock() as rsps:
yield rsps
def pytest_addoption(parser):
parser.addoption(
"--subset",
action="store",
default=None,
help="Runs an exclusive subset of tests: '1/3', '2/3', '3/3'. Useful for distributed testing",
)
@pytest.hookimpl()
def pytest_collection_modifyitems(config, items):
"""Only select a subset of tests to run, based on the --subset option."""
filtered_node_ids = set()
node_ids = [f.nodeid for f in items]
node_ids.sort()
subset = config.getoption("--subset")
if subset:
partition_no, total_partitions = subset.split("/")
partition_no, total_partitions = int(partition_no), int(total_partitions)
if partition_no < 1 or partition_no > total_partitions:
raise ValueError("Invalid subset")
for i, node_id in enumerate(node_ids):
if i % total_partitions == partition_no - 1:
filtered_node_ids.add(node_id)
items[:] = [i for i in items if i.nodeid in filtered_node_ids]
print(
f"Running subset {partition_no}/{total_partitions} {len(filtered_node_ids)} tests:"
)
filtered_node_ids = list(filtered_node_ids)
filtered_node_ids.sort()
for n in filtered_node_ids:
print(f" {n}")

Binary file not shown.

After

Width:  |  Height:  |  Size: 332 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 279 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 572 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 584 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 557 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 601 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 581 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 582 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 595 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 539 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 257 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 238 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 260 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 233 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 252 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 235 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 236 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 390 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 346 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 430 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 332 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 430 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 337 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 339 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 393 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 262 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 241 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 240 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 260 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 238 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 240 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 228 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 240 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 260 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 267 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 245 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 240 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 228 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 240 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 228 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 240 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 255 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 234 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 241 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 556 KiB

View File

@ -3,84 +3,28 @@ import os.path
import pytest
from imaginairy import LazyLoadingImage
from imaginairy.api import imagine, imagine_image_files, prompt_normalized
from imaginairy.api import imagine, imagine_image_files
from imaginairy.img_utils import pillow_fit_image_within
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
from imaginairy.schema import ImaginePrompt
from imaginairy.utils import get_device
from . import TESTS_FOLDER
device_sampler_type_test_cases = {
"mps:0": [
("plms", "78539ae3a3097dc8232da6d630551ab3"),
(
"ddim",
("828fc143cd40586347b2f8403c288c9b", "4c7905d4a36f6f9c456b7e074b52707e"),
),
("k_lms", "53d25e59add39c8447537be30e4eff4b"),
("k_dpm_2", "5108bceb58a38d88a585f37b2ba1b072"),
("k_dpm_2_a", "20396daa6c920d1cfd6db90e73558c01"),
("k_euler", "9ab4666ebe6c3aa68673912bb17fb2b1"),
("k_euler_a", "c4b03829cc93422801f3243a46bad4bc"),
("k_heun", "0d3aad6800d4a9a43f0b0514af9d23b5"),
],
"cuda": [
("plms", "b98e1248ad1f144d34122d8809b39fb8"),
("ddim", "a645ca24575ed3f18bf48f11354233bb"),
("k_lms", "3ddbdef45e3f38768730961771d01727"),
("k_dpm_2", "b6e88e16ec2c43e6382b1adec828479d"),
("k_dpm_2_a", "b0791770d48cb22d308ad76c72fb660f"),
("k_euler", "bcf375769d64d9ca224864d35565ac1d"),
("k_euler_a", "38b970ff6a67428efbf00df66a9e48f7"),
("k_heun", "ccbd0804c7ce2bb637c682951bd8b693"),
],
"cpu": [],
}
sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
from .utils import assert_image_similar_to_expectation
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases)
def test_imagine(sampler_type, expected_md5, filename_base_for_outputs):
prompt_text = "a scenic landscape"
def test_imagine(sampler_type, filename_base_for_outputs):
prompt_text = "a scenic old-growth forest with diffuse light poking through the canopy. high resolution nature photography"
prompt = ImaginePrompt(
prompt_text, width=512, height=256, steps=20, seed=1, sampler_type=sampler_type
prompt_text, width=512, height=512, steps=20, seed=1, sampler_type=sampler_type
)
result = next(imagine(prompt))
result.img.save(f"{filename_base_for_outputs}.jpg")
assert result.md5() in expected_md5
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)
device_sampler_type_test_cases_img_2_img = {
"mps:0": {
("plms", "0d9c40c348cdac7bdc8d5a472f378f42"),
("ddim", "0d9c40c348cdac7bdc8d5a472f378f42"),
("k_lms", ""),
("k_dpm_2", ""),
("k_dpm_2_a", ""),
("k_euler", ""),
("k_euler_a", ""),
("k_heun", ""),
},
"cuda": {
("plms", "841723966344dd8678aee1ce5f9cbb3d"),
("ddim", "1f0d72370fabcf2ff716e4068d5b2360"),
("k_lms", ""),
("k_dpm_2", ""),
("k_dpm_2_a", ""),
("k_euler", ""),
("k_euler_a", ""),
("k_heun", ""),
},
}
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img.get(
get_device(), []
)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
def test_img2img_beach_to_sunset(sampler_type, expected_md5, filename_base_for_outputs):
def test_img2img_beach_to_sunset(
sampler_type, filename_base_for_outputs, filename_base_for_orig_outputs
):
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
prompt = ImaginePrompt(
"a painting of beautiful cloudy sunset at the beach",
@ -96,48 +40,17 @@ def test_img2img_beach_to_sunset(sampler_type, expected_md5, filename_base_for_o
sampler_type=sampler_type,
)
result = next(imagine(prompt))
img = pillow_fit_image_within(img)
img.save(f"{filename_base_for_outputs}__orig.jpg")
result.img.save(f"{filename_base_for_outputs}.jpg")
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)
device_sampler_type_test_cases_img_2_img = {
"mps:0": {
(
"plms",
("e9bb714771f7984e61debabc4bb3cd22", "af344c404de70da5db519869f8fcd0c1"),
),
(
"ddim",
("62bacc4ae391e6775a3723c88738ec61", "5f0d2ee426e1bb6ccc1d57dfdd8c73bf"),
),
("k_lms", tuple()),
("k_dpm_2", tuple()),
("k_dpm_2_a", tuple()),
("k_euler", tuple()),
("k_euler_a", tuple()),
("k_heun", tuple()),
},
"cuda": {
("plms", ("b8c7b52da977c1531a9a61c0a082404c",)),
("ddim", ("d6784710dd78e4cb628aba28322b04cf",)),
("k_lms", ("3246b588155f430a79d08a0b1c7287f5",)),
("k_dpm_2", ("724fa459adec6a7b3ebb523263dd5176",)),
("k_dpm_2_a", ("5c36fa9c051db80e3969c63d500340f4",)),
("k_euler", ("d6800b8a3e31f81fb3902d34ee786b33",)),
("k_euler_a", ("6477863f35d0c9032b959a9cc7a0b61c",)),
("k_heun", ("1ed62ad0cfd03dba8b487a36259833a3",)),
},
}
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img.get(
get_device(), []
)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
def test_img_to_img_from_url_cats(
sampler_type, expected_md5, filename_base_for_outputs, mocked_responses
sampler_type,
filename_base_for_outputs,
mocked_responses,
filename_base_for_orig_outputs,
):
with open(
os.path.join(TESTS_FOLDER, "data", "val2017-000000039769-cococats.jpg"), "rb"
@ -165,17 +78,17 @@ def test_img_to_img_from_url_cats(
result = next(imagine(prompt))
img = pillow_fit_image_within(img)
img.save(f"{filename_base_for_outputs}__orig.jpg")
result.img.save(f"{filename_base_for_outputs}.jpg")
assert result.md5() in expected_md5
img.save(f"{filename_base_for_orig_outputs}__orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=12000)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@pytest.mark.parametrize("sampler_type", SAMPLER_TYPE_OPTIONS)
@pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1])
def test_img_to_img_fruit_2_gold(
filename_base_for_outputs, sampler_type, init_strength
filename_base_for_outputs,
sampler_type,
init_strength,
filename_base_for_orig_outputs,
):
img = LazyLoadingImage(
filepath=os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg")
@ -195,29 +108,16 @@ def test_img_to_img_fruit_2_gold(
result = next(imagine(prompt))
img = pillow_fit_image_within(img)
img.save(f"{filename_base_for_outputs}__orig.jpg")
result.img.save(f"{filename_base_for_outputs}.jpg")
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=9000)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_img_to_img_fruit_2_gold_repeat():
"""Run this test manually to"""
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
outdir = f"{TESTS_FOLDER}/test_output/"
run_count = 1
def _record_step(img, description, step_count, prompt):
steps_path = os.path.join(
outdir,
f"steps_fruit_2_gold_repeat_{get_device()}_S{prompt.seed}_run_{run_count:02}",
)
os.makedirs(steps_path, exist_ok=True)
filename = f"fruit_2_gold_repeat_{get_device()}_S{prompt.seed}_step{step_count:04}_{prompt_normalized(description)[:40]}.jpg"
destination = os.path.join(steps_path, filename)
img.save(destination)
kwargs = dict(
prompt="a white bowl filled with gold coins. sharp focus",
prompt_strength=12,
@ -237,8 +137,6 @@ def test_img_to_img_fruit_2_gold_repeat():
ImaginePrompt(**kwargs),
]
for result in imagine(prompts, img_callback=None):
img = pillow_fit_image_within(img)
img.save(f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold__orig.jpg")
result.img.save(
f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_plms_{get_device()}_run-{run_count:02}.jpg"
)
@ -261,7 +159,7 @@ def test_img_to_file():
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_inpainting_bench(filename_base_for_outputs):
def test_inpainting_bench(filename_base_for_outputs, filename_base_for_orig_outputs):
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
prompt = ImaginePrompt(
"a wise old man",
@ -276,13 +174,15 @@ def test_inpainting_bench(filename_base_for_outputs):
)
result = next(imagine(prompt))
img = pillow_fit_image_within(img)
img.save(f"{filename_base_for_outputs}__orig.jpg")
result.img.save(f"{filename_base_for_outputs}.jpg")
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}_orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_cliptext_inpainting_pearl_doctor(filename_base_for_outputs):
def test_cliptext_inpainting_pearl_doctor(
filename_base_for_outputs, filename_base_for_orig_outputs
):
img = LazyLoadingImage(
filepath=f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg"
)
@ -301,6 +201,6 @@ def test_cliptext_inpainting_pearl_doctor(filename_base_for_outputs):
)
result = next(imagine(prompt))
img = pillow_fit_image_within(img)
img.save(f"{filename_base_for_outputs}__orig.jpg")
result.img.save(f"{filename_base_for_outputs}_{prompt.seed}_01.jpg")
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}_orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)

View File

@ -1,5 +1,3 @@
import hashlib
import pytest
from PIL import Image
from pytorch_lightning import seed_everything
@ -12,28 +10,20 @@ from imaginairy.enhancers.describe_image_clip import find_img_text_similarity
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.utils import get_device
from tests import TESTS_FOLDER
from tests.utils import assert_image_similar_to_expectation
@pytest.mark.skipif(
get_device() == "cpu", reason="TypeError: Got unsupported ScalarType BFloat16"
)
def test_fix_faces():
img = Image.open(f"{TESTS_FOLDER}/data/distorted_face.png")
def test_fix_faces(filename_base_for_orig_outputs, filename_base_for_outputs):
distorted_img = Image.open(f"{TESTS_FOLDER}/data/distorted_face.png")
seed_everything(1)
img = enhance_faces(img)
img.save(f"{TESTS_FOLDER}/test_output/fixed_face.png")
if "mps" in get_device():
assert img_hash(img) == "a75991307eda675a26eeb7073f828e93"
else:
# probably different based on whether first run or not. looks the same either way
assert img_hash(img) in [
"c840cf3bfe5a7760734f425a3f8941cf",
"e56c1205bbc8f251be05773f2ba7fa24",
]
img = enhance_faces(distorted_img)
def img_hash(img):
return hashlib.md5(img.tobytes()).hexdigest()
distorted_img.save(f"{filename_base_for_orig_outputs}__orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(img, img_path=img_path, threshold=2800)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@ -58,13 +48,13 @@ def test_clip_masking():
)
prompt = ImaginePrompt(
"a female firefighter in front of a burning building",
"",
init_image=img,
init_image_strength=0.95,
init_image_strength=0.5,
# lower steps for faster tests
steps=40,
mask_prompt="(head OR face){*5}",
mask_mode="replace",
mask_mode="keep",
upscale=False,
fix_faces=True,
)
@ -72,7 +62,7 @@ def test_clip_masking():
result = next(imagine(prompt))
result.save(
f"{TESTS_FOLDER}/test_output/earring_mask_photo.png",
image_type="modified_original",
image_type="generated",
)

25
tests/utils.py Normal file
View File

@ -0,0 +1,25 @@
import numpy as np
from PIL import Image
def assert_image_similar_to_expectation(img, img_path, threshold=100):
img.save(img_path)
expected_img_path = img_path.replace("/test_output/", "/expected_output/")
expected_img = Image.open(expected_img_path)
norm_sum_sq_diff = calc_norm_sum_sq_diff(img, expected_img)
if norm_sum_sq_diff > threshold:
diff_img = Image.fromarray(np.asarray(img) - np.asarray(expected_img))
diff_img.save(img_path + f"_diff_{norm_sum_sq_diff:.1f}.png")
expected_img.save(img_path + "_expected.png")
assert (
norm_sum_sq_diff < threshold
), f"{norm_sum_sq_diff:.3f} is bigger than threshold {threshold}"
def calc_norm_sum_sq_diff(img, img2):
sum_sq_diff = np.sum(
(np.asarray(img).astype("float") - np.asarray(img2).astype("float")) ** 2
)
norm_sum_sq_diff = sum_sq_diff / np.sqrt(sum_sq_diff)
return norm_sum_sq_diff

View File

@ -8,7 +8,7 @@ filterwarnings =
[pylama]
format = pylint
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*,testing_support/vastai_cli_official.py
linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy
ignore =
Z999,C0103,C0301,C0114,C0115,C0116,