feature: clean up terminal output

- recording timing and memory usage of various steps
- re-use logging context for composition images
- load sdxl weights in a more VRAM efficient way
- switch to diffusers weights for default weights for sd15
pull/435/head
Bryce 5 months ago committed by Bryce Drennan
parent 0d78b8271f
commit 9e3403df89

@ -84,7 +84,6 @@ def imagine_image_files(
f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_" f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_"
f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}" f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}"
) )
for image_type in result.images: for image_type in result.images:
subpath = os.path.join(outdir, image_type) subpath = os.path.join(outdir, image_type)
os.makedirs(subpath, exist_ok=True) os.makedirs(subpath, exist_ok=True)
@ -92,7 +91,7 @@ def imagine_image_files(
subpath, f"{basefilename}_[{image_type}].{output_file_extension}" subpath, f"{basefilename}_[{image_type}].{output_file_extension}"
) )
result.save(filepath, image_type=image_type) result.save(filepath, image_type=image_type)
logger.info(f" [{image_type}] saved to: {filepath}") logger.info(f" {image_type:<22} {filepath}")
if image_type == return_filename_type: if image_type == return_filename_type:
result_filenames.append(filepath) result_filenames.append(filepath)
if videogen: if videogen:
@ -123,7 +122,8 @@ def imagine_image_files(
start_pause_duration_ms=1500, start_pause_duration_ms=1500,
end_pause_duration_ms=1000, end_pause_duration_ms=1000,
) )
logger.info(f" [gif] {len(frames)} frames saved to: {filepath}") image_type = "gif"
logger.info(f" {image_type:<22} {filepath}")
if make_compare_gif and prompt.init_image: if make_compare_gif and prompt.init_image:
subpath = os.path.join(outdir, "gif") subpath = os.path.join(outdir, "gif")
os.makedirs(subpath, exist_ok=True) os.makedirs(subpath, exist_ok=True)
@ -137,7 +137,8 @@ def imagine_image_files(
imgs=frames, imgs=frames,
outpath=filepath, outpath=filepath,
) )
logger.info(f" [gif-comparison] saved to: {filepath}") image_type = "gif"
logger.info(f" {image_type:<22} {filepath}")
base_count += 1 base_count += 1
del result del result
@ -192,9 +193,8 @@ def imagine(
), fix_torch_nn_layer_norm(), fix_torch_group_norm(): ), fix_torch_nn_layer_norm(), fix_torch_group_norm():
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
concrete_prompt = prompt.make_concrete_copy() concrete_prompt = prompt.make_concrete_copy()
logger.info( prog_text = f"{i + 1}/{num_prompts}"
f"🖼 Generating {i + 1}/{num_prompts}: {concrete_prompt.prompt_description()}" logger.info(f"🖼 {prog_text} {concrete_prompt.prompt_description()}")
)
for attempt in range(unsafe_retry_count + 1): for attempt in range(unsafe_retry_count + 1):
if attempt > 0 and isinstance(concrete_prompt.seed, int): if attempt > 0 and isinstance(concrete_prompt.seed, int):
concrete_prompt.seed += 100_000_000 + attempt concrete_prompt.seed += 100_000_000 + attempt
@ -204,7 +204,6 @@ def imagine(
progress_img_callback=progress_img_callback, progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps, progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s, progress_img_interval_min_s=progress_img_interval_min_s,
half_mode=half_mode,
add_caption=add_caption, add_caption=add_caption,
dtype=torch.float16 if half_mode else torch.float32, dtype=torch.float16 if half_mode else torch.float32,
) )

@ -475,7 +475,7 @@ def _generate_single_image_compvis(
is_nsfw=safety_score.is_nsfw, is_nsfw=safety_score.is_nsfw,
safety_score=safety_score, safety_score=safety_score,
result_images=result_images, result_images=result_images,
timings=lc.get_timings(), performance_stats=lc.get_performance_stats(),
progress_latents=progress_latents.copy(), progress_latents=progress_latents.copy(),
) )

