feature: add ImageResult. step output option

remove verbose args
This commit is contained in:
Bryce 2022-09-09 22:14:04 -07:00
parent 47c6bcee59
commit 66c640ce7b
9 changed files with 253 additions and 144 deletions

View File

@ -2,7 +2,8 @@
import click
from imaginairy.imagine import ImaginePrompt, imagine as imagine_f
from imaginairy.imagine import imagine as imagine_f
from imaginairy.schema import ImaginePrompt
@click.command()
@ -10,7 +11,9 @@ from imaginairy.imagine import ImaginePrompt, imagine as imagine_f
"prompt_texts", default=None, help="text to render to an image", nargs=-1
)
@click.option("--outdir", default="./outputs", help="where to write results to")
@click.option("-r", "--repeats", default=1, type=int, help="How many times to repeat the renders")
@click.option(
"-r", "--repeats", default=1, type=int, help="How many times to repeat the renders"
)
@click.option(
"-h",
"--height",

View File

@ -1,7 +1,5 @@
import argparse
import logging
import os
import random
import re
import subprocess
from contextlib import nullcontext
@ -18,13 +16,14 @@ from torch import autocast
from imaginairy.models.diffusion.ddim import DDIMSampler
from imaginairy.models.diffusion.plms import PLMSSampler
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import get_device, instantiate_from_config
LIB_PATH = os.path.dirname(__file__)
logger = logging.getLogger(__name__)
def load_model_from_config(config, ckpt, verbose=False):
def load_model_from_config(config, ckpt):
logger.info(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
@ -32,9 +31,9 @@ def load_model_from_config(config, ckpt, verbose=False):
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
if len(m) > 0:
logger.info(f"missing keys: {m}")
if len(u) > 0 and verbose:
if len(u) > 0:
logger.info(f"unexpected keys: {u}")
model.cuda()
@ -42,56 +41,6 @@ def load_model_from_config(config, ckpt, verbose=False):
return model
class WeightedPrompt:
def __init__(self, text, weight=1):
self.text = text
self.weight = weight
def __str__(self):
return f"{self.weight}*({self.text})"
class ImaginePrompt:
def __init__(
self,
prompt=None,
seed=None,
prompt_strength=7.5,
sampler_type="PLMS",
init_image=None,
init_image_strength=0.3,
steps=50,
height=512,
width=512,
upscale=True,
fix_faces=True,
parts=None,
):
prompt = prompt if prompt is not None else "a scenic landscape"
if isinstance(prompt, str):
self.prompts = [WeightedPrompt(prompt, 1)]
else:
self.prompts = prompt
self.init_image = init_image
self.init_image_strength = init_image_strength
self.prompts.sort(key=lambda p: p.weight, reverse=True)
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
self.prompt_strength = prompt_strength
self.sampler_type = sampler_type
self.steps = steps
self.height = height
self.width = width
self.upscale = upscale
self.fix_faces = fix_faces
self.parts = parts or {}
@property
def prompt_text(self):
if len(self.prompts) == 1:
return self.prompts[0].text
return "|".join(str(p) for p in self.prompts)
def load_img(path, max_height=512, max_width=512):
image = Image.open(path).convert("RGB")
w, h = image.size
@ -108,8 +57,8 @@ def load_img(path, max_height=512, max_width=512):
@lru_cache()
def load_model():
config = ("data/stable-diffusion-v1.yaml",)
ckpt = ("data/stable-diffusion-v1-4.ckpt",)
config = "data/stable-diffusion-v1.yaml"
ckpt = "data/stable-diffusion-v1-4.ckpt"
config = OmegaConf.load(f"{LIB_PATH}/{config}")
model = load_model_from_config(config, f"{LIB_PATH}/{ckpt}")
@ -117,24 +66,69 @@ def load_model():
return model
def imagine(
def imagine_image_files(
prompts,
outdir="outputs/txt2img-samples",
outdir,
latent_channels=4,
downsampling_factor=8,
precision="autocast",
skip_save=False,
ddim_eta=0.0,
record_steps=False
):
big_path = os.path.join(outdir, "upscaled")
os.makedirs(outdir, exist_ok=True)
os.makedirs(big_path, exist_ok=True)
base_count = len(os.listdir(outdir))
step_count = 0
def _record_steps(samples, i, model, prompt):
nonlocal step_count
step_count += 1
samples = model.decode_first_stage(samples)
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
os.makedirs(steps_path, exist_ok=True)
for pred_x0 in samples:
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}.jpg"
Image.fromarray(pred_x0.astype(np.uint8)).save(
os.path.join(steps_path, filename)
)
img_callback = _record_steps if record_steps else None
for result in imagine_images(
prompts,
latent_channels=latent_channels,
downsampling_factor=downsampling_factor,
precision=precision,
ddim_eta=ddim_eta,
img_callback=img_callback,
):
prompt = result.prompt
img = result.img
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
filepath = os.path.join(outdir, f"{basefilename}.jpg")
img.save(filepath)
if prompt.upscale:
enlarge_realesrgan2x(
filepath,
os.path.join(big_path, basefilename) + ".jpg",
)
base_count += 1
def imagine_images(
prompts,
latent_channels=4,
downsampling_factor=8,
precision="autocast",
ddim_eta=0.0,
img_callback=None,
):
model = load_model()
os.makedirs(outdir, exist_ok=True)
outpath = outdir
sample_path = os.path.join(outpath)
big_path = os.path.join(sample_path, "esrgan")
os.makedirs(sample_path, exist_ok=True)
os.makedirs(big_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
_img_callback = None
precision_scope = autocast if precision == "autocast" else nullcontext
with (torch.no_grad(), precision_scope("cuda")):
@ -150,6 +144,9 @@ def imagine(
for wp in prompt.prompts
]
)
if img_callback:
def _img_callback(samples, i):
img_callback(samples, i, model, prompt)
shape = [
latent_channels,
@ -157,41 +154,29 @@ def imagine(
prompt.width // downsampling_factor,
]
def img_callback(samples, i):
pass
samples = model.decode_first_stage(samples)
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
steps_path = os.path.join(
sample_path, "steps", f"{base_count:08}_S{prompt.seed}"
)
os.makedirs(steps_path, exist_ok=True)
for pred_x0 in samples:
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
filename = f"{base_count:08}_S{prompt.seed}_step{i:04}.jpg"
Image.fromarray(pred_x0.astype(np.uint8)).save(
os.path.join(steps_path, filename)
)
start_code = None
sampler = get_sampler(prompt.sampler_type, model)
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(
ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False
ddim_num_steps=ddim_steps, ddim_eta=ddim_eta
)
t_enc = int(generation_strength * ddim_steps)
init_image, w, h = load_img(prompt.init_image)
init_image = init_image.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image)
)
init_latent = model.encode_first_stage(init_image)
noised_init_latent = model.get_first_stage_encoding(init_latent)
_img_callback(init_latent.mean, 0)
_img_callback(noised_init_latent, 0)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc]).to(get_device())
noised_init_latent, torch.tensor([t_enc]).to(get_device()),
)
_img_callback(noised_init_latent, 0)
# decode it
samples = sampler.decode(
z_enc,
@ -199,7 +184,7 @@ def imagine(
t_enc,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=img_callback,
img_callback=_img_callback,
)
else:
@ -208,35 +193,27 @@ def imagine(
conditioning=c,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code,
img_callback=img_callback,
img_callback=_img_callback,
)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not skip_save:
for x_sample in x_samples:
x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), "c h w -> h w c"
)
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
filepath = os.path.join(sample_path, f"{basefilename}.jpg")
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
img = Image.fromarray(x_sample.astype(np.uint8))
if prompt.fix_faces:
img = fix_faces_GFPGAN(img)
img.save(filepath)
if prompt.upscale:
enlarge_realesrgan2x(
filepath,
os.path.join(big_path, basefilename) + ".jpg",
)
base_count += 1
return f"{basefilename}.jpg"
# if prompt.upscale:
# enlarge_realesrgan2x(
# filepath,
# os.path.join(big_path, basefilename) + ".jpg",
# )
yield ImagineResult(img=img, prompt=prompt)
def prompt_normalized(prompt):
@ -287,7 +264,3 @@ def fix_faces_GFPGAN(image):
res = Image.fromarray(restored_img)
return res
if __name__ == "__main__":
main()

