|
|
|
@ -5,6 +5,7 @@ import pytest
|
|
|
|
|
from imaginairy import LazyLoadingImage
|
|
|
|
|
from imaginairy.api import imagine, imagine_image_files, prompt_normalized
|
|
|
|
|
from imaginairy.img_utils import pillow_fit_image_within
|
|
|
|
|
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
|
|
|
|
|
from imaginairy.schema import ImaginePrompt
|
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
|
@ -13,7 +14,10 @@ from . import TESTS_FOLDER
|
|
|
|
|
device_sampler_type_test_cases = {
|
|
|
|
|
"mps:0": [
|
|
|
|
|
("plms", "78539ae3a3097dc8232da6d630551ab3"),
|
|
|
|
|
("ddim", ("828fc143cd40586347b2f8403c288c9b", "4c7905d4a36f6f9c456b7e074b52707e")),
|
|
|
|
|
(
|
|
|
|
|
"ddim",
|
|
|
|
|
("828fc143cd40586347b2f8403c288c9b", "4c7905d4a36f6f9c456b7e074b52707e"),
|
|
|
|
|
),
|
|
|
|
|
("k_lms", "53d25e59add39c8447537be30e4eff4b"),
|
|
|
|
|
("k_dpm_2", "5108bceb58a38d88a585f37b2ba1b072"),
|
|
|
|
|
("k_dpm_2_a", "20396daa6c920d1cfd6db90e73558c01"),
|
|
|
|
@ -51,10 +55,22 @@ device_sampler_type_test_cases_img_2_img = {
|
|
|
|
|
"mps:0": {
|
|
|
|
|
("plms", "0d9c40c348cdac7bdc8d5a472f378f42"),
|
|
|
|
|
("ddim", "0d9c40c348cdac7bdc8d5a472f378f42"),
|
|
|
|
|
("k_lms", ""),
|
|
|
|
|
("k_dpm_2", ""),
|
|
|
|
|
("k_dpm_2_a", ""),
|
|
|
|
|
("k_euler", ""),
|
|
|
|
|
("k_euler_a", ""),
|
|
|
|
|
("k_heun", ""),
|
|
|
|
|
},
|
|
|
|
|
"cuda": {
|
|
|
|
|
("plms", "841723966344dd8678aee1ce5f9cbb3d"),
|
|
|
|
|
("ddim", "1f0d72370fabcf2ff716e4068d5b2360"),
|
|
|
|
|
("k_lms", ""),
|
|
|
|
|
("k_dpm_2", ""),
|
|
|
|
|
("k_dpm_2_a", ""),
|
|
|
|
|
("k_euler", ""),
|
|
|
|
|
("k_euler_a", ""),
|
|
|
|
|
("k_heun", ""),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img.get(
|
|
|
|
@ -87,12 +103,30 @@ def test_img2img_beach_to_sunset(sampler_type, expected_md5, filename_base_for_o
|
|
|
|
|
|
|
|
|
|
device_sampler_type_test_cases_img_2_img = {
|
|
|
|
|
"mps:0": {
|
|
|
|
|
("plms", ("e9bb714771f7984e61debabc4bb3cd22", "af344c404de70da5db519869f8fcd0c1")),
|
|
|
|
|
("ddim", ("62bacc4ae391e6775a3723c88738ec61", "5f0d2ee426e1bb6ccc1d57dfdd8c73bf")),
|
|
|
|
|
(
|
|
|
|
|
"plms",
|
|
|
|
|
("e9bb714771f7984e61debabc4bb3cd22", "af344c404de70da5db519869f8fcd0c1"),
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
"ddim",
|
|
|
|
|
("62bacc4ae391e6775a3723c88738ec61", "5f0d2ee426e1bb6ccc1d57dfdd8c73bf"),
|
|
|
|
|
),
|
|
|
|
|
("k_lms", tuple()),
|
|
|
|
|
("k_dpm_2", tuple()),
|
|
|
|
|
("k_dpm_2_a", tuple()),
|
|
|
|
|
("k_euler", tuple()),
|
|
|
|
|
("k_euler_a", tuple()),
|
|
|
|
|
("k_heun", tuple()),
|
|
|
|
|
},
|
|
|
|
|
"cuda": {
|
|
|
|
|
("plms", "b8c7b52da977c1531a9a61c0a082404c"),
|
|
|
|
|
("ddim", "d6784710dd78e4cb628aba28322b04cf"),
|
|
|
|
|
("plms", ("b8c7b52da977c1531a9a61c0a082404c",)),
|
|
|
|
|
("ddim", ("d6784710dd78e4cb628aba28322b04cf",)),
|
|
|
|
|
("k_lms", ("3246b588155f430a79d08a0b1c7287f5",)),
|
|
|
|
|
("k_dpm_2", ("724fa459adec6a7b3ebb523263dd5176",)),
|
|
|
|
|
("k_dpm_2_a", ("5c36fa9c051db80e3969c63d500340f4",)),
|
|
|
|
|
("k_euler", ("d6800b8a3e31f81fb3902d34ee786b33",)),
|
|
|
|
|
("k_euler_a", ("6477863f35d0c9032b959a9cc7a0b61c",)),
|
|
|
|
|
("k_heun", ("1ed62ad0cfd03dba8b487a36259833a3",)),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img.get(
|
|
|
|
@ -137,9 +171,8 @@ def test_img_to_img_from_url_cats(
|
|
|
|
|
assert result.md5() in expected_md5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @pytest.mark.parametrize("sampler_type", SAMPLER_TYPE_OPTIONS)
|
|
|
|
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
|
|
|
|
@pytest.mark.parametrize("sampler_type", ["ddim", "plms"])
|
|
|
|
|
@pytest.mark.parametrize("sampler_type", SAMPLER_TYPE_OPTIONS)
|
|
|
|
|
@pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1])
|
|
|
|
|
def test_img_to_img_fruit_2_gold(
|
|
|
|
|
filename_base_for_outputs, sampler_type, init_strength
|
|
|
|
@ -153,7 +186,7 @@ def test_img_to_img_fruit_2_gold(
|
|
|
|
|
prompt_strength=12,
|
|
|
|
|
init_image=img,
|
|
|
|
|
init_image_strength=init_strength,
|
|
|
|
|
mask_prompt="(fruit{*2} OR stem{*5} OR fruit stem{*3})",
|
|
|
|
|
mask_prompt="(fruit{*2} OR stem{*10} OR fruit stem{*3})",
|
|
|
|
|
mask_mode="replace",
|
|
|
|
|
steps=80,
|
|
|
|
|
seed=1,
|
|
|
|
|