|
|
|
@ -5,6 +5,7 @@ import re
|
|
|
|
|
import subprocess
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
|
|
|
|
|
import PIL
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from PIL import Image
|
|
|
|
@ -184,6 +185,8 @@ class ImaginePrompt:
|
|
|
|
|
seed=None,
|
|
|
|
|
prompt_strength=7.5,
|
|
|
|
|
sampler_type="PLMS",
|
|
|
|
|
init_image=None,
|
|
|
|
|
init_image_strength=0.3,
|
|
|
|
|
steps=50,
|
|
|
|
|
height=512,
|
|
|
|
|
width=512,
|
|
|
|
@ -196,6 +199,8 @@ class ImaginePrompt:
|
|
|
|
|
self.prompts = [WeightedPrompt(prompt, 1)]
|
|
|
|
|
else:
|
|
|
|
|
self.prompts = prompt
|
|
|
|
|
self.init_image = init_image
|
|
|
|
|
self.init_image_strength = init_image_strength
|
|
|
|
|
self.prompts.sort(key=lambda p: p.weight, reverse=True)
|
|
|
|
|
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
|
|
|
|
|
self.prompt_strength = prompt_strength
|
|
|
|
@ -214,6 +219,20 @@ class ImaginePrompt:
|
|
|
|
|
return "|".join(str(p) for p in self.prompts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_img(path, max_height=512, max_width=512):
|
|
|
|
|
image = Image.open(path).convert("RGB")
|
|
|
|
|
w, h = image.size
|
|
|
|
|
print(f"loaded input image of size ({w}, {h}) from {path}")
|
|
|
|
|
resize_ratio = min(max_width / w, max_height / h)
|
|
|
|
|
w, h = int(w * resize_ratio), int(h * resize_ratio)
|
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
|
|
|
|
|
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
|
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
|
|
|
image = torch.from_numpy(image)
|
|
|
|
|
return 2.0 * image - 1.0, w, h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def imagine(
|
|
|
|
|
prompts,
|
|
|
|
|
config="data/stable-diffusion-v1.yaml",
|
|
|
|
@ -254,7 +273,6 @@ def imagine(
|
|
|
|
|
for wp in prompt.prompts
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
# c = model.get_learned_conditioning(prompt.prompt_text)
|
|
|
|
|
|
|
|
|
|
shape = [
|
|
|
|
|
latent_channels,
|
|
|
|
@ -263,24 +281,57 @@ def imagine(
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def img_callback(samples, i):
|
|
|
|
|
return
|
|
|
|
|
pass
|
|
|
|
|
samples = model.decode_first_stage(samples)
|
|
|
|
|
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
steps_path = os.path.join(
|
|
|
|
|
sample_path, "steps", f"{base_count:08}_S{prompt.seed}"
|
|
|
|
|
)
|
|
|
|
|
os.makedirs(steps_path, exist_ok=True)
|
|
|
|
|
for pred_x0 in samples:
|
|
|
|
|
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
|
|
|
|
|
filename = f"{base_count:08}_S{seed}_step{i:04}.jpg"
|
|
|
|
|
filename = f"{base_count:08}_S{prompt.seed}_step{i:04}.jpg"
|
|
|
|
|
Image.fromarray(pred_x0.astype(np.uint8)).save(
|
|
|
|
|
os.path.join(sample_path, filename)
|
|
|
|
|
os.path.join(steps_path, filename)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
start_code = None
|
|
|
|
|
if fixed_code:
|
|
|
|
|
start_code = torch.randn(
|
|
|
|
|
[1, latent_channels, prompt.height, prompt.width],
|
|
|
|
|
device=get_device(),
|
|
|
|
|
)
|
|
|
|
|
# if fixed_code:
|
|
|
|
|
# start_code = torch.randn(
|
|
|
|
|
# [1, latent_channels, prompt.height, prompt.width],
|
|
|
|
|
# device=get_device(),
|
|
|
|
|
# )
|
|
|
|
|
sampler = get_sampler(prompt.sampler_type, model)
|
|
|
|
|
samples_ddim, _ = sampler.sample(
|
|
|
|
|
if prompt.init_image:
|
|
|
|
|
generation_strength = 1 - prompt.init_image_strength
|
|
|
|
|
ddim_steps = int(prompt.steps / generation_strength)
|
|
|
|
|
sampler.make_schedule(
|
|
|
|
|
ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
t_enc = int(generation_strength * ddim_steps)
|
|
|
|
|
init_image, w, h = load_img(prompt.init_image)
|
|
|
|
|
init_image = init_image.to(get_device())
|
|
|
|
|
init_latent = model.get_first_stage_encoding(
|
|
|
|
|
model.encode_first_stage(init_image)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# encode (scaled latent)
|
|
|
|
|
z_enc = sampler.stochastic_encode(
|
|
|
|
|
init_latent, torch.tensor([t_enc]).to(get_device())
|
|
|
|
|
)
|
|
|
|
|
# decode it
|
|
|
|
|
samples = sampler.decode(
|
|
|
|
|
z_enc,
|
|
|
|
|
c,
|
|
|
|
|
t_enc,
|
|
|
|
|
unconditional_guidance_scale=prompt.prompt_strength,
|
|
|
|
|
unconditional_conditioning=uc,
|
|
|
|
|
img_callback=img_callback,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
samples, _ = sampler.sample(
|
|
|
|
|
S=prompt.steps,
|
|
|
|
|
conditioning=c,
|
|
|
|
|
batch_size=1,
|
|
|
|
@ -293,11 +344,11 @@ def imagine(
|
|
|
|
|
img_callback=img_callback,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
|
|
|
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
x_samples = model.decode_first_stage(samples)
|
|
|
|
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
|
|
|
|
|
if not skip_save:
|
|
|
|
|
for x_sample in x_samples_ddim:
|
|
|
|
|
for x_sample in x_samples:
|
|
|
|
|
x_sample = 255.0 * rearrange(
|
|
|
|
|
x_sample.cpu().numpy(), "c h w -> h w c"
|
|
|
|
|
)
|
|
|
|
|