feature: image prompts

pull/1/head
Bryce 2 years ago
parent 0835b2db16
commit f782fac570

@ -9,4 +9,11 @@ AI imagined images.
- LDM - Latent Diffusion - LDM - Latent Diffusion
- Stable Diffusion - Stable Diffusion
-
# Todo
- add tests
- add docs
- remove yaml config
- deploy to pypi
- add image describe feature

@ -5,6 +5,7 @@ import re
import subprocess import subprocess
from contextlib import nullcontext from contextlib import nullcontext
import PIL
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
@ -184,6 +185,8 @@ class ImaginePrompt:
seed=None, seed=None,
prompt_strength=7.5, prompt_strength=7.5,
sampler_type="PLMS", sampler_type="PLMS",
init_image=None,
init_image_strength=0.3,
steps=50, steps=50,
height=512, height=512,
width=512, width=512,
@ -196,6 +199,8 @@ class ImaginePrompt:
self.prompts = [WeightedPrompt(prompt, 1)] self.prompts = [WeightedPrompt(prompt, 1)]
else: else:
self.prompts = prompt self.prompts = prompt
self.init_image = init_image
self.init_image_strength = init_image_strength
self.prompts.sort(key=lambda p: p.weight, reverse=True) self.prompts.sort(key=lambda p: p.weight, reverse=True)
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
self.prompt_strength = prompt_strength self.prompt_strength = prompt_strength
@ -214,6 +219,20 @@ class ImaginePrompt:
return "|".join(str(p) for p in self.prompts) return "|".join(str(p) for p in self.prompts)
def load_img(path, max_height=512, max_width=512):
image = Image.open(path).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0, w, h
def imagine( def imagine(
prompts, prompts,
config="data/stable-diffusion-v1.yaml", config="data/stable-diffusion-v1.yaml",
@ -254,7 +273,6 @@ def imagine(
for wp in prompt.prompts for wp in prompt.prompts
] ]
) )
# c = model.get_learned_conditioning(prompt.prompt_text)
shape = [ shape = [
latent_channels, latent_channels,
@ -263,41 +281,74 @@ def imagine(
] ]
def img_callback(samples, i): def img_callback(samples, i):
return pass
samples = model.decode_first_stage(samples) samples = model.decode_first_stage(samples)
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0) samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
steps_path = os.path.join(
sample_path, "steps", f"{base_count:08}_S{prompt.seed}"
)
os.makedirs(steps_path, exist_ok=True)
for pred_x0 in samples: for pred_x0 in samples:
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c") pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
filename = f"{base_count:08}_S{seed}_step{i:04}.jpg" filename = f"{base_count:08}_S{prompt.seed}_step{i:04}.jpg"
Image.fromarray(pred_x0.astype(np.uint8)).save( Image.fromarray(pred_x0.astype(np.uint8)).save(
os.path.join(sample_path, filename) os.path.join(steps_path, filename)
) )
start_code = None start_code = None
if fixed_code: # if fixed_code:
start_code = torch.randn( # start_code = torch.randn(
[1, latent_channels, prompt.height, prompt.width], # [1, latent_channels, prompt.height, prompt.width],
device=get_device(), # device=get_device(),
) # )
sampler = get_sampler(prompt.sampler_type, model) sampler = get_sampler(prompt.sampler_type, model)
samples_ddim, _ = sampler.sample( if prompt.init_image:
S=prompt.steps, generation_strength = 1 - prompt.init_image_strength
conditioning=c, ddim_steps = int(prompt.steps / generation_strength)
batch_size=1, sampler.make_schedule(
shape=shape, ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False
verbose=False, )
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc, t_enc = int(generation_strength * ddim_steps)
eta=ddim_eta, init_image, w, h = load_img(prompt.init_image)
x_T=start_code, init_image = init_image.to(get_device())
img_callback=img_callback, init_latent = model.get_first_stage_encoding(
) model.encode_first_stage(init_image)
)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc]).to(get_device())
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=img_callback,
)
else:
samples, _ = sampler.sample(
S=prompt.steps,
conditioning=c,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code,
img_callback=img_callback,
)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples = model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not skip_save: if not skip_save:
for x_sample in x_samples_ddim: for x_sample in x_samples:
x_sample = 255.0 * rearrange( x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), "c h w -> h w c" x_sample.cpu().numpy(), "c h w -> h w c"
) )