@ -1,10 +1,13 @@
"""Functions for generating refined images""" """Functions for generating refined images"""
import logging import logging
from contextlib import nullcontext
from typing import Any
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode
from imaginairy.utils import clear_gpu_cache from imaginairy.utils import clear_gpu_cache
from imaginairy.utils.log_utils import ImageLoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,7 +21,9 @@ def generate_single_image(
add_caption=False, add_caption=False,
return_latent=False, return_latent=False,
dtype=None, dtype=None,
half_mode=None, logging_context: ImageLoggingContext | None = None,
output_perf=False,
image_name="",
): ):
import torch.nn import torch.nn
from PIL import Image, ImageOps from PIL import Image, ImageOps
@ -59,41 +64,48 @@ def generate_single_image(
from imaginairy.utils.safety import create_safety_score from imaginairy.utils.safety import create_safety_score
if dtype is None: if dtype is None:
dtype = torch.float16 if half_mode else torch.float32 dtype = torch.float16
get_device() get_device()
clear_gpu_cache() clear_gpu_cache()
prompt = prompt.make_concrete_copy() prompt = prompt.make_concrete_copy()
sd = get_diffusion_model_refiners( if not logging_context:
weights_config=prompt.model_weights,
for_inpainting=prompt.should_use_inpainting def latent_logger(latents):
and prompt.inpaint_method == "finetune", progress_latents.append(latents)
dtype=dtype,
) lc = ImageLoggingContext(
prompt=prompt,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
progress_latent_callback=latent_logger
if prompt.collect_progress_latents
else None,
)
_context: Any = lc
else:
lc = logging_context
_context = nullcontext()
with _context:
with lc.timing("model-load"):
sd = get_diffusion_model_refiners(
weights_config=prompt.model_weights,
for_inpainting=prompt.should_use_inpainting
and prompt.inpaint_method == "finetune",
dtype=dtype,
)
lc.model = sd
seed_everything(prompt.seed)
downsampling_factor = 8
latent_channels = 4
batch_size = 1
mask_image = None
mask_image_orig = None
seed_everything(prompt.seed)
downsampling_factor = 8
latent_channels = 4
batch_size = 1
mask_image = None
mask_image_orig = None
def latent_logger(latents):
progress_latents.append(latents)
with ImageLoggingContext(
prompt=prompt,
model=sd,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
progress_latent_callback=latent_logger
if prompt.collect_progress_latents
else None,
) as lc:
sd.set_tile_mode(prompt.tile_mode) sd.set_tile_mode(prompt.tile_mode)
result_images: dict[str, torch.Tensor | None | Image.Image] = {} result_images: dict[str, torch.Tensor | None | Image.Image] = {}
@ -178,63 +190,14 @@ def generate_single_image(
assert prompt.seed is not None assert prompt.seed is not None
noise = randn_seeded(seed=prompt.seed, size=shape).to( noise = randn_seeded(seed=prompt.seed, size=shape).to(
get_device(), dtype=sd.dtype sd.unet.device, dtype=sd.unet.dtype
) )
noised_latent = noise noised_latent = noise
controlnets = [] controlnets = []
if control_modes: if control_modes:
for control_input in control_inputs: with lc.timing("control-image-prep"):
controlnet, control_image_t, control_image_disp = prep_control_input( for control_input in control_inputs:
control_input=control_input,
sd=sd,
init_image_t=init_image_t,
fit_width=prompt.width,
fit_height=prompt.height,
)
result_images[f"control-{control_input.mode}"] = control_image_disp
controlnets.append((controlnet, control_image_t))
if prompt.allow_compose_phase:
cutoff_size = get_model_default_image_size(prompt.model_architecture)
cutoff_size = (int(cutoff_size[0] * 1.30), int(cutoff_size[1] * 1.30))
compose_kwargs = {
"prompt": prompt,
"target_height": prompt.height,
"target_width": prompt.width,
"cutoff": cutoff_size,
"dtype": dtype,
}
if prompt.init_image:
compose_kwargs.update(
{
"target_height": init_image.height,
"target_width": init_image.width,
}
)
comp_image, comp_img_orig = _generate_composition_image(**compose_kwargs)
if comp_image is not None:
prompt.fix_faces = False # done in composition
result_images["composition"] = comp_img_orig
result_images["composition-upscaled"] = comp_image
composition_strength = prompt.composition_strength
first_step = int((prompt.steps) * composition_strength)
noise_step = int((prompt.steps - 1) * composition_strength)
log_img(comp_img_orig, "comp_image")
log_img(comp_image, "comp_image_upscaled")
comp_image_t = pillow_img_to_torch_image(comp_image)
comp_image_t = comp_image_t.to(sd.lda.device, dtype=sd.lda.dtype)
init_latent = sd.lda.encode(comp_image_t)
compose_control_inputs: list[ControlInput]
if prompt.model_weights.architecture.primary_alias == "sdxl":
compose_control_inputs = []
else:
compose_control_inputs = [
ControlInput(mode="details", image=comp_image, strength=1),
]
for control_input in compose_control_inputs:
( (
controlnet, controlnet,
control_image_t, control_image_t,
@ -242,16 +205,73 @@ def generate_single_image(
) = prep_control_input( ) = prep_control_input(
control_input=control_input, control_input=control_input,
sd=sd, sd=sd,
init_image_t=None, init_image_t=init_image_t,
fit_width=prompt.width, fit_width=prompt.width,
fit_height=prompt.height, fit_height=prompt.height,
) )
result_images[f"control-{control_input.mode}"] = control_image_disp result_images[f"control-{control_input.mode}"] = control_image_disp
controlnets.append((controlnet, control_image_t)) controlnets.append((controlnet, control_image_t))
if prompt.allow_compose_phase:
with lc.timing("composition"):
cutoff_size = get_model_default_image_size(prompt.model_architecture)
cutoff_size = (int(cutoff_size[0] * 1.30), int(cutoff_size[1] * 1.30))
compose_kwargs = {
"prompt": prompt,
"target_height": prompt.height,
"target_width": prompt.width,
"cutoff": cutoff_size,
"dtype": dtype,
}
if prompt.init_image:
compose_kwargs.update(
{
"target_height": init_image.height,
"target_width": init_image.width,
}
)
comp_image, comp_img_orig = _generate_composition_image(
**compose_kwargs, logging_context=lc
)
if comp_image is not None:
prompt.fix_faces = False # done in composition
result_images["composition"] = comp_img_orig
result_images["composition-upscaled"] = comp_image
composition_strength = prompt.composition_strength
first_step = int((prompt.steps) * composition_strength)
noise_step = int((prompt.steps - 1) * composition_strength)
log_img(comp_img_orig, "comp_image")
log_img(comp_image, "comp_image_upscaled")
comp_image_t = pillow_img_to_torch_image(comp_image)
comp_image_t = comp_image_t.to(sd.lda.device, dtype=sd.lda.dtype)
init_latent = sd.lda.encode(comp_image_t)
compose_control_inputs: list[ControlInput]
if prompt.model_weights.architecture.primary_alias == "sdxl":
compose_control_inputs = []
else:
compose_control_inputs = [
ControlInput(mode="details", image=comp_image, strength=1),
]
for control_input in compose_control_inputs:
(
controlnet,
control_image_t,
control_image_disp,
) = prep_control_input(
control_input=control_input,
sd=sd,
init_image_t=None,
fit_width=prompt.width,
fit_height=prompt.height,
)
result_images[
f"control-{control_input.mode}"
] = control_image_disp
controlnets.append((controlnet, control_image_t))
for controlnet, control_image_t in controlnets: for controlnet, control_image_t in controlnets:
msg = f"Injecting controlnet {controlnet.name}. setting to device: {sd.unet.device}, dtype: {sd.unet.dtype}"
print(msg)
controlnet.set_controlnet_condition( controlnet.set_controlnet_condition(
control_image_t.to(device=sd.unet.device, dtype=sd.unet.dtype) control_image_t.to(device=sd.unet.device, dtype=sd.unet.dtype)
) )
@ -263,7 +283,7 @@ def generate_single_image(
else: else:
msg = f"Unknown solver type: {prompt.solver_type}" msg = f"Unknown solver type: {prompt.solver_type}"
raise ValueError(msg) raise ValueError(msg)
sd.scheduler.to(device=sd.device, dtype=sd.dtype) sd.scheduler.to(device=sd.unet.device, dtype=sd.unet.dtype)
sd.set_num_inference_steps(prompt.steps) sd.set_num_inference_steps(prompt.steps)
if hasattr(sd, "mask_latents") and mask_image is not None: if hasattr(sd, "mask_latents") and mask_image is not None:
@ -288,60 +308,66 @@ def generate_single_image(
x=init_latent, noise=noise, step=sd.steps[noise_step] x=init_latent, noise=noise, step=sd.steps[noise_step]
) )
text_conditioning_kwargs = sd.calculate_text_conditioning_kwargs( with lc.timing("text-conditioning"):
positive_prompts=prompt.prompts, text_conditioning_kwargs = sd.calculate_text_conditioning_kwargs(
negative_prompts=prompt.negative_prompt, positive_prompts=prompt.prompts,
positive_conditioning_override=prompt.conditioning, negative_prompts=prompt.negative_prompt,
) positive_conditioning_override=prompt.conditioning,
for k, v in text_conditioning_kwargs.items():
text_conditioning_kwargs[k] = v.to(
device=sd.unet.device, dtype=sd.unet.dtype
) )
for k, v in text_conditioning_kwargs.items():
text_conditioning_kwargs[k] = v.to(
device=sd.unet.device, dtype=sd.unet.dtype
)
x = noised_latent x = noised_latent
x = x.to(device=sd.unet.device, dtype=sd.unet.dtype) x = x.to(device=sd.unet.device, dtype=sd.unet.dtype)
clear_gpu_cache() clear_gpu_cache()
for step in tqdm(sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}"): with lc.timing("unet"):
log_latent(x, "noisy_latent") for step in tqdm(
x = sd( sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}", leave=False
x, ):
step=step, log_latent(x, "noisy_latent")
condition_scale=prompt.prompt_strength, x = sd(
**text_conditioning_kwargs, x,
) step=step,
# trying to clear memory. not sure if this helps condition_scale=prompt.prompt_strength,
sd.unet.set_context(context="self_attention_map", value={}) **text_conditioning_kwargs,
sd.unet._reset_context() )
clear_gpu_cache() # trying to clear memory. not sure if this helps
sd.unet.set_context(context="self_attention_map", value={})
sd.unet._reset_context()
clear_gpu_cache()
logger.debug("Decoding image") logger.debug("Decoding image")
if x.device != sd.lda.device: if x.device != sd.lda.device:
sd.lda.to(x.device) sd.lda.to(x.device)
clear_gpu_cache() clear_gpu_cache()
with lc.timing("decode-img"):
gen_img = sd.lda.decode_latents(x.to(dtype=sd.lda.dtype)) gen_img = sd.lda.decode_latents(x.to(dtype=sd.lda.dtype))
if mask_image_orig and init_image: if mask_image_orig and init_image:
result_images["pre-reconstitution"] = gen_img with lc.timing("combine-image"):
mask_final = mask_image_orig.copy() result_images["pre-reconstitution"] = gen_img
# mask_final = ImageOps.invert(mask_final) mask_final = mask_image_orig.copy()
# mask_final = ImageOps.invert(mask_final)
log_img(mask_final, "reconstituting mask")
# gen_img = Image.composite(gen_img, init_image, mask_final) log_img(mask_final, "reconstituting mask")
gen_img = combine_image( # gen_img = Image.composite(gen_img, init_image, mask_final)
original_img=init_image, gen_img = combine_image(
generated_img=gen_img, original_img=init_image,
mask_img=mask_final, generated_img=gen_img,
) mask_img=mask_final,
log_img(gen_img, "reconstituted image") )
log_img(gen_img, "reconstituted image")
upscaled_img = None upscaled_img = None
rebuilt_orig_img = None rebuilt_orig_img = None
if add_caption: if add_caption:
caption = generate_caption(gen_img) with lc.timing("caption-img"):
logger.info(f"Generated caption: {caption}") caption = generate_caption(gen_img)
logger.info(f"Generated caption: {caption}")
with lc.timing("safety-filter"): with lc.timing("safety-filter"):
safety_score = create_safety_score( safety_score = create_safety_score(
@ -352,13 +378,17 @@ def generate_single_image(
progress_latents.clear() progress_latents.clear()
if not safety_score.is_filtered: if not safety_score.is_filtered:
if prompt.fix_faces: if prompt.fix_faces:
logger.info("Fixing 😊 's in 🖼 using CodeFormer...") with lc.timing("face-enhancement"):
with lc.timing("face enhancement"): logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity) with lc.timing("face-enhancement"):
gen_img = enhance_faces(
gen_img, fidelity=prompt.fix_faces_fidelity
)
if prompt.upscale: if prompt.upscale:
logger.info("Upscaling 🖼 using real-ESRGAN...")
with lc.timing("upscaling"): with lc.timing("upscaling"):
upscaled_img = upscale_image(gen_img) logger.info("Upscaling 🖼 using real-ESRGAN...")
with lc.timing("upscaling"):
upscaled_img = upscale_image(gen_img)
# put the newly generated patch back into the original, full-size image # put the newly generated patch back into the original, full-size image
if prompt.mask_modify_original and mask_image_orig and starting_image: if prompt.mask_modify_original and mask_image_orig and starting_image:
@ -390,13 +420,19 @@ def generate_single_image(
is_nsfw=safety_score.is_nsfw, is_nsfw=safety_score.is_nsfw,
safety_score=safety_score, safety_score=safety_score,
result_images=result_images, result_images=result_images,
timings=lc.get_timings(), performance_stats=lc.get_performance_stats(),
progress_latents=[], # todo progress_latents=[], # todo
) )
_most_recent_result = result _most_recent_result = result
if result.timings: _image_name = f"{image_name} " if image_name else ""
logger.info(f"Image Generated. Timings: {result.timings_str()}") logger.info(f"Generated {_image_name}image in {result.total_time():.1f}s")
if result.performance_stats:
log = logger.info if output_perf else logger.debug
log(f" Timings: {result.timings_str()}")
log(f" Peak VRAM: {result.gpu_str('memory_peak')}")
log(f" Ending VRAM: {result.gpu_str('memory_end')}")
for controlnet, _ in controlnets: for controlnet, _ in controlnets:
controlnet.eject() controlnet.eject()
clear_gpu_cache() clear_gpu_cache()
@ -495,6 +531,7 @@ def _generate_composition_image(
target_width, target_width,
cutoff: tuple[int, int] = (512, 512), cutoff: tuple[int, int] = (512, 512),
dtype=None, dtype=None,
logging_context=None,
): ):
from PIL import Image from PIL import Image
@ -530,7 +567,13 @@ def _generate_composition_image(
}, },
) )
result = generate_single_image(composition_prompt, dtype=dtype) result = generate_single_image(
composition_prompt,
dtype=dtype,
logging_context=logging_context,
output_perf=False,
image_name="composition",
)
img = result.images["generated"] img = result.images["generated"]
while img.width < target_width: while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image from imaginairy.enhancers.upscale_realesrgan import upscale_image
@ -538,9 +581,11 @@ def _generate_composition_image(
if prompt.fix_faces: if prompt.fix_faces:
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity) with logging_context.timing("face-enhancement"):
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
img = upscale_image(img, ultrasharp=True) img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
with logging_context.timing("upscaling"):
img = upscale_image(img, ultrasharp=True)
img = img.resize( img = img.resize(
(target_width, target_height), (target_width, target_height),

@ -109,8 +109,11 @@ def _imagine_cmd(
raise ValueError(msg) raise ValueError(msg)
total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats
img_msg = ""
if len(init_images) > 0:
img_msg = f" and {len(init_images)} image(s)"
logger.info( logger.info(
f"Received {len(prompt_texts)} prompt(s) and {len(init_images)} input image(s). Will repeat the generations {repeats} times to create {total_image_count} images." f"Received {len(prompt_texts)} prompt(s){img_msg}. Will repeat these {repeats} times to create {total_image_count} images.\n"
) )
from imaginairy.api import imagine_image_files from imaginairy.api import imagine_image_files

@ -124,7 +124,7 @@ MODEL_WEIGHT_CONFIGS = [
aliases=MODEL_ARCHITECTURE_LOOKUP["sd15"].aliases, aliases=MODEL_ARCHITECTURE_LOOKUP["sd15"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT}, defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt", weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/889b629140e71758e1e0006e355c331a5744b4bf/",
), ),
ModelWeightsConfig( ModelWeightsConfig(
name="Stable Diffusion 1.5 - Inpainting", name="Stable Diffusion 1.5 - Inpainting",
@ -151,7 +151,7 @@ MODEL_WEIGHT_CONFIGS = [
name="OpenJourney V4", name="OpenJourney V4",
aliases=["openjourney-v4", "oj4", "ojv4", "openjourney4", "openjourney", "oj"], aliases=["openjourney-v4", "oj4", "ojv4", "openjourney4", "openjourney", "oj"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
weights_location="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors", weights_location="https://huggingface.co/prompthero/openjourney/tree/f4572661b028c732b2b97c8fbdc32fa5db3afe03/",
defaults={"negative_prompt": "poor quality"}, defaults={"negative_prompt": "poor quality"},
), ),
ModelWeightsConfig( ModelWeightsConfig(

@ -74,7 +74,7 @@ def enhance_faces(img, fidelity=0):
num_det_faces = face_helper.get_face_landmarks_5( num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5 only_center_face=False, resize=640, eye_dist_threshold=5
) )
logger.info(f"Enhancing {num_det_faces} faces") logger.debug(f"Enhancing {num_det_faces} faces")
# align and warp each face # align and warp each face
face_helper.align_warp_face() face_helper.align_warp_face()

@ -3,11 +3,18 @@
import logging import logging
import math import math
from functools import lru_cache from functools import lru_cache
from typing import List, Literal from typing import Any, List, Literal
import refiners.fluxion.layers as fl
import torch import torch
from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import ChainError from refiners.fluxion.layers.chain import ChainError
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.model import (
TLatentDiffusionModel,
)
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.self_attention_guidance import ( from refiners.foundationals.latent_diffusion.self_attention_guidance import (
SelfAttentionMap, SelfAttentionMap,
) )
@ -25,7 +32,11 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import (
SDXLAutoencoder, SDXLAutoencoder,
StableDiffusion_XL as RefinerStableDiffusion_XL, StableDiffusion_XL as RefinerStableDiffusion_XL,
) )
from torch import Tensor, nn from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
DoubleTextEncoder,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F from torch.nn import functional as F
from imaginairy.schema import WeightedPrompt from imaginairy.schema import WeightedPrompt
@ -83,6 +94,45 @@ class TileModeMixin(nn.Module):
class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1): class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
def __init__(
self,
unet: SD1UNet | None = None,
lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SD1UNet(in_channels=4)
lda = lda or SD1Autoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DDIM(num_inference_steps=50)
fl.Module.__init__(self)
# all this is to allow us to make structural copies without unnecessary device or dtype shuffeling
# since default behavior was to put everything on the same device and dtype and we want the option to
# not alter them from whatever they're already set to
self.unet = unet
self.lda = lda
self.clip_text_encoder = clip_text_encoder
self.scheduler = scheduler
to_kwargs: dict[str, Any] = {}
if device is not None:
device = device if isinstance(device, Device) else Device(device=device)
to_kwargs["device"] = device
if dtype is not None:
to_kwargs["dtype"] = dtype
self.device = device
self.dtype = dtype
if to_kwargs:
self.unet = unet.to(**to_kwargs)
self.lda = lda.to(**to_kwargs)
self.clip_text_encoder = clip_text_encoder.to(**to_kwargs)
self.scheduler = scheduler.to(**to_kwargs)
def calculate_text_conditioning_kwargs( def calculate_text_conditioning_kwargs(
self, self,
positive_prompts: List[WeightedPrompt], positive_prompts: List[WeightedPrompt],
@ -122,6 +172,48 @@ class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL): class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
def __init__(
self,
unet: SDXLUNet | None = None,
lda: SDXLAutoencoder | None = None,
clip_text_encoder: DoubleTextEncoder | None = None,
scheduler: Scheduler | None = None,
device: Device | str | None = "cpu",
dtype: DType | None = None,
) -> None:
unet = unet or SDXLUNet(in_channels=4)
lda = lda or SDXLAutoencoder()
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
scheduler = scheduler or DDIM(num_inference_steps=30)
fl.Module.__init__(self)
# all this is to allow us to make structural copies without unnecessary device or dtype shuffeling
# since default behavior was to put everything on the same device and dtype and we want the option to
# not alter them from whatever they're already set to
self.unet = unet
self.lda = lda
self.clip_text_encoder = clip_text_encoder
self.scheduler = scheduler
to_kwargs: dict[str, Any] = {}
if device is not None:
device = device if isinstance(device, Device) else Device(device=device)
to_kwargs["device"] = device
if dtype is not None:
to_kwargs["dtype"] = dtype
self.device = device # type: ignore
self.dtype = dtype # type: ignore
self.unet = unet
self.lda = lda
self.clip_text_encoder = clip_text_encoder
self.scheduler = scheduler
if to_kwargs:
self.unet = self.unet.to(**to_kwargs)
self.lda = self.lda.to(**to_kwargs)
self.clip_text_encoder = self.clip_text_encoder.to(**to_kwargs)
self.scheduler = self.scheduler.to(**to_kwargs)
def forward( # type: ignore def forward( # type: ignore
self, self,
x: Tensor, x: Tensor,
@ -144,6 +236,22 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
**kwargs, **kwargs,
) )
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
logger.debug("Making structural copy of StableDiffusion_XL model")
sd = self.__class__(
unet=self.unet.structural_copy(),
lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder,
scheduler=self.scheduler,
device=self.device,
dtype=None, # type: ignore
)
logger.debug(
f"dtype: {sd.dtype} unet-dtype:{sd.unet.dtype} lda-dtype:{sd.lda.dtype} text-encoder-dtype:{sd.clip_text_encoder.dtype} scheduler-dtype:{sd.scheduler.dtype}"
)
return sd
def calculate_text_conditioning_kwargs( def calculate_text_conditioning_kwargs(
self, self,
positive_prompts: List[WeightedPrompt], positive_prompts: List[WeightedPrompt],
@ -452,7 +560,7 @@ def monkeypatch_sd1controlnetadapter():
device=target.device, device=target.device,
dtype=target.dtype, dtype=target.dtype,
) )
print( logger.debug(
f"controlnet: {name} loaded to device {target.device} and type {target.dtype}" f"controlnet: {name} loaded to device {target.device} and type {target.dtype}"
) )

@ -362,6 +362,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
composition_strength=composition_strength, composition_strength=composition_strength,
inpaint_method=inpaint_method, inpaint_method=inpaint_method,
) )
self._default_negative_prompt = None
@field_validator("prompt", "negative_prompt", mode="before") @field_validator("prompt", "negative_prompt", mode="before")
def make_into_weighted_prompts( def make_into_weighted_prompts(
@ -401,16 +402,20 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
v.sort(key=lambda p: p.weight, reverse=True) v.sort(key=lambda p: p.weight, reverse=True)
return v return v
@property
def default_negative_prompt(self):
default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
if self.model_weights:
default_negative_prompt = self.model_weights.defaults.get(
"negative_prompt", default_negative_prompt
)
return default_negative_prompt
@model_validator(mode="after") @model_validator(mode="after")
def validate_negative_prompt(self): def validate_negative_prompt(self):
if self.negative_prompt == []: if self.negative_prompt == []:
default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT self.negative_prompt = [WeightedPrompt(text=self.default_negative_prompt)]
if self.model_weights:
default_negative_prompt = self.model_weights.defaults.get(
"negative_prompt", default_negative_prompt
)
self.negative_prompt = [WeightedPrompt(text=default_negative_prompt)]
return self return self
@field_validator("prompt_strength", mode="before") @field_validator("prompt_strength", mode="before")
@ -667,15 +672,27 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return self.model_weights.architecture return self.model_weights.architecture
def prompt_description(self): def prompt_description(self):
if self.negative_prompt_text == self.default_negative_prompt:
neg_prompt = "DEFAULT-NEGATIVE-PROMPT"
else:
neg_prompt = f'"{self.negative_prompt_text}"'
from termcolor import colored
prompt_text = colored(self.prompt_text, "green")
return ( return (
f'"{self.prompt_text}" {self.width}x{self.height}px ' f'"{prompt_text}"\n'
f'negative-prompt:"{self.negative_prompt_text}" ' " "
f"negative-prompt:{neg_prompt}\n"
" "
f"size:{self.width}x{self.height}px "
f"seed:{self.seed} " f"seed:{self.seed} "
f"prompt-strength:{self.prompt_strength} " f"prompt-strength:{self.prompt_strength} "
f"steps:{self.steps} solver-type:{self.solver_type} " f"steps:{self.steps} solver-type:{self.solver_type} "
f"init-image-strength:{self.init_image_strength} " f"init-image-strength:{self.init_image_strength} "
f"arch:{self.model_architecture.aliases[0]} " f"arch:{self.model_architecture.aliases[0]} "
f"weights: {self.model_weights.aliases[0]}" f"weights:{self.model_weights.aliases[0]}"
) )
def logging_dict(self): def logging_dict(self):
@ -724,7 +741,7 @@ class ImagineResult:
is_nsfw, is_nsfw,
safety_score, safety_score,
result_images=None, result_images=None,
timings=None, performance_stats=None,
progress_latents=None, progress_latents=None,
): ):
import torch import torch
@ -750,7 +767,7 @@ class ImagineResult:
r_img = torch_img_to_pillow_img(r_img) r_img = torch_img_to_pillow_img(r_img)
self.images[img_type] = r_img self.images[img_type] = r_img
self.timings = timings self.performance_stats = performance_stats
self.progress_latents = progress_latents self.progress_latents = progress_latents
# for backward compat # for backward compat
@ -771,9 +788,24 @@ class ImagineResult:
} }
def timings_str(self) -> str: def timings_str(self) -> str:
if not self.timings: if not self.performance_stats:
return "" return ""
return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items()) return " ".join(
f"{k}:{v['duration']:.2f}s" for k, v in self.performance_stats.items()
)
def total_time(self) -> float:
if not self.performance_stats:
return 0
return self.performance_stats["total"]["duration"]
def gpu_str(self, stat_name="memory_peak") -> str:
if not self.performance_stats:
return ""
return " ".join(
f"{k}:{v[stat_name]/(10**6):.1f}MB"
for k, v in self.performance_stats.items()
)
def _exif(self) -> "Image.Exif": def _exif(self) -> "Image.Exif":
from PIL import Image from PIL import Image

