feature: image prompts

pull/1/head
Bryce 2 years ago
parent 0835b2db16
commit f782fac570

@ -9,4 +9,11 @@ AI imagined images.
- LDM - Latent Diffusion
- Stable Diffusion
-
# Todo
- add tests
- add docs
- remove yaml config
- deploy to pypi
- add image describe feature

@ -5,6 +5,7 @@ import re
import subprocess
from contextlib import nullcontext
import PIL
import numpy as np
import torch
from PIL import Image
@ -184,6 +185,8 @@ class ImaginePrompt:
seed=None,
prompt_strength=7.5,
sampler_type="PLMS",
init_image=None,
init_image_strength=0.3,
steps=50,
height=512,
width=512,
@ -196,6 +199,8 @@ class ImaginePrompt:
self.prompts = [WeightedPrompt(prompt, 1)]
else:
self.prompts = prompt
self.init_image = init_image
self.init_image_strength = init_image_strength
self.prompts.sort(key=lambda p: p.weight, reverse=True)
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
self.prompt_strength = prompt_strength
@ -214,6 +219,20 @@ class ImaginePrompt:
return "|".join(str(p) for p in self.prompts)
def load_img(path, max_height=512, max_width=512):
image = Image.open(path).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0, w, h
def imagine(
prompts,
config="data/stable-diffusion-v1.yaml",
@ -254,7 +273,6 @@ def imagine(
for wp in prompt.prompts
]
)
# c = model.get_learned_conditioning(prompt.prompt_text)
shape = [
latent_channels,
@ -263,41 +281,74 @@ def imagine(
]
def img_callback(samples, i):
return
pass
samples = model.decode_first_stage(samples)
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
steps_path = os.path.join(
sample_path, "steps", f"{base_count:08}_S{prompt.seed}"
)
os.makedirs(steps_path, exist_ok=True)
for pred_x0 in samples:
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
filename = f"{base_count:08}_S{seed}_step{i:04}.jpg"
filename = f"{base_count:08}_S{prompt.seed}_step{i:04}.jpg"
Image.fromarray(pred_x0.astype(np.uint8)).save(
os.path.join(sample_path, filename)
os.path.join(steps_path, filename)
)
start_code = None
if fixed_code:
start_code = torch.randn(
[1, latent_channels, prompt.height, prompt.width],
device=get_device(),
)
# if fixed_code:
# start_code = torch.randn(
# [1, latent_channels, prompt.height, prompt.width],
# device=get_device(),
# )
sampler = get_sampler(prompt.sampler_type, model)
samples_ddim, _ = sampler.sample(
S=prompt.steps,
conditioning=c,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code,
img_callback=img_callback,
)
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(
ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False
)
t_enc = int(generation_strength * ddim_steps)
init_image, w, h = load_img(prompt.init_image)
init_image = init_image.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image)
)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc]).to(get_device())
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=img_callback,
)
else:
samples, _ = sampler.sample(
S=prompt.steps,
conditioning=c,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code,
img_callback=img_callback,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not skip_save:
for x_sample in x_samples_ddim:
for x_sample in x_samples:
x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), "c h w -> h w c"
)

@ -61,7 +61,6 @@ class VQModel(pl.LightningModule):
self.lr_g_factor = lr_g_factor
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)

