import logging import os import re import numpy as np import torch import torch.nn from einops import rearrange, repeat from PIL import Image, ImageDraw, ImageOps from pytorch_lightning import seed_everything from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.face_restoration_codeformer import enhance_faces from imaginairy.enhancers.upscale_realesrgan import upscale_image from imaginairy.img_utils import ( make_gif_image, pillow_fit_image_within, pillow_img_to_torch_image, ) from imaginairy.log_utils import ( ImageLoggingContext, log_conditioning, log_img, log_latent, ) from imaginairy.model_manager import get_diffusion_model from imaginairy.modules.midas.utils import AddMiDaS from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint from imaginairy.safety import SafetyMode, create_safety_score from imaginairy.samplers import SAMPLER_LOOKUP from imaginairy.samplers.base import NoiseSchedule, noise_an_image from imaginairy.samplers.editing import CFGEditingDenoiser from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.utils import ( fix_torch_group_norm, fix_torch_nn_layer_norm, get_device, platform_appropriate_autocast, randn_seeded, ) logger = logging.getLogger(__name__) # leave undocumented. I'd ask that no one publicize this flag. Just want a # slight barrier to entry. Please don't use this is any way that's gonna cause # the media or politicians to freak out about AI... IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT) if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}: IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED elif IMAGINAIRY_SAFETY_MODE == "filter": IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT # we put this in the global scope so it can be used in the interactive shell _most_recent_result = None def imagine_image_files( prompts, outdir, precision="autocast", record_step_images=False, output_file_extension="jpg", print_caption=False, make_comparison_gif=False, return_filename_type="generated", ): generated_imgs_path = os.path.join(outdir, "generated") os.makedirs(generated_imgs_path, exist_ok=True) base_count = len(os.listdir(generated_imgs_path)) output_file_extension = output_file_extension.lower() if output_file_extension not in {"jpg", "png"}: raise ValueError("Must output a png or jpg") def _record_step(img, description, image_count, step_count, prompt): steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}") os.makedirs(steps_path, exist_ok=True) filename = f"{base_count:08}_S{prompt.seed}_{image_count:04}_step{step_count:03}_{prompt_normalized(description)[:40]}.jpg" destination = os.path.join(steps_path, filename) draw = ImageDraw.Draw(img) draw.text((10, 10), str(description)) img.save(destination) result_filenames = [] for result in imagine( prompts, precision=precision, debug_img_callback=_record_step if record_step_images else None, add_caption=print_caption, ): prompt = result.prompt if prompt.is_intermediate: # we don't save intermediate images continue img_str = "" if prompt.init_image: img_str = f"_img2img-{prompt.init_image_strength}" basefilename = ( f"{base_count:06}_{prompt.seed}_{prompt.sampler_type.replace('_', '')}{prompt.steps}_" f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}" ) for image_type in result.images: subpath = os.path.join(outdir, image_type) os.makedirs(subpath, exist_ok=True) filepath = os.path.join( subpath, f"{basefilename}_[{image_type}].{output_file_extension}" ) result.save(filepath, image_type=image_type) logger.info(f"[{image_type}] saved to: {filepath}") if image_type == return_filename_type: result_filenames.append(filepath) if make_comparison_gif and prompt.init_image: subpath = os.path.join(outdir, "gif") os.makedirs(subpath, exist_ok=True) filepath = os.path.join(subpath, f"{basefilename}.gif") resized_init_image = pillow_fit_image_within( prompt.init_image, prompt.width, prompt.height ) make_gif_image( filepath, imgs=[result.images["generated"], resized_init_image], duration=1750, ) base_count += 1 del result return result_filenames def imagine( prompts, precision="autocast", debug_img_callback=None, progress_img_callback=None, progress_img_interval_steps=3, progress_img_interval_min_s=0.1, half_mode=None, add_caption=False, ): global _most_recent_result # noqa latent_channels = 4 downsampling_factor = 8 batch_size = 1 prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts try: num_prompts = str(len(prompts)) except TypeError: num_prompts = "?" if get_device() == "cpu": logger.info("Running in CPU mode. it's gonna be slooooooow.") with torch.no_grad(), platform_appropriate_autocast( precision ), fix_torch_nn_layer_norm(), fix_torch_group_norm(): for i, prompt in enumerate(prompts): # handle prompt pulling in previous values if isinstance(prompt.init_image, str) and prompt.init_image.startswith( "*prev" ): _, img_type = prompt.init_image.strip("*").split(".") prompt.init_image = _most_recent_result.images[img_type] if isinstance(prompt.mask_image, str) and prompt.mask_image.startswith( "*prev" ): _, img_type = prompt.mask_image.strip("*").split(".") prompt.mask_image = _most_recent_result.images[img_type] logger.info( f"Generating 🖼 {i + 1}/{num_prompts}: {prompt.prompt_description()}" ) model = get_diffusion_model( weights_location=prompt.model, config_path=prompt.model_config_path, half_mode=half_mode, for_inpainting=prompt.mask_image or prompt.mask_prompt or prompt.outpaint, ) has_depth_channel = hasattr(model, "depth_stage_key") with ImageLoggingContext( prompt=prompt, model=model, debug_img_callback=debug_img_callback, progress_img_callback=progress_img_callback, progress_img_interval_steps=progress_img_interval_steps, progress_img_interval_min_s=progress_img_interval_min_s, ) as lc: seed_everything(prompt.seed) model.tile_mode(prompt.tile_mode) with lc.timing("conditioning"): # need to expand if doing batches neutral_conditioning = _prompts_to_embeddings( prompt.negative_prompt, model ) log_conditioning(neutral_conditioning, "neutral conditioning") if prompt.conditioning is not None: positive_conditioning = prompt.conditioning else: positive_conditioning = _prompts_to_embeddings( prompt.prompts, model ) log_conditioning(positive_conditioning, "positive conditioning") shape = [ batch_size, latent_channels, prompt.height // downsampling_factor, prompt.width // downsampling_factor, ] SamplerCls = SAMPLER_LOOKUP[prompt.sampler_type.lower()] sampler = SamplerCls(model) mask = mask_image = mask_image_orig = mask_grayscale = None t_enc = init_latent = init_latent_noised = None starting_image = None if prompt.init_image: starting_image = prompt.init_image generation_strength = 1 - prompt.init_image_strength if model.cond_stage_key == "edit": t_enc = prompt.steps else: t_enc = int(prompt.steps * generation_strength) if prompt.mask_prompt: mask_image, mask_grayscale = get_img_mask( starting_image, prompt.mask_prompt, threshold=0.1 ) elif prompt.mask_image: mask_image = prompt.mask_image.convert("L") if prompt.outpaint: outpaint_kwargs = outpaint_arg_str_parse(prompt.outpaint) starting_image, mask_image = prepare_image_for_outpaint( starting_image, mask_image, **outpaint_kwargs ) init_image = pillow_fit_image_within( starting_image, max_height=prompt.height, max_width=prompt.width, ) if mask_image is not None: mask_image = pillow_fit_image_within( mask_image, max_height=prompt.height, max_width=prompt.width, convert="L", ) log_img(mask_image, "init mask") if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE: mask_image = ImageOps.invert(mask_image) log_img( Image.composite(init_image, mask_image, mask_image), "mask overlay", ) mask_image_orig = mask_image mask_image = mask_image.resize( ( mask_image.width // downsampling_factor, mask_image.height // downsampling_factor, ), resample=Image.Resampling.LANCZOS, ) log_img(mask_image, "latent_mask") mask = np.array(mask_image) mask = mask.astype(np.float32) / 255.0 mask = mask[None, None] mask = torch.from_numpy(mask) mask = mask.to(get_device()) init_image_t = pillow_img_to_torch_image(init_image) init_image_t = init_image_t.to(get_device()) init_latent = model.get_first_stage_encoding( model.encode_first_stage(init_image_t) ) shape = init_latent.shape log_latent(init_latent, "init_latent") # encode (scaled latent) seed_everything(prompt.seed) noise = randn_seeded(seed=prompt.seed, size=init_latent.size()) noise = noise.to(get_device()) schedule = NoiseSchedule( model_num_timesteps=model.num_timesteps, ddim_num_steps=prompt.steps, model_alphas_cumprod=model.alphas_cumprod, ddim_discretize="uniform", ) if generation_strength >= 1: # prompt strength gets converted to time encodings, # which means you can't get to true 0 without this hack # (or setting steps=1000) init_latent_noised = noise else: init_latent_noised = noise_an_image( init_latent, torch.tensor([t_enc - 1]).to(get_device()), schedule=schedule, noise=noise, ) batch_size = 1 log_latent(init_latent_noised, "init_latent_noised") batch = { "txt": batch_size * [prompt.prompt_text], } c_cat = [] c_cat_neutral = None depth_image_display = None if has_depth_channel and starting_image: midas_model = AddMiDaS() _init_image_d = np.array(starting_image.convert("RGB")) _init_image_d = ( torch.from_numpy(_init_image_d).to(dtype=torch.float32) / 127.5 - 1.0 ) depth_image = midas_model(_init_image_d) depth_image = torch.from_numpy(depth_image[None, ...]) batch[model.depth_stage_key] = depth_image.to(device=get_device()) _init_image_d = rearrange(_init_image_d, "h w c -> 1 c h w") batch["jpg"] = _init_image_d for ck in model.concat_keys: cc = batch[ck] cc = model.depth_model(cc) depth_min, depth_max = torch.amin( cc, dim=[1, 2, 3], keepdim=True ), torch.amax(cc, dim=[1, 2, 3], keepdim=True) display_depth = (cc - depth_min) / (depth_max - depth_min) depth_image_display = Image.fromarray( (display_depth[0, 0, ...].cpu().numpy() * 255.0).astype( np.uint8 ) ) cc = torch.nn.functional.interpolate( cc, size=shape[2:], mode="bicubic", align_corners=False, ) depth_min, depth_max = torch.amin( cc, dim=[1, 2, 3], keepdim=True ), torch.amax(cc, dim=[1, 2, 3], keepdim=True) cc = 2.0 * (cc - depth_min) / (depth_max - depth_min) - 1.0 c_cat.append(cc) c_cat = [torch.cat(c_cat, dim=1)] if mask_image_orig and not has_depth_channel: mask_t = pillow_img_to_torch_image( ImageOps.invert(mask_image_orig) ).to(get_device()) inverted_mask = 1 - mask masked_image_t = init_image_t * (mask_t < 0.5) batch.update( { "image": repeat( init_image_t.to(device=get_device()), "1 ... -> n ...", n=batch_size, ), "txt": batch_size * [prompt.prompt_text], "mask": repeat( inverted_mask.to(device=get_device()), "1 ... -> n ...", n=batch_size, ), "masked_image": repeat( masked_image_t.to(device=get_device()), "1 ... -> n ...", n=batch_size, ), } ) for concat_key in getattr(model, "concat_keys", []): cc = batch[concat_key].float() if concat_key != model.masked_image_key: bchw = [batch_size, 4, shape[2], shape[3]] cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) else: cc = model.get_first_stage_encoding( model.encode_first_stage(cc) ) c_cat.append(cc) if c_cat: c_cat = [torch.cat(c_cat, dim=1)] denoiser_cls = None if model.cond_stage_key == "edit": c_cat = [model.encode_first_stage(init_image_t).mode()] c_cat_neutral = [torch.zeros_like(init_latent)] denoiser_cls = CFGEditingDenoiser if c_cat_neutral is None: c_cat_neutral = c_cat positive_conditioning = { "c_concat": c_cat, "c_crossattn": [positive_conditioning], } neutral_conditioning = { "c_concat": c_cat_neutral, "c_crossattn": [neutral_conditioning], } with lc.timing("sampling"): samples = sampler.sample( num_steps=prompt.steps, initial_latent=init_latent_noised, positive_conditioning=positive_conditioning, neutral_conditioning=neutral_conditioning, guidance_scale=prompt.prompt_strength, t_start=t_enc, mask=mask, orig_latent=init_latent, shape=shape, batch_size=1, denoiser_cls=denoiser_cls, ) # from torch.nn.functional import interpolate # samples = interpolate(samples, scale_factor=2, mode='nearest') x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: x_sample = x_sample.to(torch.float32) x_sample = 255.0 * rearrange( x_sample.cpu().numpy(), "c h w -> h w c" ) x_sample_8_orig = x_sample.astype(np.uint8) img = Image.fromarray(x_sample_8_orig) if mask_image_orig and init_image: # mask_final = mask_image_orig.filter( # ImageFilter.GaussianBlur(radius=3) # ) mask_final = mask_image_orig.copy() log_img(mask_final, "reconstituting mask") mask_final = ImageOps.invert(mask_final) img = Image.composite(img, init_image, mask_final) log_img(img, "reconstituted image") upscaled_img = None rebuilt_orig_img = None if add_caption: caption = generate_caption(img) logger.info(f"Generated caption: {caption}") with lc.timing("safety-filter"): safety_score = create_safety_score( img, safety_mode=IMAGINAIRY_SAFETY_MODE, ) if not safety_score.is_filtered: if prompt.fix_faces: logger.info("Fixing 😊 's in 🖼 using CodeFormer...") with lc.timing("face enhancement"): img = enhance_faces( img, fidelity=prompt.fix_faces_fidelity ) if prompt.upscale: logger.info("Upscaling 🖼 using real-ESRGAN...") with lc.timing("upscaling"): upscaled_img = upscale_image(img) # put the newly generated patch back into the original, full size image if ( prompt.mask_modify_original and mask_image_orig and starting_image ): img_to_add_back_to_original = ( upscaled_img if upscaled_img else img ) img_to_add_back_to_original = ( img_to_add_back_to_original.resize( starting_image.size, resample=Image.Resampling.LANCZOS, ) ) mask_for_orig_size = mask_image_orig.resize( starting_image.size, resample=Image.Resampling.LANCZOS, ) # mask_for_orig_size = mask_for_orig_size.filter( # ImageFilter.GaussianBlur(radius=5) # ) log_img(mask_for_orig_size, "mask for original image size") rebuilt_orig_img = Image.composite( starting_image, img_to_add_back_to_original, mask_for_orig_size, ) log_img(rebuilt_orig_img, "reconstituted original") result = ImagineResult( img=img, prompt=prompt, upscaled_img=upscaled_img, is_nsfw=safety_score.is_nsfw, safety_score=safety_score, modified_original=rebuilt_orig_img, mask_binary=mask_image_orig, mask_grayscale=mask_grayscale, depth_image=depth_image_display, timings=lc.get_timings(), ) _most_recent_result = result logger.info(f"Image Generated. Timings: {result.timings_str()}") yield result def _prompts_to_embeddings(prompts, model): total_weight = sum(wp.weight for wp in prompts) conditioning = sum( model.get_learned_conditioning(wp.text) * (wp.weight / total_weight) for wp in prompts ) return conditioning def prompt_normalized(prompt): return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]