2022-09-09 04:51:25 +00:00
|
|
|
import logging
|
2022-09-08 03:59:30 +00:00
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import subprocess
|
|
|
|
from contextlib import nullcontext
|
2022-09-09 04:51:25 +00:00
|
|
|
from functools import lru_cache
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-09 04:30:20 +00:00
|
|
|
import PIL
|
2022-09-08 03:59:30 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
2022-09-11 10:08:51 +00:00
|
|
|
import torch.nn
|
2022-09-08 03:59:30 +00:00
|
|
|
from PIL import Image
|
|
|
|
from einops import rearrange
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from pytorch_lightning import seed_everything
|
|
|
|
from torch import autocast
|
2022-09-10 07:32:31 +00:00
|
|
|
from transformers import cached_path
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-11 07:58:56 +00:00
|
|
|
from imaginairy.modules.diffusion.ddim import DDIMSampler
|
|
|
|
from imaginairy.modules.diffusion.plms import PLMSSampler
|
2022-09-11 07:35:57 +00:00
|
|
|
from imaginairy.safety import is_nsfw
|
2022-09-10 05:14:04 +00:00
|
|
|
from imaginairy.schema import ImaginePrompt, ImagineResult
|
2022-09-10 07:32:31 +00:00
|
|
|
from imaginairy.utils import (
|
|
|
|
get_device,
|
|
|
|
instantiate_from_config,
|
|
|
|
fix_torch_nn_layer_norm,
|
|
|
|
)
|
|
|
|
|
2022-09-08 03:59:30 +00:00
|
|
|
LIB_PATH = os.path.dirname(__file__)
|
2022-09-09 04:51:25 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
|
2022-09-11 07:35:57 +00:00
|
|
|
# leave undocumented. I'd ask that no one publicize this flag
|
|
|
|
IMAGINAIRY_ALLOW_NSFW = os.getenv("IMAGINAIRY_ALLOW_NSFW", "False")
|
|
|
|
IMAGINAIRY_ALLOW_NSFW = bool(IMAGINAIRY_ALLOW_NSFW == "I AM A RESPONSIBLE ADULT")
|
|
|
|
|
|
|
|
|
2022-09-10 07:32:31 +00:00
|
|
|
def load_model_from_config(config):
|
2022-09-11 06:27:22 +00:00
|
|
|
url = "https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
|
|
|
|
ckpt_path = cached_path(url)
|
|
|
|
logger.info(f"Loading model onto {get_device()} backend...")
|
|
|
|
logger.debug(f"Loading model from {ckpt_path}")
|
2022-09-10 07:32:31 +00:00
|
|
|
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
2022-09-08 03:59:30 +00:00
|
|
|
if "global_step" in pl_sd:
|
2022-09-11 06:27:22 +00:00
|
|
|
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
2022-09-08 03:59:30 +00:00
|
|
|
sd = pl_sd["state_dict"]
|
|
|
|
model = instantiate_from_config(config.model)
|
|
|
|
m, u = model.load_state_dict(sd, strict=False)
|
2022-09-10 05:14:04 +00:00
|
|
|
if len(m) > 0:
|
2022-09-11 06:27:22 +00:00
|
|
|
logger.debug(f"missing keys: {m}")
|
2022-09-10 05:14:04 +00:00
|
|
|
if len(u) > 0:
|
2022-09-11 06:27:22 +00:00
|
|
|
logger.debug(f"unexpected keys: {u}")
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-10 07:32:31 +00:00
|
|
|
model.to(get_device())
|
2022-09-08 03:59:30 +00:00
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2022-09-09 04:30:20 +00:00
|
|
|
def load_img(path, max_height=512, max_width=512):
|
|
|
|
image = Image.open(path).convert("RGB")
|
|
|
|
w, h = image.size
|
2022-09-09 04:51:25 +00:00
|
|
|
logger.info(f"loaded input image of size ({w}, {h}) from {path}")
|
2022-09-09 04:30:20 +00:00
|
|
|
resize_ratio = min(max_width / w, max_height / h)
|
|
|
|
w, h = int(w * resize_ratio), int(h * resize_ratio)
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
|
|
|
|
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
|
|
image = torch.from_numpy(image)
|
|
|
|
return 2.0 * image - 1.0, w, h
|
|
|
|
|
|
|
|
|
2022-09-11 10:08:51 +00:00
|
|
|
def patch_conv(**patch):
|
2022-09-11 20:56:41 +00:00
|
|
|
"""https://github.com/replicate/cog-stable-diffusion/compare/main...TomMoore515:material_stable_diffusion:main"""
|
2022-09-11 10:08:51 +00:00
|
|
|
cls = torch.nn.Conv2d
|
|
|
|
init = cls.__init__
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
return init(self, *args, **kwargs, **patch)
|
|
|
|
|
|
|
|
cls.__init__ = __init__
|
|
|
|
|
|
|
|
|
2022-09-09 04:51:25 +00:00
|
|
|
@lru_cache()
|
2022-09-11 10:08:51 +00:00
|
|
|
def load_model(tile_mode=False):
|
|
|
|
if tile_mode:
|
|
|
|
# generated images are tileable
|
|
|
|
patch_conv(padding_mode="circular")
|
|
|
|
|
2022-09-10 07:32:31 +00:00
|
|
|
config = "configs/stable-diffusion-v1.yaml"
|
2022-09-09 04:51:25 +00:00
|
|
|
config = OmegaConf.load(f"{LIB_PATH}/{config}")
|
2022-09-10 07:32:31 +00:00
|
|
|
model = load_model_from_config(config)
|
2022-09-09 04:51:25 +00:00
|
|
|
|
|
|
|
model = model.to(get_device())
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
def imagine_image_files(
|
2022-09-08 03:59:30 +00:00
|
|
|
prompts,
|
2022-09-10 05:14:04 +00:00
|
|
|
outdir,
|
2022-09-08 03:59:30 +00:00
|
|
|
latent_channels=4,
|
|
|
|
downsampling_factor=8,
|
|
|
|
precision="autocast",
|
|
|
|
ddim_eta=0.0,
|
2022-09-11 06:27:22 +00:00
|
|
|
record_step_images=False,
|
|
|
|
output_file_extension="jpg",
|
2022-09-11 20:56:41 +00:00
|
|
|
tile_mode=False,
|
2022-09-08 03:59:30 +00:00
|
|
|
):
|
2022-09-10 05:14:04 +00:00
|
|
|
big_path = os.path.join(outdir, "upscaled")
|
2022-09-08 03:59:30 +00:00
|
|
|
os.makedirs(outdir, exist_ok=True)
|
|
|
|
os.makedirs(big_path, exist_ok=True)
|
2022-09-10 05:14:04 +00:00
|
|
|
base_count = len(os.listdir(outdir))
|
|
|
|
step_count = 0
|
2022-09-11 06:27:22 +00:00
|
|
|
output_file_extension = output_file_extension.lower()
|
|
|
|
if output_file_extension not in {"jpg", "png"}:
|
|
|
|
raise ValueError("Must output a png or jpg")
|
2022-09-10 05:14:04 +00:00
|
|
|
|
|
|
|
def _record_steps(samples, i, model, prompt):
|
|
|
|
nonlocal step_count
|
|
|
|
step_count += 1
|
|
|
|
samples = model.decode_first_stage(samples)
|
|
|
|
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
|
|
|
|
os.makedirs(steps_path, exist_ok=True)
|
|
|
|
for pred_x0 in samples:
|
|
|
|
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
|
|
|
|
filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}.jpg"
|
|
|
|
Image.fromarray(pred_x0.astype(np.uint8)).save(
|
|
|
|
os.path.join(steps_path, filename)
|
|
|
|
)
|
2022-09-10 07:32:31 +00:00
|
|
|
|
2022-09-11 06:27:22 +00:00
|
|
|
img_callback = _record_steps if record_step_images else None
|
2022-09-10 05:14:04 +00:00
|
|
|
for result in imagine_images(
|
|
|
|
prompts,
|
|
|
|
latent_channels=latent_channels,
|
|
|
|
downsampling_factor=downsampling_factor,
|
|
|
|
precision=precision,
|
|
|
|
ddim_eta=ddim_eta,
|
|
|
|
img_callback=img_callback,
|
2022-09-11 20:56:41 +00:00
|
|
|
tile_mode=tile_mode,
|
2022-09-10 05:14:04 +00:00
|
|
|
):
|
|
|
|
prompt = result.prompt
|
|
|
|
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
|
|
|
|
filepath = os.path.join(outdir, f"{basefilename}.jpg")
|
|
|
|
|
2022-09-11 06:27:22 +00:00
|
|
|
result.save(filepath)
|
|
|
|
logger.info(f" 🖼 saved to: {filepath}")
|
2022-09-10 05:14:04 +00:00
|
|
|
if prompt.upscale:
|
2022-09-11 06:27:22 +00:00
|
|
|
bigfilepath = (os.path.join(big_path, basefilename) + ".jpg",)
|
|
|
|
enlarge_realesrgan2x(filepath, bigfilepath)
|
|
|
|
logger.info(f" upscaled 🖼 saved to: {filepath}")
|
2022-09-10 05:14:04 +00:00
|
|
|
base_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
def imagine_images(
|
|
|
|
prompts,
|
|
|
|
latent_channels=4,
|
|
|
|
downsampling_factor=8,
|
|
|
|
precision="autocast",
|
|
|
|
ddim_eta=0.0,
|
|
|
|
img_callback=None,
|
2022-09-11 20:56:41 +00:00
|
|
|
tile_mode=False,
|
2022-09-10 05:14:04 +00:00
|
|
|
):
|
2022-09-11 20:56:41 +00:00
|
|
|
model = load_model(tile_mode=tile_mode)
|
2022-09-11 10:08:51 +00:00
|
|
|
# model = model.half()
|
2022-09-10 05:14:04 +00:00
|
|
|
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
|
|
|
|
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
|
|
|
|
_img_callback = None
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-11 06:27:22 +00:00
|
|
|
precision_scope = (
|
|
|
|
autocast
|
|
|
|
if precision == "autocast" and get_device() in ("cuda", "cpu")
|
|
|
|
else nullcontext
|
|
|
|
)
|
|
|
|
with (torch.no_grad(), precision_scope(get_device()), fix_torch_nn_layer_norm()):
|
2022-09-08 03:59:30 +00:00
|
|
|
for prompt in prompts:
|
2022-09-11 06:27:22 +00:00
|
|
|
logger.info(f"Generating {prompt.prompt_description()}")
|
2022-09-08 03:59:30 +00:00
|
|
|
seed_everything(prompt.seed)
|
2022-09-11 10:08:51 +00:00
|
|
|
|
|
|
|
# needed when model is in half mode, remove if not using half mode
|
|
|
|
# torch.set_default_tensor_type(torch.HalfTensor)
|
|
|
|
|
2022-09-08 03:59:30 +00:00
|
|
|
uc = None
|
|
|
|
if prompt.prompt_strength != 1.0:
|
|
|
|
uc = model.get_learned_conditioning(1 * [""])
|
|
|
|
total_weight = sum(wp.weight for wp in prompt.prompts)
|
|
|
|
c = sum(
|
|
|
|
[
|
|
|
|
model.get_learned_conditioning(wp.text) * (wp.weight / total_weight)
|
|
|
|
for wp in prompt.prompts
|
|
|
|
]
|
|
|
|
)
|
2022-09-10 05:14:04 +00:00
|
|
|
if img_callback:
|
2022-09-10 07:32:31 +00:00
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
def _img_callback(samples, i):
|
|
|
|
img_callback(samples, i, model, prompt)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
shape = [
|
|
|
|
latent_channels,
|
|
|
|
prompt.height // downsampling_factor,
|
|
|
|
prompt.width // downsampling_factor,
|
|
|
|
]
|
|
|
|
|
|
|
|
start_code = None
|
|
|
|
sampler = get_sampler(prompt.sampler_type, model)
|
2022-09-09 04:30:20 +00:00
|
|
|
if prompt.init_image:
|
|
|
|
generation_strength = 1 - prompt.init_image_strength
|
|
|
|
ddim_steps = int(prompt.steps / generation_strength)
|
2022-09-10 07:32:31 +00:00
|
|
|
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
|
2022-09-09 04:30:20 +00:00
|
|
|
|
|
|
|
t_enc = int(generation_strength * ddim_steps)
|
|
|
|
init_image, w, h = load_img(prompt.init_image)
|
|
|
|
init_image = init_image.to(get_device())
|
2022-09-10 05:14:04 +00:00
|
|
|
init_latent = model.encode_first_stage(init_image)
|
|
|
|
noised_init_latent = model.get_first_stage_encoding(init_latent)
|
|
|
|
_img_callback(init_latent.mean, 0)
|
|
|
|
_img_callback(noised_init_latent, 0)
|
2022-09-09 04:30:20 +00:00
|
|
|
|
|
|
|
# encode (scaled latent)
|
|
|
|
z_enc = sampler.stochastic_encode(
|
2022-09-10 07:32:31 +00:00
|
|
|
noised_init_latent,
|
|
|
|
torch.tensor([t_enc]).to(get_device()),
|
2022-09-09 04:30:20 +00:00
|
|
|
)
|
2022-09-10 05:14:04 +00:00
|
|
|
_img_callback(noised_init_latent, 0)
|
|
|
|
|
2022-09-09 04:30:20 +00:00
|
|
|
# decode it
|
|
|
|
samples = sampler.decode(
|
|
|
|
z_enc,
|
|
|
|
c,
|
|
|
|
t_enc,
|
|
|
|
unconditional_guidance_scale=prompt.prompt_strength,
|
|
|
|
unconditional_conditioning=uc,
|
2022-09-10 05:14:04 +00:00
|
|
|
img_callback=_img_callback,
|
2022-09-09 04:30:20 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
|
|
|
|
samples, _ = sampler.sample(
|
|
|
|
S=prompt.steps,
|
|
|
|
conditioning=c,
|
|
|
|
batch_size=1,
|
|
|
|
shape=shape,
|
|
|
|
unconditional_guidance_scale=prompt.prompt_strength,
|
|
|
|
unconditional_conditioning=uc,
|
|
|
|
eta=ddim_eta,
|
|
|
|
x_T=start_code,
|
2022-09-10 05:14:04 +00:00
|
|
|
img_callback=_img_callback,
|
2022-09-09 04:30:20 +00:00
|
|
|
)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-09 04:30:20 +00:00
|
|
|
x_samples = model.decode_first_stage(samples)
|
|
|
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
for x_sample in x_samples:
|
|
|
|
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
|
|
|
img = Image.fromarray(x_sample.astype(np.uint8))
|
2022-09-11 07:35:57 +00:00
|
|
|
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(img, x_sample):
|
|
|
|
logger.info(" ⚠️ Filtering NSFW image")
|
|
|
|
img = Image.new("RGB", img.size, (228, 150, 150))
|
2022-09-10 05:14:04 +00:00
|
|
|
if prompt.fix_faces:
|
|
|
|
img = fix_faces_GFPGAN(img)
|
|
|
|
# if prompt.upscale:
|
|
|
|
# enlarge_realesrgan2x(
|
|
|
|
# filepath,
|
|
|
|
# os.path.join(big_path, basefilename) + ".jpg",
|
|
|
|
# )
|
|
|
|
yield ImagineResult(img=img, prompt=prompt)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
def prompt_normalized(prompt):
|
|
|
|
return re.sub(r"[^a-zA-Z0-9.,]+", "_", prompt)[:130]
|
|
|
|
|
|
|
|
|
|
|
|
DOWNLOADED_FILES_PATH = f"{LIB_PATH}/../downloads/"
|
|
|
|
ESRGAN_PATH = DOWNLOADED_FILES_PATH + "realesrgan-ncnn-vulkan/realesrgan-ncnn-vulkan"
|
|
|
|
|
|
|
|
|
|
|
|
def enlarge_realesrgan2x(src, dst):
|
|
|
|
process = subprocess.Popen(
|
|
|
|
[ESRGAN_PATH, "-i", src, "-o", dst, "-n", "realesrgan-x4plus"]
|
|
|
|
)
|
|
|
|
process.wait()
|
|
|
|
|
|
|
|
|
|
|
|
def get_sampler(sampler_type, model):
|
|
|
|
sampler_type = sampler_type.upper()
|
|
|
|
if sampler_type == "PLMS":
|
|
|
|
return PLMSSampler(model)
|
|
|
|
elif sampler_type == "DDIM":
|
|
|
|
return DDIMSampler(model)
|
|
|
|
|
|
|
|
|
|
|
|
def gfpgan_model():
|
|
|
|
from gfpgan import GFPGANer
|
|
|
|
|
|
|
|
return GFPGANer(
|
|
|
|
model_path=DOWNLOADED_FILES_PATH
|
|
|
|
+ "GFPGAN/experiments/pretrained_models/GFPGANv1.3.pth",
|
|
|
|
upscale=1,
|
|
|
|
arch="clean",
|
|
|
|
channel_multiplier=2,
|
|
|
|
bg_upsampler=None,
|
|
|
|
device=torch.device(get_device()),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def fix_faces_GFPGAN(image):
|
|
|
|
image = image.convert("RGB")
|
|
|
|
cropped_faces, restored_faces, restored_img = gfpgan_model().enhance(
|
|
|
|
np.array(image, dtype=np.uint8),
|
|
|
|
has_aligned=False,
|
|
|
|
only_center_face=False,
|
|
|
|
paste_back=True,
|
|
|
|
)
|
|
|
|
res = Image.fromarray(restored_img)
|
|
|
|
|
|
|
|
return res
|