feature: image prompts

pull/1/head
Bryce 2 years ago
parent 0835b2db16
commit f782fac570

@ -9,4 +9,11 @@ AI imagined images.
- LDM - Latent Diffusion
- Stable Diffusion
-
# Todo
- add tests
- add docs
- remove yaml config
- deploy to pypi
- add image describe feature

@ -5,6 +5,7 @@ import re
import subprocess
from contextlib import nullcontext
import PIL
import numpy as np
import torch
from PIL import Image
@ -184,6 +185,8 @@ class ImaginePrompt:
seed=None,
prompt_strength=7.5,
sampler_type="PLMS",
init_image=None,
init_image_strength=0.3,
steps=50,
height=512,
width=512,
@ -196,6 +199,8 @@ class ImaginePrompt:
self.prompts = [WeightedPrompt(prompt, 1)]
else:
self.prompts = prompt
self.init_image = init_image
self.init_image_strength = init_image_strength
self.prompts.sort(key=lambda p: p.weight, reverse=True)
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
self.prompt_strength = prompt_strength
@ -214,6 +219,20 @@ class ImaginePrompt:
return "|".join(str(p) for p in self.prompts)
def load_img(path, max_height=512, max_width=512):
image = Image.open(path).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0, w, h
def imagine(
prompts,
config="data/stable-diffusion-v1.yaml",
@ -254,7 +273,6 @@ def imagine(
for wp in prompt.prompts
]
)
# c = model.get_learned_conditioning(prompt.prompt_text)
shape = [
latent_channels,
@ -263,24 +281,57 @@ def imagine(
]
def img_callback(samples, i):
return
pass
samples = model.decode_first_stage(samples)
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
steps_path = os.path.join(
sample_path, "steps", f"{base_count:08}_S{prompt.seed}"
)
os.makedirs(steps_path, exist_ok=True)
for pred_x0 in samples:
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
filename = f"{base_count:08}_S{seed}_step{i:04}.jpg"
filename = f"{base_count:08}_S{prompt.seed}_step{i:04}.jpg"
Image.fromarray(pred_x0.astype(np.uint8)).save(
os.path.join(sample_path, filename)
os.path.join(steps_path, filename)
)
start_code = None
if fixed_code:
start_code = torch.randn(
[1, latent_channels, prompt.height, prompt.width],
device=get_device(),
)
# if fixed_code:
# start_code = torch.randn(
# [1, latent_channels, prompt.height, prompt.width],
# device=get_device(),
# )
sampler = get_sampler(prompt.sampler_type, model)
samples_ddim, _ = sampler.sample(
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(
ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False
)
t_enc = int(generation_strength * ddim_steps)
init_image, w, h = load_img(prompt.init_image)
init_image = init_image.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image)
)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc]).to(get_device())
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=img_callback,
)
else:
samples, _ = sampler.sample(
S=prompt.steps,
conditioning=c,
batch_size=1,
@ -293,11 +344,11 @@ def imagine(
img_callback=img_callback,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not skip_save:
for x_sample in x_samples_ddim:
for x_sample in x_samples:
x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), "c h w -> h w c"
)

@ -61,7 +61,6 @@ class VQModel(pl.LightningModule):
self.lr_g_factor = lr_g_factor
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)

@ -218,7 +218,7 @@ class DDIMSampler:
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(
img, pred_x0 = self.p_sample_ddim(
img,
cond,
ts,
@ -232,7 +232,6 @@ class DDIMSampler:
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
@ -341,6 +340,7 @@ class DDIMSampler:
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
img_callback=None,
):
timesteps = (
@ -361,7 +361,7 @@ class DDIMSampler:
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
)
x_dec, _ = self.p_sample_ddim(
x_dec, pred_x0 = self.p_sample_ddim(
x_dec,
cond,
ts,
@ -370,4 +370,6 @@ class DDIMSampler:
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
if img_callback:
img_callback(pred_x0, i)
return x_dec

@ -196,7 +196,7 @@ class DDPM(pl.LightningModule):
)
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
lvlb_weights = self.betas**2 / (
2
* self.posterior_variance
* to_torch(alphas)
@ -216,7 +216,6 @@ class DDPM(pl.LightningModule):
assert not torch.isnan(self.lvlb_weights).all()
class LatentDiffusion(DDPM):
"""main class"""
@ -235,7 +234,9 @@ class LatentDiffusion(DDPM):
*args,
**kwargs,
):
self.num_timesteps_cond = 1 if num_timesteps_cond is None else num_timesteps_cond
self.num_timesteps_cond = (
1 if num_timesteps_cond is None else num_timesteps_cond
)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs["timesteps"]
# for backwards compatibility after implementation of DiffusionWrapper
@ -631,6 +632,52 @@ class LatentDiffusion(DDPM):
else:
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params["original_image_size"] = x.shape[-2:]
bs, nc, h, w = x.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(
x, ks, stride, df=df
)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
output_list = [
self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
else:
return self.first_stage_model.encode(x)
else:
return self.first_stage_model.encode(x)
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):

@ -105,13 +105,13 @@ class FrozenClipImageEmbedder(nn.Module):
def __init__(
self,
model,
model_name,
jit=False,
device=get_device(),
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
self.model, preprocess = clip.load(name=model_name, device=device, jit=jit)
self.antialias = antialias

@ -9,9 +9,10 @@
import math
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from imaginairy.utils import instantiate_from_config
@ -52,12 +53,23 @@ def make_beta_schedule(
return betas.numpy()
def frange(start, stop, step):
"""range but handles floats"""
x = start
while True:
if x >= stop:
return
yield x
x += step
def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
c = num_ddpm_timesteps / num_ddim_timesteps
ddim_timesteps = [int(i) for i in frange(0, num_ddpm_timesteps - 1, c)]
ddim_timesteps = np.asarray(ddim_timesteps)
elif ddim_discr_method == "quad":
ddim_timesteps = (
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2

Loading…
Cancel
Save