diff --git a/imaginairy/api.py b/imaginairy/api.py index d85ab2e..fdbc8a5 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -280,7 +280,6 @@ def imagine( unconditional_guidance_scale=prompt.prompt_strength, unconditional_conditioning=uc, eta=ddim_eta, - initial_noise_tensor=start_code, img_callback=_img_callback, ) diff --git a/imaginairy/enhancers/describe_image_clip.py b/imaginairy/enhancers/describe_image_clip.py index a1a89a2..9a86dda 100644 --- a/imaginairy/enhancers/describe_image_clip.py +++ b/imaginairy/enhancers/describe_image_clip.py @@ -30,7 +30,7 @@ def find_img_text_similarity(image: Image.Image, phrases: Sequence): def find_embed_text_similarity(embed_features, phrases): - model, preprocess = get_model() + model, _ = get_model() text = clip.tokenize(phrases).to(device) with torch.no_grad(): diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index 6052978..ab65c18 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -11,7 +11,7 @@ from functools import partial import numpy as np import pytorch_lightning as pl import torch -import torch.nn as nn +from torch import nn from einops import rearrange from torchvision.utils import make_grid from tqdm import tqdm @@ -28,9 +28,11 @@ logger = logging.getLogger(__name__) __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" +def disabled_train(self): + """ + Overwrite model.train with this function to make sure train/eval mode + does not change anymore. + """ return self @@ -265,6 +267,7 @@ class LatentDiffusion(DDPM): self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward + self.cond_ids = None self.clip_denoised = False self.bbox_tokenizer = None @@ -346,13 +349,13 @@ class LatentDiffusion(DDPM): self.cond_stage_model = model def _get_denoise_row_from_list( - self, samples, desc="", force_no_decoder_quantization=False + self, samples, desc="" ): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append( self.decode_first_stage( - zd.to(self.device), force_not_quantize=force_no_decoder_quantization + zd.to(self.device) ) ) n_imgs_per_row = len(denoise_row) @@ -547,7 +550,7 @@ class LatentDiffusion(DDPM): else: xc = x if not self.cond_stage_trainable or force_c_encode: - if isinstance(xc, dict) or isinstance(xc, list): + if isinstance(xc, (dict, list)): # import pudb; pudb.set_trace() c = self.get_learned_conditioning(xc) else: @@ -577,7 +580,7 @@ class LatentDiffusion(DDPM): return out @torch.no_grad() - def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + def decode_first_stage(self, z, predict_cids=False): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() diff --git a/imaginairy/modules/diffusion/util.py b/imaginairy/modules/diffusion/util.py index 9b0c38a..3963646 100644 --- a/imaginairy/modules/diffusion/util.py +++ b/imaginairy/modules/diffusion/util.py @@ -262,9 +262,9 @@ def conv_nd(dims, *args, **kwargs): """ if dims == 1: return nn.Conv1d(*args, **kwargs) - elif dims == 2: + if dims == 2: return nn.Conv2d(*args, **kwargs) - elif dims == 3: + if dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") diff --git a/imaginairy/modules/find_noise.py b/imaginairy/modules/find_noise.py index 394573d..bb7857a 100644 --- a/imaginairy/modules/find_noise.py +++ b/imaginairy/modules/find_noise.py @@ -13,9 +13,9 @@ from einops import repeat from torch import autocast from imaginairy.utils import get_device, pillow_img_to_torch_image +from imaginairy.vendored import k_diffusion as K - -def pil_img_to_latent(model, img, batch_size=1, device="cuda", half=True): +def pil_img_to_latent(model, img, batch_size=1, half=True): # init_image = pil_img_to_torch(img, half=half).to(device) init_image = pillow_img_to_torch_image(img).to(get_device()) init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) @@ -40,11 +40,7 @@ def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half= ) -def find_noise_for_latent( - model, img_latent, prompt, steps=50, cond_scale=1.0, half=True -): - from imaginairy.vendored import k_diffusion as K - +def find_noise_for_latent(model, img_latent, prompt, steps=50, cond_scale=1.0): x = img_latent _autocast = autocast if get_device() in ("cuda", "cpu") else nullcontext diff --git a/imaginairy/safety.py b/imaginairy/safety.py index ca147b7..c0d7468 100644 --- a/imaginairy/safety.py +++ b/imaginairy/safety.py @@ -15,7 +15,7 @@ def safety_models(): return safety_feature_extractor, safety_checker -def is_nsfw(img, x_sample): +def is_nsfw(img): safety_feature_extractor, safety_checker = safety_models() safety_checker_input = safety_feature_extractor([img], return_tensors="pt") clip_input = safety_checker_input.pixel_values diff --git a/imaginairy/samplers/base.py b/imaginairy/samplers/base.py index 0908b36..8a69c2f 100644 --- a/imaginairy/samplers/base.py +++ b/imaginairy/samplers/base.py @@ -3,6 +3,7 @@ from torch import nn from imaginairy.utils import get_device + SAMPLER_TYPE_OPTIONS = [ "plms", "ddim", @@ -32,11 +33,12 @@ def get_sampler(sampler_type, model): sampler_type = sampler_type.lower() if sampler_type == "plms": return PLMSSampler(model) - elif sampler_type == "ddim": + if sampler_type == "ddim": return DDIMSampler(model) - elif sampler_type.startswith("k_"): + if sampler_type.startswith("k_"): sampler_type = _k_sampler_type_lookup[sampler_type] return KDiffusionSampler(model, sampler_type) + raise ValueError("invalid sampler_type") class CFGDenoiser(nn.Module): diff --git a/imaginairy/samplers/ddim.py b/imaginairy/samplers/ddim.py index 20e966f..078bde3 100644 --- a/imaginairy/samplers/ddim.py +++ b/imaginairy/samplers/ddim.py @@ -127,8 +127,8 @@ class DDIMSampler: x_T=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... ): if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -376,7 +376,7 @@ class DDIMSampler: # x_dec = x_dec.detach() + (original_loss * 0.1) ** 2 # cond_grad = -torch.autograd.grad(original_loss, x_dec)[0] # x_dec = x_dec.detach() + cond_grad * sigma_t ** 2 - ## x_dec_alt = x_dec + (original_loss * 0.1) ** 2 + # x_dec_alt = x_dec + (original_loss * 0.1) ** 2 log_latent(x_dec, f"x_dec {i}") log_latent(pred_x0, f"pred_x0 {i}") diff --git a/imaginairy/samplers/plms.py b/imaginairy/samplers/plms.py index 7e9bb1c..978e9ff 100644 --- a/imaginairy/samplers/plms.py +++ b/imaginairy/samplers/plms.py @@ -20,13 +20,14 @@ logger = logging.getLogger(__name__) class PLMSSampler: """probabilistic least-mean-squares""" - def __init__(self, model, **kwargs): + def __init__(self, model): self.model = model self.ddpm_num_timesteps = model.num_timesteps self.device_available = get_device() + self.ddim_timesteps = None def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: + if isinstance(attr, torch.Tensor): if attr.device != torch.device(self.device_available): attr = attr.to(torch.float32).to(torch.device(self.device_available)) setattr(self, name, attr) @@ -43,7 +44,9 @@ class PLMSSampler: assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + def to_torch(x): + return x.clone().detach().to(torch.float32).to(self.model.device) self.register_buffer("betas", to_torch(self.model.betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) @@ -431,7 +434,7 @@ class PLMSSampler: # x_dec = x_dec.detach() + (original_loss * 0.1) ** 2 # cond_grad = -torch.autograd.grad(original_loss, x_dec)[0] # x_dec = x_dec.detach() + cond_grad * sigma_t ** 2 - ## x_dec_alt = x_dec + (original_loss * 0.1) ** 2 + # x_dec_alt = x_dec + (original_loss * 0.1) ** 2 old_eps.append(e_t) if len(old_eps) >= 4: diff --git a/tests/test_safety.py b/tests/test_safety.py index d3434f8..4d95209 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -9,11 +9,11 @@ from tests import TESTS_FOLDER def test_is_nsfw(): img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg") latent = _pil_to_latent(img) - assert is_nsfw(img, latent) + assert is_nsfw(img) img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg") latent = _pil_to_latent(img) - assert not is_nsfw(img, latent) + assert not is_nsfw(img) def _pil_to_latent(img):