2022-09-09 04:51:25 +00:00
|
|
|
import logging
|
2022-09-08 03:59:30 +00:00
|
|
|
import os
|
|
|
|
import re
|
2022-09-09 04:51:25 +00:00
|
|
|
from functools import lru_cache
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
import numpy as np
|
2022-09-24 05:58:48 +00:00
|
|
|
import PIL
|
2022-09-08 03:59:30 +00:00
|
|
|
import torch
|
2022-09-11 10:08:51 +00:00
|
|
|
import torch.nn
|
2022-09-08 03:59:30 +00:00
|
|
|
from einops import rearrange
|
|
|
|
from omegaconf import OmegaConf
|
2022-09-18 13:07:07 +00:00
|
|
|
from PIL import Image, ImageDraw, ImageFilter, ImageOps
|
2022-09-08 03:59:30 +00:00
|
|
|
from pytorch_lightning import seed_everything
|
2022-09-10 07:32:31 +00:00
|
|
|
from transformers import cached_path
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
from imaginairy.enhancers.clip_masking import get_img_mask
|
2022-09-20 04:15:38 +00:00
|
|
|
from imaginairy.enhancers.describe_image_blip import generate_caption
|
2022-09-13 07:27:53 +00:00
|
|
|
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
|
|
|
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
2022-10-11 02:50:11 +00:00
|
|
|
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
|
|
|
from imaginairy.log_utils import (
|
2022-09-18 13:07:07 +00:00
|
|
|
ImageLoggingContext,
|
|
|
|
log_conditioning,
|
|
|
|
log_img,
|
|
|
|
log_latent,
|
|
|
|
)
|
2022-10-10 08:22:11 +00:00
|
|
|
from imaginairy.safety import SafetyMode, create_safety_score
|
2022-09-14 07:40:25 +00:00
|
|
|
from imaginairy.samplers.base import get_sampler
|
2022-10-06 04:43:00 +00:00
|
|
|
from imaginairy.samplers.plms import PLMSSchedule
|
2022-09-10 05:14:04 +00:00
|
|
|
from imaginairy.schema import ImaginePrompt, ImagineResult
|
2022-09-10 07:32:31 +00:00
|
|
|
from imaginairy.utils import (
|
2022-09-22 05:03:12 +00:00
|
|
|
fix_torch_group_norm,
|
2022-09-11 20:58:14 +00:00
|
|
|
fix_torch_nn_layer_norm,
|
2022-09-10 07:32:31 +00:00
|
|
|
get_device,
|
|
|
|
instantiate_from_config,
|
2022-09-22 05:38:44 +00:00
|
|
|
platform_appropriate_autocast,
|
2022-09-10 07:32:31 +00:00
|
|
|
)
|
|
|
|
|
2022-09-08 03:59:30 +00:00
|
|
|
LIB_PATH = os.path.dirname(__file__)
|
2022-09-09 04:51:25 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-15 02:40:50 +00:00
|
|
|
|
|
|
|
# leave undocumented. I'd ask that no one publicize this flag. Just want a
|
|
|
|
# slight barrier to entry. Please don't use this is any way that's gonna cause
|
2022-10-10 08:22:11 +00:00
|
|
|
# the media or politicians to freak out about AI...
|
|
|
|
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT)
|
|
|
|
if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
|
|
|
|
IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED
|
|
|
|
elif IMAGINAIRY_SAFETY_MODE == "filter":
|
|
|
|
IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT
|
2022-09-11 07:35:57 +00:00
|
|
|
|
2022-10-06 04:50:20 +00:00
|
|
|
DEFAULT_MODEL_WEIGHTS_LOCATION = (
|
|
|
|
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
|
|
|
|
)
|
|
|
|
|
2022-09-11 07:35:57 +00:00
|
|
|
|
2022-10-06 04:50:20 +00:00
|
|
|
def load_model_from_config(
|
|
|
|
config, model_weights_location=DEFAULT_MODEL_WEIGHTS_LOCATION
|
|
|
|
):
|
|
|
|
model_weights_location = (
|
|
|
|
model_weights_location
|
|
|
|
if model_weights_location
|
|
|
|
else DEFAULT_MODEL_WEIGHTS_LOCATION
|
|
|
|
)
|
|
|
|
if model_weights_location.startswith("http"):
|
|
|
|
ckpt_path = cached_path(model_weights_location)
|
|
|
|
else:
|
|
|
|
ckpt_path = model_weights_location
|
2022-10-10 08:22:11 +00:00
|
|
|
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
|
2022-09-10 07:32:31 +00:00
|
|
|
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
2022-09-08 03:59:30 +00:00
|
|
|
if "global_step" in pl_sd:
|
2022-09-11 06:27:22 +00:00
|
|
|
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
2022-09-08 03:59:30 +00:00
|
|
|
sd = pl_sd["state_dict"]
|
|
|
|
model = instantiate_from_config(config.model)
|
|
|
|
m, u = model.load_state_dict(sd, strict=False)
|
2022-09-10 05:14:04 +00:00
|
|
|
if len(m) > 0:
|
2022-09-11 06:27:22 +00:00
|
|
|
logger.debug(f"missing keys: {m}")
|
2022-09-10 05:14:04 +00:00
|
|
|
if len(u) > 0:
|
2022-09-11 06:27:22 +00:00
|
|
|
logger.debug(f"unexpected keys: {u}")
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-10 07:32:31 +00:00
|
|
|
model.to(get_device())
|
2022-09-08 03:59:30 +00:00
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2022-09-09 04:51:25 +00:00
|
|
|
@lru_cache()
|
2022-10-06 04:50:20 +00:00
|
|
|
def load_model(model_weights_location=None):
|
2022-09-10 07:32:31 +00:00
|
|
|
config = "configs/stable-diffusion-v1.yaml"
|
2022-09-09 04:51:25 +00:00
|
|
|
config = OmegaConf.load(f"{LIB_PATH}/{config}")
|
2022-10-06 04:50:20 +00:00
|
|
|
model = load_model_from_config(
|
|
|
|
config, model_weights_location=model_weights_location
|
|
|
|
)
|
2022-09-09 04:51:25 +00:00
|
|
|
model = model.to(get_device())
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
def imagine_image_files(
|
2022-09-08 03:59:30 +00:00
|
|
|
prompts,
|
2022-09-10 05:14:04 +00:00
|
|
|
outdir,
|
2022-09-08 03:59:30 +00:00
|
|
|
latent_channels=4,
|
|
|
|
downsampling_factor=8,
|
|
|
|
precision="autocast",
|
2022-09-11 06:27:22 +00:00
|
|
|
record_step_images=False,
|
|
|
|
output_file_extension="jpg",
|
2022-09-20 04:15:38 +00:00
|
|
|
print_caption=False,
|
2022-10-06 04:50:20 +00:00
|
|
|
model_weights_path=None,
|
2022-09-08 03:59:30 +00:00
|
|
|
):
|
2022-09-24 18:21:53 +00:00
|
|
|
generated_imgs_path = os.path.join(outdir, "generated")
|
|
|
|
os.makedirs(generated_imgs_path, exist_ok=True)
|
2022-09-13 07:27:53 +00:00
|
|
|
|
2022-09-24 21:41:25 +00:00
|
|
|
base_count = len(os.listdir(generated_imgs_path))
|
2022-09-11 06:27:22 +00:00
|
|
|
output_file_extension = output_file_extension.lower()
|
|
|
|
if output_file_extension not in {"jpg", "png"}:
|
|
|
|
raise ValueError("Must output a png or jpg")
|
2022-09-10 05:14:04 +00:00
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
def _record_step(img, description, step_count, prompt):
|
2022-09-10 05:14:04 +00:00
|
|
|
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
|
|
|
|
os.makedirs(steps_path, exist_ok=True)
|
2022-09-17 05:21:20 +00:00
|
|
|
filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}_{prompt_normalized(description)[:40]}.jpg"
|
2022-09-20 04:15:38 +00:00
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
destination = os.path.join(steps_path, filename)
|
|
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
draw.text((10, 10), str(description))
|
|
|
|
img.save(destination)
|
|
|
|
|
2022-09-13 07:27:53 +00:00
|
|
|
for result in imagine(
|
2022-09-10 05:14:04 +00:00
|
|
|
prompts,
|
|
|
|
latent_channels=latent_channels,
|
|
|
|
downsampling_factor=downsampling_factor,
|
|
|
|
precision=precision,
|
2022-09-14 07:40:25 +00:00
|
|
|
img_callback=_record_step if record_step_images else None,
|
2022-09-20 04:15:38 +00:00
|
|
|
add_caption=print_caption,
|
2022-10-06 04:50:20 +00:00
|
|
|
model_weights_path=model_weights_path,
|
2022-09-10 05:14:04 +00:00
|
|
|
):
|
|
|
|
prompt = result.prompt
|
2022-09-28 00:04:16 +00:00
|
|
|
img_str = ""
|
|
|
|
if prompt.init_image:
|
|
|
|
img_str = f"_img2img-{prompt.init_image_strength}"
|
|
|
|
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}"
|
2022-09-25 20:07:27 +00:00
|
|
|
|
2022-09-26 04:55:25 +00:00
|
|
|
for image_type in result.images:
|
2022-09-25 20:07:27 +00:00
|
|
|
subpath = os.path.join(outdir, image_type)
|
|
|
|
os.makedirs(subpath, exist_ok=True)
|
|
|
|
filepath = os.path.join(
|
|
|
|
subpath, f"{basefilename}_[{image_type}].{output_file_extension}"
|
|
|
|
)
|
2022-09-26 04:55:25 +00:00
|
|
|
result.save(filepath, image_type=image_type)
|
2022-10-11 03:13:32 +00:00
|
|
|
logger.info(f"🖼 [{image_type}] saved to: {filepath}")
|
2022-09-10 05:14:04 +00:00
|
|
|
base_count += 1
|
2022-09-17 05:21:20 +00:00
|
|
|
del result
|
2022-09-10 05:14:04 +00:00
|
|
|
|
|
|
|
|
2022-09-13 07:27:53 +00:00
|
|
|
def imagine(
|
2022-09-10 05:14:04 +00:00
|
|
|
prompts,
|
|
|
|
latent_channels=4,
|
|
|
|
downsampling_factor=8,
|
|
|
|
precision="autocast",
|
|
|
|
img_callback=None,
|
2022-09-12 04:32:11 +00:00
|
|
|
half_mode=None,
|
2022-09-20 04:15:38 +00:00
|
|
|
add_caption=False,
|
2022-10-06 04:50:20 +00:00
|
|
|
model_weights_path=None,
|
2022-09-10 05:14:04 +00:00
|
|
|
):
|
2022-10-06 04:50:20 +00:00
|
|
|
model = load_model(model_weights_location=model_weights_path)
|
2022-09-13 07:27:53 +00:00
|
|
|
|
2022-09-12 04:32:11 +00:00
|
|
|
# only run half-mode on cuda. run it by default
|
2022-09-13 07:27:53 +00:00
|
|
|
half_mode = half_mode is None and get_device() == "cuda"
|
2022-09-12 04:32:11 +00:00
|
|
|
if half_mode:
|
|
|
|
model = model.half()
|
|
|
|
# needed when model is in half mode, remove if not using half mode
|
2022-09-13 07:27:53 +00:00
|
|
|
# torch.set_default_tensor_type(torch.HalfTensor)
|
2022-09-10 05:14:04 +00:00
|
|
|
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
|
|
|
|
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
|
|
|
|
_img_callback = None
|
2022-09-22 05:03:12 +00:00
|
|
|
if get_device() == "cpu":
|
|
|
|
logger.info("Running in CPU mode. it's gonna be slooooooow.")
|
2022-09-22 05:38:44 +00:00
|
|
|
|
|
|
|
with torch.no_grad(), platform_appropriate_autocast(
|
|
|
|
precision
|
2022-09-22 05:03:12 +00:00
|
|
|
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
|
2022-09-08 03:59:30 +00:00
|
|
|
for prompt in prompts:
|
2022-10-11 03:13:32 +00:00
|
|
|
logger.info(f"Generating 🖼 : {prompt.prompt_description()}")
|
2022-09-17 05:21:20 +00:00
|
|
|
with ImageLoggingContext(
|
|
|
|
prompt=prompt,
|
|
|
|
model=model,
|
|
|
|
img_callback=img_callback,
|
2022-09-14 07:40:25 +00:00
|
|
|
):
|
|
|
|
seed_everything(prompt.seed)
|
2022-09-21 05:57:03 +00:00
|
|
|
model.tile_mode(prompt.tile_mode)
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
uc = None
|
|
|
|
if prompt.prompt_strength != 1.0:
|
|
|
|
uc = model.get_learned_conditioning(1 * [""])
|
2022-09-17 05:21:20 +00:00
|
|
|
log_conditioning(uc, "neutral conditioning")
|
|
|
|
if prompt.conditioning is not None:
|
|
|
|
c = prompt.conditioning
|
|
|
|
else:
|
|
|
|
total_weight = sum(wp.weight for wp in prompt.prompts)
|
|
|
|
c = sum(
|
|
|
|
model.get_learned_conditioning(wp.text)
|
|
|
|
* (wp.weight / total_weight)
|
|
|
|
for wp in prompt.prompts
|
|
|
|
)
|
|
|
|
log_conditioning(c, "positive conditioning")
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
shape = [
|
2022-10-06 04:43:00 +00:00
|
|
|
1,
|
2022-09-14 07:40:25 +00:00
|
|
|
latent_channels,
|
|
|
|
prompt.height // downsampling_factor,
|
|
|
|
prompt.width // downsampling_factor,
|
|
|
|
]
|
2022-09-20 15:42:00 +00:00
|
|
|
if prompt.init_image and prompt.sampler_type not in ("ddim", "plms"):
|
|
|
|
sampler_type = "plms"
|
|
|
|
logger.info(" Sampler type switched to plms for img2img")
|
2022-09-18 13:07:07 +00:00
|
|
|
else:
|
|
|
|
sampler_type = prompt.sampler_type
|
2022-09-24 05:58:48 +00:00
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
sampler = get_sampler(sampler_type, model)
|
2022-09-26 04:55:25 +00:00
|
|
|
mask, mask_image, mask_image_orig, mask_grayscale = (
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
)
|
2022-09-14 07:40:25 +00:00
|
|
|
if prompt.init_image:
|
|
|
|
generation_strength = 1 - prompt.init_image_strength
|
2022-09-24 21:41:25 +00:00
|
|
|
t_enc = int(prompt.steps * generation_strength)
|
2022-09-24 05:58:48 +00:00
|
|
|
try:
|
2022-09-24 07:29:45 +00:00
|
|
|
init_image = pillow_fit_image_within(
|
2022-09-24 05:58:48 +00:00
|
|
|
prompt.init_image,
|
|
|
|
max_height=prompt.height,
|
|
|
|
max_width=prompt.width,
|
|
|
|
)
|
|
|
|
except PIL.UnidentifiedImageError:
|
2022-10-11 03:13:32 +00:00
|
|
|
logger.warning(f"Could not load image: {prompt.init_image}")
|
2022-09-24 05:58:48 +00:00
|
|
|
continue
|
2022-09-20 04:15:38 +00:00
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
init_image_t = pillow_img_to_torch_image(init_image)
|
|
|
|
|
|
|
|
if prompt.mask_prompt:
|
2022-09-26 04:55:25 +00:00
|
|
|
mask_image, mask_grayscale = get_img_mask(
|
2022-09-24 05:58:48 +00:00
|
|
|
init_image, prompt.mask_prompt, threshold=0.1
|
|
|
|
)
|
2022-09-18 13:07:07 +00:00
|
|
|
elif prompt.mask_image:
|
2022-09-24 05:58:48 +00:00
|
|
|
mask_image = prompt.mask_image.convert("L")
|
2022-09-18 13:07:07 +00:00
|
|
|
|
|
|
|
if mask_image is not None:
|
|
|
|
log_img(mask_image, "init mask")
|
|
|
|
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
|
|
|
|
mask_image = ImageOps.invert(mask_image)
|
2022-09-18 22:24:31 +00:00
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
log_img(
|
|
|
|
Image.composite(init_image, mask_image, mask_image),
|
|
|
|
"mask overlay",
|
|
|
|
)
|
|
|
|
mask_image_orig = mask_image
|
|
|
|
mask_image = mask_image.resize(
|
|
|
|
(
|
|
|
|
mask_image.width // downsampling_factor,
|
|
|
|
mask_image.height // downsampling_factor,
|
|
|
|
),
|
2022-09-26 04:55:25 +00:00
|
|
|
resample=Image.Resampling.LANCZOS,
|
2022-09-18 13:07:07 +00:00
|
|
|
)
|
2022-09-24 05:58:48 +00:00
|
|
|
log_img(mask_image, "latent_mask")
|
2022-09-18 13:07:07 +00:00
|
|
|
|
|
|
|
mask = np.array(mask_image)
|
|
|
|
mask = mask.astype(np.float32) / 255.0
|
|
|
|
mask = mask[None, None]
|
|
|
|
mask = torch.from_numpy(mask)
|
|
|
|
mask = mask.to(get_device())
|
|
|
|
|
|
|
|
init_image_t = init_image_t.to(get_device())
|
2022-09-14 07:40:25 +00:00
|
|
|
init_latent = model.get_first_stage_encoding(
|
2022-09-18 13:07:07 +00:00
|
|
|
model.encode_first_stage(init_image_t)
|
2022-09-14 07:40:25 +00:00
|
|
|
)
|
|
|
|
|
2022-09-14 16:37:45 +00:00
|
|
|
log_latent(init_latent, "init_latent")
|
2022-09-14 07:40:25 +00:00
|
|
|
# encode (scaled latent)
|
2022-09-28 00:04:16 +00:00
|
|
|
seed_everything(prompt.seed)
|
2022-09-26 04:55:25 +00:00
|
|
|
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
2022-10-11 04:06:52 +00:00
|
|
|
# todo: this isn't the right scheduler for everything...
|
2022-10-06 04:43:00 +00:00
|
|
|
schedule = PLMSSchedule(
|
|
|
|
ddpm_num_timesteps=model.num_timesteps,
|
|
|
|
ddim_num_steps=prompt.steps,
|
|
|
|
alphas_cumprod=model.alphas_cumprod,
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
)
|
2022-10-04 13:12:42 +00:00
|
|
|
if generation_strength >= 1:
|
|
|
|
# prompt strength gets converted to time encodings,
|
|
|
|
# which means you can't get to true 0 without this hack
|
|
|
|
# (or setting steps=1000)
|
|
|
|
z_enc = noise
|
|
|
|
else:
|
2022-10-06 04:43:00 +00:00
|
|
|
z_enc = sampler.noise_an_image(
|
2022-10-04 13:12:42 +00:00
|
|
|
init_latent,
|
|
|
|
torch.tensor([t_enc - 1]).to(get_device()),
|
2022-10-06 04:43:00 +00:00
|
|
|
schedule=schedule,
|
2022-10-04 13:12:42 +00:00
|
|
|
noise=noise,
|
|
|
|
)
|
2022-09-14 07:40:25 +00:00
|
|
|
log_latent(z_enc, "z_enc")
|
|
|
|
|
|
|
|
# decode it
|
|
|
|
samples = sampler.decode(
|
2022-10-06 04:43:00 +00:00
|
|
|
initial_latent=z_enc,
|
2022-10-04 13:12:42 +00:00
|
|
|
cond=c,
|
|
|
|
t_start=t_enc,
|
2022-10-06 04:43:00 +00:00
|
|
|
schedule=schedule,
|
2022-09-14 07:40:25 +00:00
|
|
|
unconditional_guidance_scale=prompt.prompt_strength,
|
|
|
|
unconditional_conditioning=uc,
|
|
|
|
img_callback=_img_callback,
|
2022-09-18 13:07:07 +00:00
|
|
|
mask=mask,
|
|
|
|
orig_latent=init_latent,
|
2022-09-14 07:40:25 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
|
2022-09-17 19:24:27 +00:00
|
|
|
samples = sampler.sample(
|
2022-09-14 07:40:25 +00:00
|
|
|
num_steps=prompt.steps,
|
|
|
|
conditioning=c,
|
|
|
|
batch_size=1,
|
|
|
|
shape=shape,
|
|
|
|
unconditional_guidance_scale=prompt.prompt_strength,
|
|
|
|
unconditional_conditioning=uc,
|
|
|
|
img_callback=_img_callback,
|
|
|
|
)
|
|
|
|
|
|
|
|
x_samples = model.decode_first_stage(samples)
|
|
|
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
|
|
|
for x_sample in x_samples:
|
2022-09-22 17:33:35 +00:00
|
|
|
x_sample = x_sample.to(torch.float32)
|
2022-09-14 07:40:25 +00:00
|
|
|
x_sample = 255.0 * rearrange(
|
2022-09-22 17:33:35 +00:00
|
|
|
x_sample.cpu().numpy(), "c h w -> h w c"
|
2022-09-14 07:40:25 +00:00
|
|
|
)
|
|
|
|
x_sample_8_orig = x_sample.astype(np.uint8)
|
|
|
|
img = Image.fromarray(x_sample_8_orig)
|
2022-09-18 13:07:07 +00:00
|
|
|
if mask_image_orig and init_image:
|
2022-09-26 04:55:25 +00:00
|
|
|
mask_final = mask_image_orig.filter(
|
2022-09-18 13:07:07 +00:00
|
|
|
ImageFilter.GaussianBlur(radius=3)
|
|
|
|
)
|
2022-09-26 04:55:25 +00:00
|
|
|
log_img(mask_final, "reconstituting mask")
|
|
|
|
mask_final = ImageOps.invert(mask_final)
|
|
|
|
img = Image.composite(img, init_image, mask_final)
|
2022-09-18 13:07:07 +00:00
|
|
|
log_img(img, "reconstituted image")
|
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
upscaled_img = None
|
2022-09-24 05:58:48 +00:00
|
|
|
rebuilt_orig_img = None
|
2022-10-10 08:22:11 +00:00
|
|
|
|
2022-09-20 04:15:38 +00:00
|
|
|
if add_caption:
|
|
|
|
caption = generate_caption(img)
|
2022-10-11 03:13:32 +00:00
|
|
|
logger.info(f"Generated caption: {caption}")
|
2022-10-10 08:22:11 +00:00
|
|
|
|
|
|
|
safety_score = create_safety_score(
|
|
|
|
img,
|
|
|
|
safety_mode=IMAGINAIRY_SAFETY_MODE,
|
|
|
|
)
|
|
|
|
if not safety_score.is_filtered:
|
|
|
|
if prompt.fix_faces:
|
2022-10-11 03:13:32 +00:00
|
|
|
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
|
2022-10-10 08:22:11 +00:00
|
|
|
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
|
|
|
|
if prompt.upscale:
|
2022-10-11 03:13:32 +00:00
|
|
|
logger.info("Upscaling 🖼 using real-ESRGAN...")
|
2022-10-10 08:22:11 +00:00
|
|
|
upscaled_img = upscale_image(img)
|
|
|
|
|
|
|
|
# put the newly generated patch back into the original, full size image
|
|
|
|
if (
|
|
|
|
prompt.mask_modify_original
|
|
|
|
and mask_image_orig
|
|
|
|
and prompt.init_image
|
|
|
|
):
|
|
|
|
img_to_add_back_to_original = (
|
|
|
|
upscaled_img if upscaled_img else img
|
|
|
|
)
|
|
|
|
img_to_add_back_to_original = (
|
|
|
|
img_to_add_back_to_original.resize(
|
|
|
|
prompt.init_image.size,
|
|
|
|
resample=Image.Resampling.LANCZOS,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
mask_for_orig_size = mask_image_orig.resize(
|
2022-09-24 05:58:48 +00:00
|
|
|
prompt.init_image.size,
|
|
|
|
resample=Image.Resampling.LANCZOS,
|
|
|
|
)
|
2022-10-10 08:22:11 +00:00
|
|
|
mask_for_orig_size = mask_for_orig_size.filter(
|
|
|
|
ImageFilter.GaussianBlur(radius=5)
|
|
|
|
)
|
|
|
|
log_img(mask_for_orig_size, "mask for original image size")
|
2022-09-24 05:58:48 +00:00
|
|
|
|
2022-10-10 08:22:11 +00:00
|
|
|
rebuilt_orig_img = Image.composite(
|
|
|
|
prompt.init_image,
|
|
|
|
img_to_add_back_to_original,
|
|
|
|
mask_for_orig_size,
|
|
|
|
)
|
|
|
|
log_img(rebuilt_orig_img, "reconstituted original")
|
2022-09-24 05:58:48 +00:00
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
yield ImagineResult(
|
2022-09-15 02:40:50 +00:00
|
|
|
img=img,
|
|
|
|
prompt=prompt,
|
|
|
|
upscaled_img=upscaled_img,
|
2022-10-10 08:22:11 +00:00
|
|
|
is_nsfw=safety_score.is_nsfw,
|
|
|
|
safety_score=safety_score,
|
2022-09-26 04:55:25 +00:00
|
|
|
modified_original=rebuilt_orig_img,
|
|
|
|
mask_binary=mask_image_orig,
|
|
|
|
mask_grayscale=mask_grayscale,
|
2022-09-14 07:40:25 +00:00
|
|
|
)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
def prompt_normalized(prompt):
|
2022-09-25 20:07:27 +00:00
|
|
|
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]
|