@ -218,7 +218,7 @@ class DDIMSampler:
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(
img, pred_x0 = self.p_sample_ddim(
img,
cond,
ts,
@ -232,7 +232,6 @@ class DDIMSampler:
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
@ -341,6 +340,7 @@ class DDIMSampler:
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
img_callback=None,
):
timesteps = (
@ -361,7 +361,7 @@ class DDIMSampler:
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
)
x_dec, _ = self.p_sample_ddim(
x_dec, pred_x0 = self.p_sample_ddim(
x_dec,
cond,
ts,
@ -370,4 +370,6 @@ class DDIMSampler:
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
if img_callback:
img_callback(pred_x0, i)
return x_dec

@ -43,33 +43,33 @@ def uniform_on_device(r1, r2, shape, device):
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(
self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.0,
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.0,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.0,
self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.0,
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.0,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.0,
):
super().__init__()
assert parameterization in [
@ -122,13 +122,13 @@ class DDPM(pl.LightningModule):
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
if given_betas is not None:
betas = given_betas
@ -149,7 +149,7 @@ class DDPM(pl.LightningModule):
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
@ -175,7 +175,7 @@ class DDPM(pl.LightningModule):
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev
1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
@ -196,17 +196,17 @@ class DDPM(pl.LightningModule):
)
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
lvlb_weights = self.betas**2 / (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
elif self.parameterization == "x0":
lvlb_weights = (
0.5
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
0.5
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
)
else:
raise NotImplementedError("mu not supported")
@ -216,26 +216,27 @@ class DDPM(pl.LightningModule):
assert not torch.isnan(self.lvlb_weights).all()
class LatentDiffusion(DDPM):
"""main class"""
def __init__(
self,
first_stage_config,
cond_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
*args,
**kwargs,
self,
first_stage_config,
cond_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
*args,
**kwargs,
):
self.num_timesteps_cond = 1 if num_timesteps_cond is None else num_timesteps_cond
self.num_timesteps_cond = (
1 if num_timesteps_cond is None else num_timesteps_cond
)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs["timesteps"]
# for backwards compatibility after implementation of DiffusionWrapper
@ -269,7 +270,7 @@ class LatentDiffusion(DDPM):
self.restarted_from_ckpt = True
def make_cond_schedule(
self,
self,
):
self.cond_ids = torch.full(
size=(self.num_timesteps,),
@ -282,13 +283,13 @@ class LatentDiffusion(DDPM):
self.cond_ids[: self.num_timesteps_cond] = ids
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
super().register_schedule(
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
@ -327,7 +328,7 @@ class LatentDiffusion(DDPM):
self.cond_stage_model = model
def _get_denoise_row_from_list(
self, samples, desc="", force_no_decoder_quantization=False
self, samples, desc="", force_no_decoder_quantization=False
):
denoise_row = []
for zd in tqdm(samples, desc=desc):
@ -357,7 +358,7 @@ class LatentDiffusion(DDPM):
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, "encode") and callable(
self.cond_stage_model.encode
self.cond_stage_model.encode
):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
@ -414,7 +415,7 @@ class LatentDiffusion(DDPM):
return weighting
def get_fold_unfold(
self, x, kernel_size, stride, uf=1, df=1
self, x, kernel_size, stride, uf=1, df=1
): # todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w)
@ -499,14 +500,14 @@ class LatentDiffusion(DDPM):
@torch.no_grad()
def get_input(
self,
batch,
k,
return_first_stage_outputs=False,
force_c_encode=False,
cond_key=None,
return_original_cond=False,
bs=None,
self,
batch,
k,
return_first_stage_outputs=False,
force_c_encode=False,
cond_key=None,
return_original_cond=False,
bs=None,
):
x = super().get_input(batch, k)
if bs is not None:
@ -631,6 +632,52 @@ class LatentDiffusion(DDPM):
else:
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params["original_image_size"] = x.shape[-2:]
bs, nc, h, w = x.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(
x, ks, stride, df=df
)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
output_list = [
self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
else:
return self.first_stage_model.encode(x)
else:
return self.first_stage_model.encode(x)
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
@ -664,8 +711,8 @@ class LatentDiffusion(DDPM):
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
if (
self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"]
and self.model.conditioning_key
self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"]
and self.model.conditioning_key
): # todo check for completeness
c_key = next(iter(cond.keys())) # get key
c = next(iter(cond.values())) # get value
@ -681,7 +728,7 @@ class LatentDiffusion(DDPM):
elif self.cond_stage_key == "coordinates_bbox":
assert (
"original_image_size" in self.split_input_params
"original_image_size" in self.split_input_params
), "BoudingBoxRescaling is missing original_image_size"
# assuming padding of unfold is always 0 and its dilation is always 1
@ -776,16 +823,16 @@ class LatentDiffusion(DDPM):
return x_recon
def p_mean_variance(
self,
x,
c,
t,
clip_denoised: bool,
return_codebook_ids=False,
quantize_denoised=False,
return_x0=False,
score_corrector=None,
corrector_kwargs=None,
self,
x,
c,
t,
clip_denoised: bool,
return_codebook_ids=False,
quantize_denoised=False,
return_x0=False,
score_corrector=None,
corrector_kwargs=None,
):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
@ -822,19 +869,19 @@ class LatentDiffusion(DDPM):
@torch.no_grad()
def p_sample(
self,
x,
c,
t,
clip_denoised=False,
repeat_noise=False,
return_codebook_ids=False,
quantize_denoised=False,
return_x0=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
self,
x,
c,
t,
clip_denoised=False,
repeat_noise=False,
return_codebook_ids=False,
quantize_denoised=False,
return_x0=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(
@ -864,7 +911,7 @@ class LatentDiffusion(DDPM):
if return_codebook_ids:
return model_mean + nonzero_mask * (
0.5 * model_log_variance
0.5 * model_log_variance
).exp() * noise, logits.argmax(dim=1)
if return_x0:
return (

@ -105,13 +105,13 @@ class FrozenClipImageEmbedder(nn.Module):
def __init__(
self,
model,
model_name,
jit=False,
device=get_device(),
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
self.model, preprocess = clip.load(name=model_name, device=device, jit=jit)
self.antialias = antialias

@ -80,13 +80,13 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
):
super().__init__()
self.in_channels = in_channels
@ -204,22 +204,22 @@ def make_attn(in_channels, attn_type="vanilla"):
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
**ignore_kwargs,
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
**ignore_kwargs,
):
super().__init__()
if use_linear_attn:
@ -321,23 +321,23 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs,
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs,
):
super().__init__()
if use_linear_attn:
@ -567,24 +567,24 @@ class Resize(nn.Module):
class FirstStagePostProcessor(nn.Module):
def __init__(
self,
ch_mult: list,
in_channels,
pretrained_model: nn.Module = None,
reshape=False,
n_channels=None,
dropout=0.0,
pretrained_config=None,
self,
ch_mult: list,
in_channels,
pretrained_model: nn.Module = None,
reshape=False,
n_channels=None,
dropout=0.0,
pretrained_config=None,
):
super().__init__()
if pretrained_config is None:
assert (
pretrained_model is not None
pretrained_model is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model
else:
assert (
pretrained_config is not None
pretrained_config is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config)

@ -9,9 +9,10 @@
import math
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from imaginairy.utils import instantiate_from_config
@ -52,12 +53,23 @@ def make_beta_schedule(
return betas.numpy()
def frange(start, stop, step):
"""range but handles floats"""
x = start
while True:
if x >= stop:
return
yield x
x += step
def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
c = num_ddpm_timesteps / num_ddim_timesteps
ddim_timesteps = [int(i) for i in frange(0, num_ddpm_timesteps - 1, c)]
ddim_timesteps = np.asarray(ddim_timesteps)
elif ddim_discr_method == "quad":
ddim_timesteps = (
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2

Loading…
Cancel
Save