@ -61,7 +61,6 @@ class VQModel(pl.LightningModule):
self.lr_g_factor = lr_g_factor self.lr_g_factor = lr_g_factor
class VQModelInterface(VQModel): class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs): def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs) super().__init__(embed_dim=embed_dim, *args, **kwargs)

@ -218,7 +218,7 @@ class DDIMSampler:
) # TODO: deterministic forward pass? ) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim( img, pred_x0 = self.p_sample_ddim(
img, img,
cond, cond,
ts, ts,
@ -232,7 +232,6 @@ class DDIMSampler:
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
) )
img, pred_x0 = outs
if callback: if callback:
callback(i) callback(i)
if img_callback: if img_callback:
@ -341,6 +340,7 @@ class DDIMSampler:
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
use_original_steps=False, use_original_steps=False,
img_callback=None,
): ):
timesteps = ( timesteps = (
@ -361,7 +361,7 @@ class DDIMSampler:
ts = torch.full( ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
) )
x_dec, _ = self.p_sample_ddim( x_dec, pred_x0 = self.p_sample_ddim(
x_dec, x_dec,
cond, cond,
ts, ts,
@ -370,4 +370,6 @@ class DDIMSampler:
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
) )
if img_callback:
img_callback(pred_x0, i)
return x_dec return x_dec

@ -43,33 +43,33 @@ def uniform_on_device(r1, r2, shape, device):
class DDPM(pl.LightningModule): class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space # classic DDPM with Gaussian diffusion, in image space
def __init__( def __init__(
self, self,
unet_config, unet_config,
timesteps=1000, timesteps=1000,
beta_schedule="linear", beta_schedule="linear",
loss_type="l2", loss_type="l2",
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=[],
load_only_unet=False, load_only_unet=False,
monitor="val/loss", monitor="val/loss",
first_stage_key="image", first_stage_key="image",
image_size=256, image_size=256,
channels=3, channels=3,
log_every_t=100, log_every_t=100,
clip_denoised=True, clip_denoised=True,
linear_start=1e-4, linear_start=1e-4,
linear_end=2e-2, linear_end=2e-2,
cosine_s=8e-3, cosine_s=8e-3,
given_betas=None, given_betas=None,
original_elbo_weight=0.0, original_elbo_weight=0.0,
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.0, l_simple_weight=1.0,
conditioning_key=None, conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None, scheduler_config=None,
use_positional_encodings=False, use_positional_encodings=False,
learn_logvar=False, learn_logvar=False,
logvar_init=0.0, logvar_init=0.0,
): ):
super().__init__() super().__init__()
assert parameterization in [ assert parameterization in [
@ -122,13 +122,13 @@ class DDPM(pl.LightningModule):
self.logvar = nn.Parameter(self.logvar, requires_grad=True) self.logvar = nn.Parameter(self.logvar, requires_grad=True)
def register_schedule( def register_schedule(
self, self,
given_betas=None, given_betas=None,
beta_schedule="linear", beta_schedule="linear",
timesteps=1000, timesteps=1000,
linear_start=1e-4, linear_start=1e-4,
linear_end=2e-2, linear_end=2e-2,
cosine_s=8e-3, cosine_s=8e-3,
): ):
if given_betas is not None: if given_betas is not None:
betas = given_betas betas = given_betas
@ -149,7 +149,7 @@ class DDPM(pl.LightningModule):
self.linear_start = linear_start self.linear_start = linear_start
self.linear_end = linear_end self.linear_end = linear_end
assert ( assert (
alphas_cumprod.shape[0] == self.num_timesteps alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep" ), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32) to_torch = partial(torch.tensor, dtype=torch.float32)
@ -175,7 +175,7 @@ class DDPM(pl.LightningModule):
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * ( posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev 1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance)) self.register_buffer("posterior_variance", to_torch(posterior_variance))
@ -196,17 +196,17 @@ class DDPM(pl.LightningModule):
) )
if self.parameterization == "eps": if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / ( lvlb_weights = self.betas**2 / (
2 2
* self.posterior_variance * self.posterior_variance
* to_torch(alphas) * to_torch(alphas)
* (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod)
) )
elif self.parameterization == "x0": elif self.parameterization == "x0":
lvlb_weights = ( lvlb_weights = (
0.5 0.5
* np.sqrt(torch.Tensor(alphas_cumprod)) * np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
) )
else: else:
raise NotImplementedError("mu not supported") raise NotImplementedError("mu not supported")
@ -216,26 +216,27 @@ class DDPM(pl.LightningModule):
assert not torch.isnan(self.lvlb_weights).all() assert not torch.isnan(self.lvlb_weights).all()
class LatentDiffusion(DDPM): class LatentDiffusion(DDPM):
"""main class""" """main class"""
def __init__( def __init__(
self, self,
first_stage_config, first_stage_config,
cond_stage_config, cond_stage_config,
num_timesteps_cond=None, num_timesteps_cond=None,
cond_stage_key="image", cond_stage_key="image",
cond_stage_trainable=False, cond_stage_trainable=False,
concat_mode=True, concat_mode=True,
cond_stage_forward=None, cond_stage_forward=None,
conditioning_key=None, conditioning_key=None,
scale_factor=1.0, scale_factor=1.0,
scale_by_std=False, scale_by_std=False,
*args, *args,
**kwargs, **kwargs,
): ):
self.num_timesteps_cond = 1 if num_timesteps_cond is None else num_timesteps_cond self.num_timesteps_cond = (
1 if num_timesteps_cond is None else num_timesteps_cond
)
self.scale_by_std = scale_by_std self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs["timesteps"] assert self.num_timesteps_cond <= kwargs["timesteps"]
# for backwards compatibility after implementation of DiffusionWrapper # for backwards compatibility after implementation of DiffusionWrapper
@ -269,7 +270,7 @@ class LatentDiffusion(DDPM):
self.restarted_from_ckpt = True self.restarted_from_ckpt = True
def make_cond_schedule( def make_cond_schedule(
self, self,
): ):
self.cond_ids = torch.full( self.cond_ids = torch.full(
size=(self.num_timesteps,), size=(self.num_timesteps,),
@ -282,13 +283,13 @@ class LatentDiffusion(DDPM):
self.cond_ids[: self.num_timesteps_cond] = ids self.cond_ids[: self.num_timesteps_cond] = ids
def register_schedule( def register_schedule(
self, self,
given_betas=None, given_betas=None,
beta_schedule="linear", beta_schedule="linear",
timesteps=1000, timesteps=1000,
linear_start=1e-4, linear_start=1e-4,
linear_end=2e-2, linear_end=2e-2,
cosine_s=8e-3, cosine_s=8e-3,
): ):
super().register_schedule( super().register_schedule(
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
@ -327,7 +328,7 @@ class LatentDiffusion(DDPM):
self.cond_stage_model = model self.cond_stage_model = model
def _get_denoise_row_from_list( def _get_denoise_row_from_list(
self, samples, desc="", force_no_decoder_quantization=False self, samples, desc="", force_no_decoder_quantization=False
): ):
denoise_row = [] denoise_row = []
for zd in tqdm(samples, desc=desc): for zd in tqdm(samples, desc=desc):
@ -357,7 +358,7 @@ class LatentDiffusion(DDPM):
def get_learned_conditioning(self, c): def get_learned_conditioning(self, c):
if self.cond_stage_forward is None: if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, "encode") and callable( if hasattr(self.cond_stage_model, "encode") and callable(
self.cond_stage_model.encode self.cond_stage_model.encode
): ):
c = self.cond_stage_model.encode(c) c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution): if isinstance(c, DiagonalGaussianDistribution):
@ -414,7 +415,7 @@ class LatentDiffusion(DDPM):
return weighting return weighting
def get_fold_unfold( def get_fold_unfold(
self, x, kernel_size, stride, uf=1, df=1 self, x, kernel_size, stride, uf=1, df=1
): # todo load once not every time, shorten code ): # todo load once not every time, shorten code
""" """
:param x: img of size (bs, c, h, w) :param x: img of size (bs, c, h, w)
@ -499,14 +500,14 @@ class LatentDiffusion(DDPM):
@torch.no_grad() @torch.no_grad()
def get_input( def get_input(
self, self,
batch, batch,
k, k,
return_first_stage_outputs=False, return_first_stage_outputs=False,
force_c_encode=False, force_c_encode=False,
cond_key=None, cond_key=None,
return_original_cond=False, return_original_cond=False,
bs=None, bs=None,
): ):
x = super().get_input(batch, k) x = super().get_input(batch, k)
if bs is not None: if bs is not None:
@ -631,6 +632,52 @@ class LatentDiffusion(DDPM):
else: else:
return self.first_stage_model.decode(z) return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params["original_image_size"] = x.shape[-2:]
bs, nc, h, w = x.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(
x, ks, stride, df=df
)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
output_list = [
self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
else:
return self.first_stage_model.encode(x)
else:
return self.first_stage_model.encode(x)
def apply_model(self, x_noisy, t, cond, return_ids=False): def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict): if isinstance(cond, dict):
@ -664,8 +711,8 @@ class LatentDiffusion(DDPM):
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
if ( if (
self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"] self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"]
and self.model.conditioning_key and self.model.conditioning_key
): # todo check for completeness ): # todo check for completeness
c_key = next(iter(cond.keys())) # get key c_key = next(iter(cond.keys())) # get key
c = next(iter(cond.values())) # get value c = next(iter(cond.values())) # get value
@ -681,7 +728,7 @@ class LatentDiffusion(DDPM):
elif self.cond_stage_key == "coordinates_bbox": elif self.cond_stage_key == "coordinates_bbox":
assert ( assert (
"original_image_size" in self.split_input_params "original_image_size" in self.split_input_params
), "BoudingBoxRescaling is missing original_image_size" ), "BoudingBoxRescaling is missing original_image_size"
# assuming padding of unfold is always 0 and its dilation is always 1 # assuming padding of unfold is always 0 and its dilation is always 1
@ -776,16 +823,16 @@ class LatentDiffusion(DDPM):
return x_recon return x_recon
def p_mean_variance( def p_mean_variance(
self, self,
x, x,
c, c,
t, t,
clip_denoised: bool, clip_denoised: bool,
return_codebook_ids=False, return_codebook_ids=False,
quantize_denoised=False, quantize_denoised=False,
return_x0=False, return_x0=False,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
): ):
t_in = t t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
@ -822,19 +869,19 @@ class LatentDiffusion(DDPM):
@torch.no_grad() @torch.no_grad()
def p_sample( def p_sample(
self, self,
x, x,
c, c,
t, t,
clip_denoised=False, clip_denoised=False,
repeat_noise=False, repeat_noise=False,
return_codebook_ids=False, return_codebook_ids=False,
quantize_denoised=False, quantize_denoised=False,
return_x0=False, return_x0=False,
temperature=1.0, temperature=1.0,
noise_dropout=0.0, noise_dropout=0.0,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
): ):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance( outputs = self.p_mean_variance(
@ -864,7 +911,7 @@ class LatentDiffusion(DDPM):
if return_codebook_ids: if return_codebook_ids:
return model_mean + nonzero_mask * ( return model_mean + nonzero_mask * (
0.5 * model_log_variance 0.5 * model_log_variance
).exp() * noise, logits.argmax(dim=1) ).exp() * noise, logits.argmax(dim=1)
if return_x0: if return_x0:
return ( return (

@ -105,13 +105,13 @@ class FrozenClipImageEmbedder(nn.Module):
def __init__( def __init__(
self, self,
model, model_name,
jit=False, jit=False,
device=get_device(), device=get_device(),
antialias=False, antialias=False,
): ):
super().__init__() super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit) self.model, preprocess = clip.load(name=model_name, device=device, jit=jit)
self.antialias = antialias self.antialias = antialias

@ -80,13 +80,13 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__( def __init__(
self, self,
*, *,
in_channels, in_channels,
out_channels=None, out_channels=None,
conv_shortcut=False, conv_shortcut=False,
dropout, dropout,
temb_channels=512, temb_channels=512,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -204,22 +204,22 @@ def make_attn(in_channels, attn_type="vanilla"):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__( def __init__(
self, self,
*, *,
ch, ch,
out_ch, out_ch,
ch_mult=(1, 2, 4, 8), ch_mult=(1, 2, 4, 8),
num_res_blocks, num_res_blocks,
attn_resolutions, attn_resolutions,
dropout=0.0, dropout=0.0,
resamp_with_conv=True, resamp_with_conv=True,
in_channels, in_channels,
resolution, resolution,
z_channels, z_channels,
double_z=True, double_z=True,
use_linear_attn=False, use_linear_attn=False,
attn_type="vanilla", attn_type="vanilla",
**ignore_kwargs, **ignore_kwargs,
): ):
super().__init__() super().__init__()
if use_linear_attn: if use_linear_attn:
@ -321,23 +321,23 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, self,
*, *,
ch, ch,
out_ch, out_ch,
ch_mult=(1, 2, 4, 8), ch_mult=(1, 2, 4, 8),
num_res_blocks, num_res_blocks,
attn_resolutions, attn_resolutions,
dropout=0.0, dropout=0.0,
resamp_with_conv=True, resamp_with_conv=True,
in_channels, in_channels,
resolution, resolution,
z_channels, z_channels,
give_pre_end=False, give_pre_end=False,
tanh_out=False, tanh_out=False,
use_linear_attn=False, use_linear_attn=False,
attn_type="vanilla", attn_type="vanilla",
**ignorekwargs, **ignorekwargs,
): ):
super().__init__() super().__init__()
if use_linear_attn: if use_linear_attn:
@ -567,24 +567,24 @@ class Resize(nn.Module):
class FirstStagePostProcessor(nn.Module): class FirstStagePostProcessor(nn.Module):
def __init__( def __init__(
self, self,
ch_mult: list, ch_mult: list,
in_channels, in_channels,
pretrained_model: nn.Module = None, pretrained_model: nn.Module = None,
reshape=False, reshape=False,
n_channels=None, n_channels=None,
dropout=0.0, dropout=0.0,
pretrained_config=None, pretrained_config=None,
): ):
super().__init__() super().__init__()
if pretrained_config is None: if pretrained_config is None:
assert ( assert (
pretrained_model is not None pretrained_model is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None' ), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model self.pretrained_model = pretrained_model
else: else:
assert ( assert (
pretrained_config is not None pretrained_config is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None' ), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config) self.instantiate_pretrained(pretrained_config)

@ -9,9 +9,10 @@
import math import math
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from einops import repeat from einops import repeat
from imaginairy.utils import instantiate_from_config from imaginairy.utils import instantiate_from_config
@ -52,12 +53,23 @@ def make_beta_schedule(
return betas.numpy() return betas.numpy()
def frange(start, stop, step):
"""range but handles floats"""
x = start
while True:
if x >= stop:
return
yield x
x += step
def make_ddim_timesteps( def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
): ):
if ddim_discr_method == "uniform": if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps / num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) ddim_timesteps = [int(i) for i in frange(0, num_ddpm_timesteps - 1, c)]
ddim_timesteps = np.asarray(ddim_timesteps)
elif ddim_discr_method == "quad": elif ddim_discr_method == "quad":
ddim_timesteps = ( ddim_timesteps = (
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2

Loading…
Cancel
Save