@ -5,6 +5,7 @@ import logging.config
import re import re
import time import time
import warnings import warnings
from typing import Callable
_CURRENT_LOGGING_CONTEXT = None _CURRENT_LOGGING_CONTEXT = None
@ -53,23 +54,65 @@ def increment_step():
class TimingContext: class TimingContext:
def __init__(self, logging_context, description): """Tracks time and memory usage of a block of code"""
self.logging_context = logging_context
def __init__(
self,
description: str,
device: str | None = None,
callback_fn: Callable | None = None,
):
from imaginairy.utils import get_device
self.description = description self.description = description
self._device = device or get_device()
self.callback_fn = callback_fn
self.start_time = None self.start_time = None
self.end_time = None
self.duration = 0
def __enter__(self): self.memory_start = 0
self.memory_end = 0
self.memory_peak = 0
def start(self):
if self._device == "cuda":
import torch
torch.cuda.reset_peak_memory_stats()
self.memory_start = torch.cuda.memory_allocated()
self.end_time = None
self.start_time = time.time() self.start_time = time.time()
def stop(self):
self.end_time = time.time()
self.duration += self.end_time - self.start_time
if self._device == "cuda":
import torch
self.memory_end = torch.cuda.memory_allocated()
self.memory_peak = max(
torch.cuda.max_memory_allocated() - self.memory_start, self.memory_peak
)
if self.callback_fn is not None:
self.callback_fn(self)
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.logging_context.timings[self.description] = time.time() - self.start_time self.stop()
class ImageLoggingContext: class ImageLoggingContext:
def __init__( def __init__(
self, self,
prompt, prompt,
model, model=None,
debug_img_callback=None, debug_img_callback=None,
img_outdir=None, img_outdir=None,
progress_img_callback=None, progress_img_callback=None,
@ -88,29 +131,60 @@ class ImageLoggingContext:
self.progress_img_interval_min_s = progress_img_interval_min_s self.progress_img_interval_min_s = progress_img_interval_min_s
self.progress_latent_callback = progress_latent_callback self.progress_latent_callback = progress_latent_callback
self.start_ts = time.perf_counter() self.summary_context = TimingContext("total")
self.timings = {} self.summary_context.start()
self.timing_contexts = {}
self.last_progress_img_ts = 0 self.last_progress_img_ts = 0
self.last_progress_img_step = -1000 self.last_progress_img_step = -1000
self._prev_log_context = None self._prev_log_context = None
def __enter__(self): def __enter__(self):
self.start()
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def start(self):
global _CURRENT_LOGGING_CONTEXT global _CURRENT_LOGGING_CONTEXT
self._prev_log_context = _CURRENT_LOGGING_CONTEXT self._prev_log_context = _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = self _CURRENT_LOGGING_CONTEXT = self
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def stop(self):
global _CURRENT_LOGGING_CONTEXT global _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = self._prev_log_context _CURRENT_LOGGING_CONTEXT = self._prev_log_context
def timing(self, description): def timing(self, description):
return TimingContext(self, description) if description not in self.timing_contexts:
def cb(context):
self.timing_contexts[description] = context
tc = TimingContext(description, callback_fn=cb)
self.timing_contexts[description] = tc
return self.timing_contexts[description]
def get_performance_stats(self) -> dict[str, dict[str, float]]:
# calculate max peak seen in any timing context
self.summary_context.stop()
self.timing_contexts["total"] = self.summary_context
self.summary_context.memory_peak = max(
max(context.memory_peak, context.memory_start, context.memory_end)
for context in self.timing_contexts.values()
)
def get_timings(self): performance_stats = {}
self.timings["total"] = time.perf_counter() - self.start_ts for context in self.timing_contexts.values():
return self.timings performance_stats[context.description] = {
"duration": context.duration,
"memory_start": context.memory_start,
"memory_end": context.memory_end,
"memory_peak": context.memory_peak,
"memory_delta": context.memory_end - context.memory_start,
}
return performance_stats
def log_conditioning(self, conditioning, description): def log_conditioning(self, conditioning, description):
if not self.debug_img_callback: if not self.debug_img_callback:

@ -123,12 +123,9 @@ def load_model_from_config_old(
if half_mode: if half_mode:
model = model.half() model = model.half()
print("halved")
model.to(get_device()) model.to(get_device())
print("moved to device")
model.eval() model.eval()
print("set to eval mode")
return model return model
@ -222,12 +219,18 @@ def get_diffusion_model_refiners(
) -> LatentDiffusionModel: ) -> LatentDiffusionModel:
"""Load a diffusion model.""" """Load a diffusion model."""
return _get_diffusion_model_refiners( sd = _get_diffusion_model_refiners(
weights_location=weights_config.weights_location, weights_location=weights_config.weights_location,
architecture_alias=weights_config.architecture.primary_alias, architecture_alias=weights_config.architecture.primary_alias,
for_inpainting=for_inpainting, for_inpainting=for_inpainting,
dtype=dtype, dtype=dtype,
) )
# ensures a "fresh" copy that doesn't have additional injected parts
# sd = sd.structural_copy()
sd.set_self_attention_guidance(enable=True)
return sd
hf_repo_url_pattern = re.compile( hf_repo_url_pattern = re.compile(
@ -241,7 +244,9 @@ def parse_diffusers_repo_url(url: str) -> dict[str, str]:
def is_diffusers_repo_url(url: str) -> bool: def is_diffusers_repo_url(url: str) -> bool:
return bool(parse_diffusers_repo_url(url)) result = bool(parse_diffusers_repo_url(url))
logger.debug(f"{url} is diffusers repo url: {result}")
return result
def normalize_diffusers_repo_url(url: str) -> str: def normalize_diffusers_repo_url(url: str) -> str:
@ -332,17 +337,16 @@ def _get_sd15_diffusion_model_refiners(
device=device, dtype=dtype, lda=SD1AutoencoderSliced(), unet=unet device=device, dtype=dtype, lda=SD1AutoencoderSliced(), unet=unet
) )
logger.debug("Loading VAE") logger.debug("Loading VAE")
sd.lda.load_state_dict(vae_weights) sd.lda.load_state_dict(vae_weights, assign=True)
logger.debug("Loading text encoder") logger.debug("Loading text encoder")
sd.clip_text_encoder.load_state_dict(text_encoder_weights) sd.clip_text_encoder.load_state_dict(text_encoder_weights, assign=True)
logger.debug("Loading UNet") logger.debug("Loading UNet")
sd.unet.load_state_dict(unet_weights, strict=False) sd.unet.load_state_dict(unet_weights, strict=False, assign=True)
logger.debug(f"'{weights_location}' Loaded") logger.debug(f"'{weights_location}' Loaded")
sd.to(device=device, dtype=dtype)
sd.set_self_attention_guidance(enable=True)
return sd return sd
@ -711,16 +715,15 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
device="cpu", device="cpu",
) )
) )
text_encoder = DoubleTextEncoder(device="cpu", dtype=dtype) text_encoder = DoubleTextEncoder(device="cpu", dtype=torch.float32)
text_encoder.load_state_dict(text_encoder_weights) text_encoder.load_state_dict(text_encoder_weights)
del text_encoder_weights del text_encoder_weights
lda = lda.to(device=device) lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device) unet = unet.to(device=device)
text_encoder = text_encoder.to(device=device) text_encoder = text_encoder.to(device=device)
sd = StableDiffusion_XL( sd = StableDiffusion_XL(
device=device, dtype=dtype, lda=lda, unet=unet, clip_text_encoder=text_encoder device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
) )
sd.lda.to(device=device, dtype=torch.float32)
return sd return sd
@ -729,9 +732,6 @@ def load_sdxl_pipeline(base_url, device=None):
logger.info(f"Loading SDXL weights from {base_url}") logger.info(f"Loading SDXL weights from {base_url}")
device = device or get_device() device = device or get_device()
sd = load_sdxl_diffusers_weights(base_url, device=device) sd = load_sdxl_diffusers_weights(base_url, device=device)
sd.set_self_attention_guidance(enable=True)
return sd return sd

