diff --git a/README.md b/README.md index 92b79df..0a53249 100644 --- a/README.md +++ b/README.md @@ -9,4 +9,11 @@ AI imagined images. - LDM - Latent Diffusion - Stable Diffusion - - \ No newline at end of file + +# Todo + - add tests + - add docs + - remove yaml config + - deploy to pypi + - add image describe feature + \ No newline at end of file diff --git a/imaginairy/imagine.py b/imaginairy/imagine.py index cb48d3f..eae6d95 100755 --- a/imaginairy/imagine.py +++ b/imaginairy/imagine.py @@ -5,6 +5,7 @@ import re import subprocess from contextlib import nullcontext +import PIL import numpy as np import torch from PIL import Image @@ -184,6 +185,8 @@ class ImaginePrompt: seed=None, prompt_strength=7.5, sampler_type="PLMS", + init_image=None, + init_image_strength=0.3, steps=50, height=512, width=512, @@ -196,6 +199,8 @@ class ImaginePrompt: self.prompts = [WeightedPrompt(prompt, 1)] else: self.prompts = prompt + self.init_image = init_image + self.init_image_strength = init_image_strength self.prompts.sort(key=lambda p: p.weight, reverse=True) self.seed = random.randint(1, 1_000_000_000) if seed is None else seed self.prompt_strength = prompt_strength @@ -214,6 +219,20 @@ class ImaginePrompt: return "|".join(str(p) for p in self.prompts) +def load_img(path, max_height=512, max_width=512): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + resize_ratio = min(max_width / w, max_height / h) + w, h = int(w * resize_ratio), int(h * resize_ratio) + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0, w, h + + def imagine( prompts, config="data/stable-diffusion-v1.yaml", @@ -254,7 +273,6 @@ def imagine( for wp in prompt.prompts ] ) - # c = model.get_learned_conditioning(prompt.prompt_text) shape = [ latent_channels, @@ -263,41 +281,74 @@ def imagine( ] def img_callback(samples, i): - return + pass samples = model.decode_first_stage(samples) samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0) + steps_path = os.path.join( + sample_path, "steps", f"{base_count:08}_S{prompt.seed}" + ) + os.makedirs(steps_path, exist_ok=True) for pred_x0 in samples: pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c") - filename = f"{base_count:08}_S{seed}_step{i:04}.jpg" + filename = f"{base_count:08}_S{prompt.seed}_step{i:04}.jpg" Image.fromarray(pred_x0.astype(np.uint8)).save( - os.path.join(sample_path, filename) + os.path.join(steps_path, filename) ) start_code = None - if fixed_code: - start_code = torch.randn( - [1, latent_channels, prompt.height, prompt.width], - device=get_device(), - ) + # if fixed_code: + # start_code = torch.randn( + # [1, latent_channels, prompt.height, prompt.width], + # device=get_device(), + # ) sampler = get_sampler(prompt.sampler_type, model) - samples_ddim, _ = sampler.sample( - S=prompt.steps, - conditioning=c, - batch_size=1, - shape=shape, - verbose=False, - unconditional_guidance_scale=prompt.prompt_strength, - unconditional_conditioning=uc, - eta=ddim_eta, - x_T=start_code, - img_callback=img_callback, - ) + if prompt.init_image: + generation_strength = 1 - prompt.init_image_strength + ddim_steps = int(prompt.steps / generation_strength) + sampler.make_schedule( + ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False + ) + + t_enc = int(generation_strength * ddim_steps) + init_image, w, h = load_img(prompt.init_image) + init_image = init_image.to(get_device()) + init_latent = model.get_first_stage_encoding( + model.encode_first_stage(init_image) + ) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode( + init_latent, torch.tensor([t_enc]).to(get_device()) + ) + # decode it + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=prompt.prompt_strength, + unconditional_conditioning=uc, + img_callback=img_callback, + ) + else: + + samples, _ = sampler.sample( + S=prompt.steps, + conditioning=c, + batch_size=1, + shape=shape, + verbose=False, + unconditional_guidance_scale=prompt.prompt_strength, + unconditional_conditioning=uc, + eta=ddim_eta, + x_T=start_code, + img_callback=img_callback, + ) - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) if not skip_save: - for x_sample in x_samples_ddim: + for x_sample in x_samples: x_sample = 255.0 * rearrange( x_sample.cpu().numpy(), "c h w -> h w c" ) diff --git a/imaginairy/models/autoencoder.py b/imaginairy/models/autoencoder.py index 4b70e9d..d625011 100644 --- a/imaginairy/models/autoencoder.py +++ b/imaginairy/models/autoencoder.py @@ -61,7 +61,6 @@ class VQModel(pl.LightningModule): self.lr_g_factor = lr_g_factor - class VQModelInterface(VQModel): def __init__(self, embed_dim, *args, **kwargs): super().__init__(embed_dim=embed_dim, *args, **kwargs) diff --git a/imaginairy/models/diffusion/ddim.py b/imaginairy/models/diffusion/ddim.py index b92efa2..308d0a6 100644 --- a/imaginairy/models/diffusion/ddim.py +++ b/imaginairy/models/diffusion/ddim.py @@ -218,7 +218,7 @@ class DDIMSampler: ) # TODO: deterministic forward pass? img = img_orig * mask + (1.0 - mask) * img - outs = self.p_sample_ddim( + img, pred_x0 = self.p_sample_ddim( img, cond, ts, @@ -232,7 +232,6 @@ class DDIMSampler: unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, ) - img, pred_x0 = outs if callback: callback(i) if img_callback: @@ -341,6 +340,7 @@ class DDIMSampler: unconditional_guidance_scale=1.0, unconditional_conditioning=None, use_original_steps=False, + img_callback=None, ): timesteps = ( @@ -361,7 +361,7 @@ class DDIMSampler: ts = torch.full( (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long ) - x_dec, _ = self.p_sample_ddim( + x_dec, pred_x0 = self.p_sample_ddim( x_dec, cond, ts, @@ -370,4 +370,6 @@ class DDIMSampler: unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, ) + if img_callback: + img_callback(pred_x0, i) return x_dec diff --git a/imaginairy/models/diffusion/ddpm.py b/imaginairy/models/diffusion/ddpm.py index 13df265..0584913 100644 --- a/imaginairy/models/diffusion/ddpm.py +++ b/imaginairy/models/diffusion/ddpm.py @@ -43,33 +43,33 @@ def uniform_on_device(r1, r2, shape, device): class DDPM(pl.LightningModule): # classic DDPM with Gaussian diffusion, in image space def __init__( - self, - unet_config, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor="val/loss", - first_stage_key="image", - image_size=256, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0.0, - v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1.0, - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0.0, + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, ): super().__init__() assert parameterization in [ @@ -122,13 +122,13 @@ class DDPM(pl.LightningModule): self.logvar = nn.Parameter(self.logvar, requires_grad=True) def register_schedule( - self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, ): if given_betas is not None: betas = given_betas @@ -149,7 +149,7 @@ class DDPM(pl.LightningModule): self.linear_start = linear_start self.linear_end = linear_end assert ( - alphas_cumprod.shape[0] == self.num_timesteps + alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) @@ -175,7 +175,7 @@ class DDPM(pl.LightningModule): # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * ( - 1.0 - alphas_cumprod_prev + 1.0 - alphas_cumprod_prev ) / (1.0 - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer("posterior_variance", to_torch(posterior_variance)) @@ -196,17 +196,17 @@ class DDPM(pl.LightningModule): ) if self.parameterization == "eps": - lvlb_weights = self.betas ** 2 / ( - 2 - * self.posterior_variance - * to_torch(alphas) - * (1 - self.alphas_cumprod) + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) ) elif self.parameterization == "x0": lvlb_weights = ( - 0.5 - * np.sqrt(torch.Tensor(alphas_cumprod)) - / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) ) else: raise NotImplementedError("mu not supported") @@ -216,26 +216,27 @@ class DDPM(pl.LightningModule): assert not torch.isnan(self.lvlb_weights).all() - class LatentDiffusion(DDPM): """main class""" def __init__( - self, - first_stage_config, - cond_stage_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - *args, - **kwargs, + self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, + **kwargs, ): - self.num_timesteps_cond = 1 if num_timesteps_cond is None else num_timesteps_cond + self.num_timesteps_cond = ( + 1 if num_timesteps_cond is None else num_timesteps_cond + ) self.scale_by_std = scale_by_std assert self.num_timesteps_cond <= kwargs["timesteps"] # for backwards compatibility after implementation of DiffusionWrapper @@ -269,7 +270,7 @@ class LatentDiffusion(DDPM): self.restarted_from_ckpt = True def make_cond_schedule( - self, + self, ): self.cond_ids = torch.full( size=(self.num_timesteps,), @@ -282,13 +283,13 @@ class LatentDiffusion(DDPM): self.cond_ids[: self.num_timesteps_cond] = ids def register_schedule( - self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, ): super().register_schedule( given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s @@ -327,7 +328,7 @@ class LatentDiffusion(DDPM): self.cond_stage_model = model def _get_denoise_row_from_list( - self, samples, desc="", force_no_decoder_quantization=False + self, samples, desc="", force_no_decoder_quantization=False ): denoise_row = [] for zd in tqdm(samples, desc=desc): @@ -357,7 +358,7 @@ class LatentDiffusion(DDPM): def get_learned_conditioning(self, c): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, "encode") and callable( - self.cond_stage_model.encode + self.cond_stage_model.encode ): c = self.cond_stage_model.encode(c) if isinstance(c, DiagonalGaussianDistribution): @@ -414,7 +415,7 @@ class LatentDiffusion(DDPM): return weighting def get_fold_unfold( - self, x, kernel_size, stride, uf=1, df=1 + self, x, kernel_size, stride, uf=1, df=1 ): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) @@ -499,14 +500,14 @@ class LatentDiffusion(DDPM): @torch.no_grad() def get_input( - self, - batch, - k, - return_first_stage_outputs=False, - force_c_encode=False, - cond_key=None, - return_original_cond=False, - bs=None, + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, ): x = super().get_input(batch, k) if bs is not None: @@ -631,6 +632,52 @@ class LatentDiffusion(DDPM): else: return self.first_stage_model.decode(z) + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params["original_image_size"] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold( + x, ks, stride, df=df + ) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) + + output_list = [ + self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -664,8 +711,8 @@ class LatentDiffusion(DDPM): z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] if ( - self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"] - and self.model.conditioning_key + self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"] + and self.model.conditioning_key ): # todo check for completeness c_key = next(iter(cond.keys())) # get key c = next(iter(cond.values())) # get value @@ -681,7 +728,7 @@ class LatentDiffusion(DDPM): elif self.cond_stage_key == "coordinates_bbox": assert ( - "original_image_size" in self.split_input_params + "original_image_size" in self.split_input_params ), "BoudingBoxRescaling is missing original_image_size" # assuming padding of unfold is always 0 and its dilation is always 1 @@ -776,16 +823,16 @@ class LatentDiffusion(DDPM): return x_recon def p_mean_variance( - self, - x, - c, - t, - clip_denoised: bool, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - score_corrector=None, - corrector_kwargs=None, + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, ): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) @@ -822,19 +869,19 @@ class LatentDiffusion(DDPM): @torch.no_grad() def p_sample( - self, - x, - c, - t, - clip_denoised=False, - repeat_noise=False, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, ): b, *_, device = *x.shape, x.device outputs = self.p_mean_variance( @@ -864,7 +911,7 @@ class LatentDiffusion(DDPM): if return_codebook_ids: return model_mean + nonzero_mask * ( - 0.5 * model_log_variance + 0.5 * model_log_variance ).exp() * noise, logits.argmax(dim=1) if return_x0: return ( diff --git a/imaginairy/modules/clip_embedders.py b/imaginairy/modules/clip_embedders.py index 0bd4b79..a128ab1 100644 --- a/imaginairy/modules/clip_embedders.py +++ b/imaginairy/modules/clip_embedders.py @@ -105,13 +105,13 @@ class FrozenClipImageEmbedder(nn.Module): def __init__( self, - model, + model_name, jit=False, device=get_device(), antialias=False, ): super().__init__() - self.model, _ = clip.load(name=model, device=device, jit=jit) + self.model, preprocess = clip.load(name=model_name, device=device, jit=jit) self.antialias = antialias diff --git a/imaginairy/modules/diffusionmodules/model.py b/imaginairy/modules/diffusionmodules/model.py index 34b2667..95e68c2 100644 --- a/imaginairy/modules/diffusionmodules/model.py +++ b/imaginairy/modules/diffusionmodules/model.py @@ -80,13 +80,13 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, ): super().__init__() self.in_channels = in_channels @@ -204,22 +204,22 @@ def make_attn(in_channels, attn_type="vanilla"): class Encoder(nn.Module): def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - **ignore_kwargs, + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, ): super().__init__() if use_linear_attn: @@ -321,23 +321,23 @@ class Encoder(nn.Module): class Decoder(nn.Module): def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type="vanilla", - **ignorekwargs, + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, ): super().__init__() if use_linear_attn: @@ -567,24 +567,24 @@ class Resize(nn.Module): class FirstStagePostProcessor(nn.Module): def __init__( - self, - ch_mult: list, - in_channels, - pretrained_model: nn.Module = None, - reshape=False, - n_channels=None, - dropout=0.0, - pretrained_config=None, + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, ): super().__init__() if pretrained_config is None: assert ( - pretrained_model is not None + pretrained_model is not None ), 'Either "pretrained_model" or "pretrained_config" must not be None' self.pretrained_model = pretrained_model else: assert ( - pretrained_config is not None + pretrained_config is not None ), 'Either "pretrained_model" or "pretrained_config" must not be None' self.instantiate_pretrained(pretrained_config) diff --git a/imaginairy/modules/diffusionmodules/util.py b/imaginairy/modules/diffusionmodules/util.py index c48c233..91e1eb7 100644 --- a/imaginairy/modules/diffusionmodules/util.py +++ b/imaginairy/modules/diffusionmodules/util.py @@ -9,9 +9,10 @@ import math + +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import repeat from imaginairy.utils import instantiate_from_config @@ -52,12 +53,23 @@ def make_beta_schedule( return betas.numpy() +def frange(start, stop, step): + """range but handles floats""" + x = start + while True: + if x >= stop: + return + yield x + x += step + + def make_ddim_timesteps( ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True ): if ddim_discr_method == "uniform": - c = num_ddpm_timesteps // num_ddim_timesteps - ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + c = num_ddpm_timesteps / num_ddim_timesteps + ddim_timesteps = [int(i) for i in frange(0, num_ddpm_timesteps - 1, c)] + ddim_timesteps = np.asarray(ddim_timesteps) elif ddim_discr_method == "quad": ddim_timesteps = ( (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2