mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
feature: add ImageResult. step output option
remove verbose args
This commit is contained in:
parent
47c6bcee59
commit
66c640ce7b
@ -2,7 +2,8 @@
|
||||
|
||||
import click
|
||||
|
||||
from imaginairy.imagine import ImaginePrompt, imagine as imagine_f
|
||||
from imaginairy.imagine import imagine as imagine_f
|
||||
from imaginairy.schema import ImaginePrompt
|
||||
|
||||
|
||||
@click.command()
|
||||
@ -10,7 +11,9 @@ from imaginairy.imagine import ImaginePrompt, imagine as imagine_f
|
||||
"prompt_texts", default=None, help="text to render to an image", nargs=-1
|
||||
)
|
||||
@click.option("--outdir", default="./outputs", help="where to write results to")
|
||||
@click.option("-r", "--repeats", default=1, type=int, help="How many times to repeat the renders")
|
||||
@click.option(
|
||||
"-r", "--repeats", default=1, type=int, help="How many times to repeat the renders"
|
||||
)
|
||||
@click.option(
|
||||
"-h",
|
||||
"--height",
|
||||
|
@ -1,7 +1,5 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
from contextlib import nullcontext
|
||||
@ -18,13 +16,14 @@ from torch import autocast
|
||||
|
||||
from imaginairy.models.diffusion.ddim import DDIMSampler
|
||||
from imaginairy.models.diffusion.plms import PLMSSampler
|
||||
from imaginairy.schema import ImaginePrompt, ImagineResult
|
||||
from imaginairy.utils import get_device, instantiate_from_config
|
||||
|
||||
LIB_PATH = os.path.dirname(__file__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
def load_model_from_config(config, ckpt):
|
||||
logger.info(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
@ -32,9 +31,9 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
if len(m) > 0:
|
||||
logger.info(f"missing keys: {m}")
|
||||
if len(u) > 0 and verbose:
|
||||
if len(u) > 0:
|
||||
logger.info(f"unexpected keys: {u}")
|
||||
|
||||
model.cuda()
|
||||
@ -42,56 +41,6 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||
return model
|
||||
|
||||
|
||||
class WeightedPrompt:
|
||||
def __init__(self, text, weight=1):
|
||||
self.text = text
|
||||
self.weight = weight
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.weight}*({self.text})"
|
||||
|
||||
|
||||
class ImaginePrompt:
|
||||
def __init__(
|
||||
self,
|
||||
prompt=None,
|
||||
seed=None,
|
||||
prompt_strength=7.5,
|
||||
sampler_type="PLMS",
|
||||
init_image=None,
|
||||
init_image_strength=0.3,
|
||||
steps=50,
|
||||
height=512,
|
||||
width=512,
|
||||
upscale=True,
|
||||
fix_faces=True,
|
||||
parts=None,
|
||||
):
|
||||
prompt = prompt if prompt is not None else "a scenic landscape"
|
||||
if isinstance(prompt, str):
|
||||
self.prompts = [WeightedPrompt(prompt, 1)]
|
||||
else:
|
||||
self.prompts = prompt
|
||||
self.init_image = init_image
|
||||
self.init_image_strength = init_image_strength
|
||||
self.prompts.sort(key=lambda p: p.weight, reverse=True)
|
||||
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
|
||||
self.prompt_strength = prompt_strength
|
||||
self.sampler_type = sampler_type
|
||||
self.steps = steps
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.upscale = upscale
|
||||
self.fix_faces = fix_faces
|
||||
self.parts = parts or {}
|
||||
|
||||
@property
|
||||
def prompt_text(self):
|
||||
if len(self.prompts) == 1:
|
||||
return self.prompts[0].text
|
||||
return "|".join(str(p) for p in self.prompts)
|
||||
|
||||
|
||||
def load_img(path, max_height=512, max_width=512):
|
||||
image = Image.open(path).convert("RGB")
|
||||
w, h = image.size
|
||||
@ -108,8 +57,8 @@ def load_img(path, max_height=512, max_width=512):
|
||||
|
||||
@lru_cache()
|
||||
def load_model():
|
||||
config = ("data/stable-diffusion-v1.yaml",)
|
||||
ckpt = ("data/stable-diffusion-v1-4.ckpt",)
|
||||
config = "data/stable-diffusion-v1.yaml"
|
||||
ckpt = "data/stable-diffusion-v1-4.ckpt"
|
||||
config = OmegaConf.load(f"{LIB_PATH}/{config}")
|
||||
model = load_model_from_config(config, f"{LIB_PATH}/{ckpt}")
|
||||
|
||||
@ -117,24 +66,69 @@ def load_model():
|
||||
return model
|
||||
|
||||
|
||||
def imagine(
|
||||
def imagine_image_files(
|
||||
prompts,
|
||||
outdir="outputs/txt2img-samples",
|
||||
outdir,
|
||||
latent_channels=4,
|
||||
downsampling_factor=8,
|
||||
precision="autocast",
|
||||
skip_save=False,
|
||||
ddim_eta=0.0,
|
||||
record_steps=False
|
||||
):
|
||||
big_path = os.path.join(outdir, "upscaled")
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
os.makedirs(big_path, exist_ok=True)
|
||||
base_count = len(os.listdir(outdir))
|
||||
step_count = 0
|
||||
|
||||
def _record_steps(samples, i, model, prompt):
|
||||
nonlocal step_count
|
||||
step_count += 1
|
||||
samples = model.decode_first_stage(samples)
|
||||
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
|
||||
os.makedirs(steps_path, exist_ok=True)
|
||||
for pred_x0 in samples:
|
||||
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
|
||||
filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}.jpg"
|
||||
Image.fromarray(pred_x0.astype(np.uint8)).save(
|
||||
os.path.join(steps_path, filename)
|
||||
)
|
||||
img_callback = _record_steps if record_steps else None
|
||||
for result in imagine_images(
|
||||
prompts,
|
||||
latent_channels=latent_channels,
|
||||
downsampling_factor=downsampling_factor,
|
||||
precision=precision,
|
||||
ddim_eta=ddim_eta,
|
||||
img_callback=img_callback,
|
||||
):
|
||||
prompt = result.prompt
|
||||
img = result.img
|
||||
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
|
||||
filepath = os.path.join(outdir, f"{basefilename}.jpg")
|
||||
|
||||
img.save(filepath)
|
||||
if prompt.upscale:
|
||||
enlarge_realesrgan2x(
|
||||
filepath,
|
||||
os.path.join(big_path, basefilename) + ".jpg",
|
||||
)
|
||||
base_count += 1
|
||||
|
||||
|
||||
def imagine_images(
|
||||
prompts,
|
||||
latent_channels=4,
|
||||
downsampling_factor=8,
|
||||
precision="autocast",
|
||||
ddim_eta=0.0,
|
||||
img_callback=None,
|
||||
):
|
||||
model = load_model()
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
outpath = outdir
|
||||
|
||||
sample_path = os.path.join(outpath)
|
||||
big_path = os.path.join(sample_path, "esrgan")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
os.makedirs(big_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
|
||||
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
|
||||
_img_callback = None
|
||||
|
||||
precision_scope = autocast if precision == "autocast" else nullcontext
|
||||
with (torch.no_grad(), precision_scope("cuda")):
|
||||
@ -150,6 +144,9 @@ def imagine(
|
||||
for wp in prompt.prompts
|
||||
]
|
||||
)
|
||||
if img_callback:
|
||||
def _img_callback(samples, i):
|
||||
img_callback(samples, i, model, prompt)
|
||||
|
||||
shape = [
|
||||
latent_channels,
|
||||
@ -157,41 +154,29 @@ def imagine(
|
||||
prompt.width // downsampling_factor,
|
||||
]
|
||||
|
||||
def img_callback(samples, i):
|
||||
pass
|
||||
samples = model.decode_first_stage(samples)
|
||||
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
steps_path = os.path.join(
|
||||
sample_path, "steps", f"{base_count:08}_S{prompt.seed}"
|
||||
)
|
||||
os.makedirs(steps_path, exist_ok=True)
|
||||
for pred_x0 in samples:
|
||||
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
|
||||
filename = f"{base_count:08}_S{prompt.seed}_step{i:04}.jpg"
|
||||
Image.fromarray(pred_x0.astype(np.uint8)).save(
|
||||
os.path.join(steps_path, filename)
|
||||
)
|
||||
|
||||
start_code = None
|
||||
sampler = get_sampler(prompt.sampler_type, model)
|
||||
if prompt.init_image:
|
||||
generation_strength = 1 - prompt.init_image_strength
|
||||
ddim_steps = int(prompt.steps / generation_strength)
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False
|
||||
ddim_num_steps=ddim_steps, ddim_eta=ddim_eta
|
||||
)
|
||||
|
||||
t_enc = int(generation_strength * ddim_steps)
|
||||
init_image, w, h = load_img(prompt.init_image)
|
||||
init_image = init_image.to(get_device())
|
||||
init_latent = model.get_first_stage_encoding(
|
||||
model.encode_first_stage(init_image)
|
||||
)
|
||||
init_latent = model.encode_first_stage(init_image)
|
||||
noised_init_latent = model.get_first_stage_encoding(init_latent)
|
||||
_img_callback(init_latent.mean, 0)
|
||||
_img_callback(noised_init_latent, 0)
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
init_latent, torch.tensor([t_enc]).to(get_device())
|
||||
noised_init_latent, torch.tensor([t_enc]).to(get_device()),
|
||||
)
|
||||
_img_callback(noised_init_latent, 0)
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
@ -199,7 +184,7 @@ def imagine(
|
||||
t_enc,
|
||||
unconditional_guidance_scale=prompt.prompt_strength,
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=img_callback,
|
||||
img_callback=_img_callback,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -208,35 +193,27 @@ def imagine(
|
||||
conditioning=c,
|
||||
batch_size=1,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=prompt.prompt_strength,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
x_T=start_code,
|
||||
img_callback=img_callback,
|
||||
img_callback=_img_callback,
|
||||
)
|
||||
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if not skip_save:
|
||||
for x_sample in x_samples:
|
||||
x_sample = 255.0 * rearrange(
|
||||
x_sample.cpu().numpy(), "c h w -> h w c"
|
||||
)
|
||||
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
|
||||
filepath = os.path.join(sample_path, f"{basefilename}.jpg")
|
||||
img = Image.fromarray(x_sample.astype(np.uint8))
|
||||
if prompt.fix_faces:
|
||||
img = fix_faces_GFPGAN(img)
|
||||
img.save(filepath)
|
||||
if prompt.upscale:
|
||||
enlarge_realesrgan2x(
|
||||
filepath,
|
||||
os.path.join(big_path, basefilename) + ".jpg",
|
||||
)
|
||||
base_count += 1
|
||||
return f"{basefilename}.jpg"
|
||||
for x_sample in x_samples:
|
||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
||||
img = Image.fromarray(x_sample.astype(np.uint8))
|
||||
if prompt.fix_faces:
|
||||
img = fix_faces_GFPGAN(img)
|
||||
# if prompt.upscale:
|
||||
# enlarge_realesrgan2x(
|
||||
# filepath,
|
||||
# os.path.join(big_path, basefilename) + ".jpg",
|
||||
# )
|
||||
yield ImagineResult(img=img, prompt=prompt)
|
||||
|
||||
|
||||
def prompt_normalized(prompt):
|
||||
@ -287,7 +264,3 @@ def fix_faces_GFPGAN(image):
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -31,13 +31,12 @@ class DDIMSampler:
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0
|
||||
):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
@ -75,7 +74,6 @@ class DDIMSampler:
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
@ -108,7 +106,6 @@ class DDIMSampler:
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
@ -129,7 +126,7 @@ class DDIMSampler:
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
@ -230,8 +227,6 @@ class DDIMSampler:
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
@ -239,6 +234,7 @@ class DDIMSampler:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates["x_inter"].append(img)
|
||||
@ -246,7 +242,7 @@ class DDIMSampler:
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
# @torch.no_grad()
|
||||
def p_sample_ddim(
|
||||
self,
|
||||
x,
|
||||
@ -258,27 +254,22 @@ class DDIMSampler:
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
loss_function=None
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
with torch.no_grad():
|
||||
noise_pred = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
# with torch.no_grad():
|
||||
noise_pred_uncond, noise_pred = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
noise_pred = noise_pred_uncond + unconditional_guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
@ -305,11 +296,12 @@ class DDIMSampler:
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
pred_x0 = (x - sqrt_one_minus_at * noise_pred) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * noise_pred
|
||||
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
@ -344,6 +336,8 @@ class DDIMSampler:
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
img_callback=None,
|
||||
score_corrector=None,
|
||||
temperature=1.0
|
||||
):
|
||||
|
||||
timesteps = (
|
||||
@ -359,6 +353,7 @@ class DDIMSampler:
|
||||
|
||||
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
||||
x_dec = x_latent
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
@ -372,7 +367,15 @@ class DDIMSampler:
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
temperature=temperature
|
||||
)
|
||||
# original_loss = ((x_dec - x_latent).abs().mean()*70)
|
||||
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())
|
||||
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
|
||||
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
|
||||
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
|
||||
## x_dec_alt = x_dec + (original_loss * 0.1) ** 2
|
||||
if img_callback:
|
||||
img_callback(x_dec, i)
|
||||
img_callback(pred_x0, i)
|
||||
return x_dec
|
||||
|
@ -29,16 +29,13 @@ class PLMSSampler(object):
|
||||
attr = attr.to(torch.float32).to(torch.device(self.device_available))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||
):
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError("ddim_eta must be 0 for PLMS")
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
@ -76,7 +73,6 @@ class PLMSSampler(object):
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
@ -109,7 +105,6 @@ class PLMSSampler(object):
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
@ -130,7 +125,7 @@ class PLMSSampler(object):
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
@ -252,6 +247,7 @@ class PLMSSampler(object):
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(img, i)
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
|
@ -8,7 +8,7 @@ from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder:
|
||||
class FrozenCLIPEmbedder(nn.Module):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
|
||||
def __init__(
|
||||
|
70
imaginairy/schema.py
Normal file
70
imaginairy/schema.py
Normal file
@ -0,0 +1,70 @@
|
||||
import hashlib
|
||||
import random
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
class WeightedPrompt:
|
||||
def __init__(self, text, weight=1):
|
||||
self.text = text
|
||||
self.weight = weight
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.weight}*({self.text})"
|
||||
|
||||
|
||||
class ImaginePrompt:
|
||||
def __init__(
|
||||
self,
|
||||
prompt=None,
|
||||
seed=None,
|
||||
prompt_strength=7.5,
|
||||
sampler_type="PLMS",
|
||||
init_image=None,
|
||||
init_image_strength=0.3,
|
||||
steps=50,
|
||||
height=512,
|
||||
width=512,
|
||||
upscale=False,
|
||||
fix_faces=False,
|
||||
parts=None,
|
||||
):
|
||||
prompt = prompt if prompt is not None else "a scenic landscape"
|
||||
if isinstance(prompt, str):
|
||||
self.prompts = [WeightedPrompt(prompt, 1)]
|
||||
else:
|
||||
self.prompts = prompt
|
||||
self.init_image = init_image
|
||||
self.init_image_strength = init_image_strength
|
||||
self.prompts.sort(key=lambda p: p.weight, reverse=True)
|
||||
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
|
||||
self.prompt_strength = prompt_strength
|
||||
self.sampler_type = sampler_type
|
||||
self.steps = steps
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.upscale = upscale
|
||||
self.fix_faces = fix_faces
|
||||
self.parts = parts or {}
|
||||
|
||||
@property
|
||||
def prompt_text(self):
|
||||
if len(self.prompts) == 1:
|
||||
return self.prompts[0].text
|
||||
return "|".join(str(p) for p in self.prompts)
|
||||
|
||||
|
||||
class ImagineResult:
|
||||
def __init__(self, img, prompt):
|
||||
self.img = img
|
||||
self.prompt = prompt
|
||||
|
||||
def cv2_img(self):
|
||||
open_cv_image = numpy.array(self.img)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
return open_cv_image
|
||||
# return cv2.cvtColor(numpy.array(self.img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
def md5(self):
|
||||
return hashlib.md5(self.img.tobytes()).hexdigest()
|
1
requirements-dev.in
Normal file
1
requirements-dev.in
Normal file
@ -0,0 +1 @@
|
||||
pytest
|
@ -0,0 +1,3 @@
|
||||
import os.path
|
||||
|
||||
TESTS_FOLDER = os.path.dirname(__file__)
|
60
tests/test_imagine.py
Normal file
60
tests/test_imagine.py
Normal file
@ -0,0 +1,60 @@
|
||||
from imaginairy.imagine import imagine_images, imagine_image_files
|
||||
from imaginairy.schema import ImaginePrompt, WeightedPrompt
|
||||
from . import TESTS_FOLDER
|
||||
|
||||
|
||||
def test_imagine():
|
||||
prompt = ImaginePrompt("a scenic landscape", width=512, height=256, steps=20, seed=1)
|
||||
result = next(imagine_images(prompt))
|
||||
assert result.md5() == '4c5957c498881d365cfcf13014812af0'
|
||||
result.img.save(f"{TESTS_FOLDER}/test_output/scenic_landscape.png")
|
||||
|
||||
|
||||
def test_img_to_img():
|
||||
prompt = ImaginePrompt(
|
||||
"a photo of a beach",
|
||||
init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
|
||||
init_image_strength=0.5,
|
||||
width=512,
|
||||
height=512,
|
||||
steps=50,
|
||||
seed=1,
|
||||
sampler_type="DDIM",
|
||||
)
|
||||
out_folder = f"{TESTS_FOLDER}/test_output"
|
||||
out_folder = '/home/bryce/Mounts/drennanfiles/art/tests'
|
||||
imagine_image_files(prompt, outdir=out_folder)
|
||||
|
||||
|
||||
def test_img_to_file():
|
||||
prompt = ImaginePrompt(
|
||||
[WeightedPrompt("an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo")],
|
||||
# init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
|
||||
init_image_strength=0.5,
|
||||
width=512+64,
|
||||
height=512-64,
|
||||
steps=50,
|
||||
# seed=2,
|
||||
sampler_type="PLMS",
|
||||
upscale=True
|
||||
)
|
||||
out_folder = f"{TESTS_FOLDER}/test_output"
|
||||
out_folder = '/home/bryce/Mounts/drennanfiles/art/tests'
|
||||
imagine_image_files(prompt, outdir=out_folder)
|
||||
|
||||
|
||||
def test_img_conditioning():
|
||||
prompt = ImaginePrompt(
|
||||
"photo",
|
||||
init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
|
||||
init_image_strength=0.5,
|
||||
width=512+64,
|
||||
height=512-64,
|
||||
steps=50,
|
||||
# seed=2,
|
||||
sampler_type="PLMS",
|
||||
upscale=True
|
||||
)
|
||||
out_folder = f"{TESTS_FOLDER}/test_output"
|
||||
out_folder = '/home/bryce/Mounts/drennanfiles/art/tests'
|
||||
imagine_image_files(prompt, outdir=out_folder, record_steps=True)
|
Loading…
Reference in New Issue
Block a user