@ -1,7 +1,7 @@
import re import re
def generate_phrase_list(subject, num_phrases=100): def generate_phrase_list(subject, num_phrases=100, max_words=6):
"""Generate a list of phrases for a given subject.""" """Generate a list of phrases for a given subject."""
from openai import OpenAI from openai import OpenAI
@ -9,7 +9,7 @@ def generate_phrase_list(subject, num_phrases=100):
prompt = ( prompt = (
f'Make list of archetypal imagery about "{subject}". These will provide composition ideas to an artist. ' f'Make list of archetypal imagery about "{subject}". These will provide composition ideas to an artist. '
f"No more than 6 words per idea. Make {num_phrases} ideas. Provide the output as plaintext with each idea on a new line. " f"No more than {max_words} words per idea. Make {num_phrases} ideas. Provide the output as plaintext with each idea on a new line. "
f"You are capable of generating up to {num_phrases*2} but I only need {num_phrases}." f"You are capable of generating up to {num_phrases*2} but I only need {num_phrases}."
) )
messages = [ messages = [
@ -19,8 +19,8 @@ def generate_phrase_list(subject, num_phrases=100):
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4-1106-preview", model="gpt-4-1106-preview",
messages=messages, messages=messages,
temperature=0.1, temperature=1,
max_tokens=2623, max_tokens=4000,
top_p=1, top_p=1,
frequency_penalty=0.17, frequency_penalty=0.17,
presence_penalty=0, presence_penalty=0,
@ -39,4 +39,9 @@ def generate_phrase_list(subject, num_phrases=100):
if __name__ == "__main__": if __name__ == "__main__":
print(generate_phrase_list("traditional christmas", num_phrases=200)) phrase_list = generate_phrase_list(
subject="symbolism for the human condition and the struggle between good and evil",
num_phrases=200,
max_words=15,
)
print(phrase_list)

@ -42,3 +42,8 @@ def test_prompt_expander_from_wordlist():
def test_get_phraselist_names(): def test_get_phraselist_names():
print(", ".join(category_list())) print(", ".join(category_list()))
def test_complex_prompt():
prompt = "{_painting-style_} of {_art-scene_}. painting"
assert len(list(expand_prompts(prompt, n=100))) == 100

@ -12,7 +12,7 @@ plugins = pydantic.mypy
exclude = ^(\./|)(downloads|dist|build|other|testing_support|imaginairy/vendored|imaginairy/modules/sgm) exclude = ^(\./|)(downloads|dist|build|other|testing_support|imaginairy/vendored|imaginairy/modules/sgm)
ignore_missing_imports = True ignore_missing_imports = True
warn_unused_configs = True warn_unused_configs = True
warn_unused_ignores = True warn_unused_ignores = False
[mypy-imaginairy.vendored.*] [mypy-imaginairy.vendored.*]
follow_imports = skip follow_imports = skip

Loading…
Cancel
Save