diff --git a/imaginairy/api.py b/imaginairy/api.py index c5fab28..55f0b13 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -295,6 +295,7 @@ def imagine( schedule=schedule, noise=noise, ) + log_latent(init_latent_noised, "init_latent_noised") samples = sampler.sample( diff --git a/imaginairy/samplers/base.py b/imaginairy/samplers/base.py index 07284c6..136eea5 100644 --- a/imaginairy/samplers/base.py +++ b/imaginairy/samplers/base.py @@ -74,7 +74,7 @@ class CFGDenoiser(nn.Module): if mask is not None: assert orig_latent is not None mask_inv = 1.0 - mask - noise_pred = (orig_latent * mask_inv) + (mask * noise_pred) + noise_pred = (orig_latent * mask) + (mask_inv * noise_pred) return noise_pred diff --git a/imaginairy/samplers/kdiff.py b/imaginairy/samplers/kdiff.py index 33a906a..0059474 100644 --- a/imaginairy/samplers/kdiff.py +++ b/imaginairy/samplers/kdiff.py @@ -32,7 +32,7 @@ class KDiffusionSampler: mask=None, orig_latent=None, initial_latent=None, - img_callback=None, + t_start=None, ): if positive_conditioning.shape[0] != batch_size: raise ValueError( @@ -62,6 +62,8 @@ class KDiffusionSampler: "cond": positive_conditioning, "uncond": neutral_conditioning, "cond_scale": guidance_scale, + "mask": mask, + "orig_latent": orig_latent, }, disable=False, callback=callback, diff --git a/imaginairy/schema.py b/imaginairy/schema.py index 82a6be4..923e4e2 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -105,7 +105,7 @@ class ImaginePrompt: upscale=False, fix_faces=False, fix_faces_fidelity=DEFAULT_FACE_FIDELITY, - sampler_type="PLMS", + sampler_type="plms", conditioning=None, tile_mode=False, ): diff --git a/tests/test_api.py b/tests/test_api.py index 947ef8b..adc7655 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -137,6 +137,7 @@ def test_img_to_img_from_url_cats( assert result.md5() == expected_md5 +# @pytest.mark.parametrize("sampler_type", SAMPLER_TYPE_OPTIONS) @pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU") @pytest.mark.parametrize("sampler_type", ["ddim", "plms"]) @pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1]) @@ -152,7 +153,7 @@ def test_img_to_img_fruit_2_gold( prompt_strength=12, init_image=img, init_image_strength=init_strength, - mask_prompt="(fruit OR stem{*5} OR fruit stem)", + mask_prompt="(fruit{*2} OR stem{*5} OR fruit stem{*3})", mask_mode="replace", steps=80, seed=1,