View File

@ -31,13 +31,12 @@ class DDIMSampler:
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
@ -75,7 +74,6 @@ class DDIMSampler:
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
@ -108,7 +106,6 @@ class DDIMSampler:
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
@ -129,7 +126,7 @@ class DDIMSampler:
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
@ -230,8 +227,6 @@ class DDIMSampler:
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
@ -239,6 +234,7 @@ class DDIMSampler:
callback(i)
if img_callback:
img_callback(pred_x0, i)
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
@ -246,7 +242,7 @@ class DDIMSampler:
return img, intermediates
@torch.no_grad()
# @torch.no_grad()
def p_sample_ddim(
self,
x,
@ -258,27 +254,22 @@ class DDIMSampler:
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
loss_function=None
):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
e_t = self.model.apply_model(x, t, c)
with torch.no_grad():
noise_pred = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
# with torch.no_grad():
noise_pred_uncond, noise_pred = self.model.apply_model(x_in, t_in, c_in).chunk(2)
noise_pred = noise_pred_uncond + unconditional_guidance_scale * (noise_pred - noise_pred_uncond)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
@ -305,11 +296,12 @@ class DDIMSampler:
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
pred_x0 = (x - sqrt_one_minus_at * noise_pred) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * noise_pred
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
@ -344,6 +336,8 @@ class DDIMSampler:
unconditional_conditioning=None,
use_original_steps=False,
img_callback=None,
score_corrector=None,
temperature=1.0
):
timesteps = (
@ -359,6 +353,7 @@ class DDIMSampler:
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
@ -372,7 +367,15 @@ class DDIMSampler:
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
temperature=temperature
)
# original_loss = ((x_dec - x_latent).abs().mean()*70)
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
## x_dec_alt = x_dec + (original_loss * 0.1) ** 2
if img_callback:
img_callback(x_dec, i)
img_callback(pred_x0, i)
return x_dec

View File

@ -29,16 +29,13 @@ class PLMSSampler(object):
attr = attr.to(torch.float32).to(torch.device(self.device_available))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0):
if ddim_eta != 0:
raise ValueError("ddim_eta must be 0 for PLMS")
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
@ -76,7 +73,6 @@ class PLMSSampler(object):
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
@ -109,7 +105,6 @@ class PLMSSampler(object):
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
@ -130,7 +125,7 @@ class PLMSSampler(object):
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
@ -252,6 +247,7 @@ class PLMSSampler(object):
if callback:
callback(i)
if img_callback:
img_callback(img, i)
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:

View File

@ -8,7 +8,7 @@ from transformers import CLIPTokenizer, CLIPTextModel
from imaginairy.utils import get_device
class FrozenCLIPEmbedder:
class FrozenCLIPEmbedder(nn.Module):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(

70
imaginairy/schema.py Normal file
View File

@ -0,0 +1,70 @@
import hashlib
import random
import numpy
class WeightedPrompt:
def __init__(self, text, weight=1):
self.text = text
self.weight = weight
def __str__(self):
return f"{self.weight}*({self.text})"
class ImaginePrompt:
def __init__(
self,
prompt=None,
seed=None,
prompt_strength=7.5,
sampler_type="PLMS",
init_image=None,
init_image_strength=0.3,
steps=50,
height=512,
width=512,
upscale=False,
fix_faces=False,
parts=None,
):
prompt = prompt if prompt is not None else "a scenic landscape"
if isinstance(prompt, str):
self.prompts = [WeightedPrompt(prompt, 1)]
else:
self.prompts = prompt
self.init_image = init_image
self.init_image_strength = init_image_strength
self.prompts.sort(key=lambda p: p.weight, reverse=True)
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
self.prompt_strength = prompt_strength
self.sampler_type = sampler_type
self.steps = steps
self.height = height
self.width = width
self.upscale = upscale
self.fix_faces = fix_faces
self.parts = parts or {}
@property
def prompt_text(self):
if len(self.prompts) == 1:
return self.prompts[0].text
return "|".join(str(p) for p in self.prompts)
class ImagineResult:
def __init__(self, img, prompt):
self.img = img
self.prompt = prompt
def cv2_img(self):
open_cv_image = numpy.array(self.img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
return open_cv_image
# return cv2.cvtColor(numpy.array(self.img), cv2.COLOR_RGB2BGR)
def md5(self):
return hashlib.md5(self.img.tobytes()).hexdigest()

1
requirements-dev.in Normal file
View File

@ -0,0 +1 @@
pytest

View File

@ -0,0 +1,3 @@
import os.path
TESTS_FOLDER = os.path.dirname(__file__)

60
tests/test_imagine.py Normal file
View File

@ -0,0 +1,60 @@
from imaginairy.imagine import imagine_images, imagine_image_files
from imaginairy.schema import ImaginePrompt, WeightedPrompt
from . import TESTS_FOLDER
def test_imagine():
prompt = ImaginePrompt("a scenic landscape", width=512, height=256, steps=20, seed=1)
result = next(imagine_images(prompt))
assert result.md5() == '4c5957c498881d365cfcf13014812af0'
result.img.save(f"{TESTS_FOLDER}/test_output/scenic_landscape.png")
def test_img_to_img():
prompt = ImaginePrompt(
"a photo of a beach",
init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
init_image_strength=0.5,
width=512,
height=512,
steps=50,
seed=1,
sampler_type="DDIM",
)
out_folder = f"{TESTS_FOLDER}/test_output"
out_folder = '/home/bryce/Mounts/drennanfiles/art/tests'
imagine_image_files(prompt, outdir=out_folder)
def test_img_to_file():
prompt = ImaginePrompt(
[WeightedPrompt("an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo")],
# init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
init_image_strength=0.5,
width=512+64,
height=512-64,
steps=50,
# seed=2,
sampler_type="PLMS",
upscale=True
)
out_folder = f"{TESTS_FOLDER}/test_output"
out_folder = '/home/bryce/Mounts/drennanfiles/art/tests'
imagine_image_files(prompt, outdir=out_folder)
def test_img_conditioning():
prompt = ImaginePrompt(
"photo",
init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
init_image_strength=0.5,
width=512+64,
height=512-64,
steps=50,
# seed=2,
sampler_type="PLMS",
upscale=True
)
out_folder = f"{TESTS_FOLDER}/test_output"
out_folder = '/home/bryce/Mounts/drennanfiles/art/tests'
imagine_image_files(prompt, outdir=out_folder, record_steps=True)