feature: clean up terminal output

- recording timing and memory usage of various steps
- re-use logging context for composition images
- load sdxl weights in a more VRAM efficient way
- switch to diffusers weights for default weights for sd15
pull/435/head
Bryce 5 months ago committed by Bryce Drennan
parent 0d78b8271f
commit 9e3403df89

@ -84,7 +84,6 @@ def imagine_image_files(
f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_"
f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}"
)
for image_type in result.images:
subpath = os.path.join(outdir, image_type)
os.makedirs(subpath, exist_ok=True)
@ -92,7 +91,7 @@ def imagine_image_files(
subpath, f"{basefilename}_[{image_type}].{output_file_extension}"
)
result.save(filepath, image_type=image_type)
logger.info(f" [{image_type}] saved to: {filepath}")
logger.info(f" {image_type:<22} {filepath}")
if image_type == return_filename_type:
result_filenames.append(filepath)
if videogen:
@ -123,7 +122,8 @@ def imagine_image_files(
start_pause_duration_ms=1500,
end_pause_duration_ms=1000,
)
logger.info(f" [gif] {len(frames)} frames saved to: {filepath}")
image_type = "gif"
logger.info(f" {image_type:<22} {filepath}")
if make_compare_gif and prompt.init_image:
subpath = os.path.join(outdir, "gif")
os.makedirs(subpath, exist_ok=True)
@ -137,7 +137,8 @@ def imagine_image_files(
imgs=frames,
outpath=filepath,
)
logger.info(f" [gif-comparison] saved to: {filepath}")
image_type = "gif"
logger.info(f" {image_type:<22} {filepath}")
base_count += 1
del result
@ -192,9 +193,8 @@ def imagine(
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
for i, prompt in enumerate(prompts):
concrete_prompt = prompt.make_concrete_copy()
logger.info(
f"🖼 Generating {i + 1}/{num_prompts}: {concrete_prompt.prompt_description()}"
)
prog_text = f"{i + 1}/{num_prompts}"
logger.info(f"🖼 {prog_text} {concrete_prompt.prompt_description()}")
for attempt in range(unsafe_retry_count + 1):
if attempt > 0 and isinstance(concrete_prompt.seed, int):
concrete_prompt.seed += 100_000_000 + attempt
@ -204,7 +204,6 @@ def imagine(
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
half_mode=half_mode,
add_caption=add_caption,
dtype=torch.float16 if half_mode else torch.float32,
)

@ -475,7 +475,7 @@ def _generate_single_image_compvis(
is_nsfw=safety_score.is_nsfw,
safety_score=safety_score,
result_images=result_images,
timings=lc.get_timings(),
performance_stats=lc.get_performance_stats(),
progress_latents=progress_latents.copy(),
)

@ -1,10 +1,13 @@
"""Functions for generating refined images"""
import logging
from contextlib import nullcontext
from typing import Any
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode
from imaginairy.utils import clear_gpu_cache
from imaginairy.utils.log_utils import ImageLoggingContext
logger = logging.getLogger(__name__)
@ -18,7 +21,9 @@ def generate_single_image(
add_caption=False,
return_latent=False,
dtype=None,
half_mode=None,
logging_context: ImageLoggingContext | None = None,
output_perf=False,
image_name="",
):
import torch.nn
from PIL import Image, ImageOps
@ -59,41 +64,48 @@ def generate_single_image(
from imaginairy.utils.safety import create_safety_score
if dtype is None:
dtype = torch.float16 if half_mode else torch.float32
dtype = torch.float16
get_device()
clear_gpu_cache()
prompt = prompt.make_concrete_copy()
sd = get_diffusion_model_refiners(
weights_config=prompt.model_weights,
for_inpainting=prompt.should_use_inpainting
and prompt.inpaint_method == "finetune",
dtype=dtype,
)
if not logging_context:
def latent_logger(latents):
progress_latents.append(latents)
lc = ImageLoggingContext(
prompt=prompt,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
progress_latent_callback=latent_logger
if prompt.collect_progress_latents
else None,
)
_context: Any = lc
else:
lc = logging_context
_context = nullcontext()
with _context:
with lc.timing("model-load"):
sd = get_diffusion_model_refiners(
weights_config=prompt.model_weights,
for_inpainting=prompt.should_use_inpainting
and prompt.inpaint_method == "finetune",
dtype=dtype,
)
lc.model = sd
seed_everything(prompt.seed)
downsampling_factor = 8
latent_channels = 4
batch_size = 1
mask_image = None
mask_image_orig = None
seed_everything(prompt.seed)
downsampling_factor = 8
latent_channels = 4
batch_size = 1
mask_image = None
mask_image_orig = None
def latent_logger(latents):
progress_latents.append(latents)
with ImageLoggingContext(
prompt=prompt,
model=sd,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
progress_latent_callback=latent_logger
if prompt.collect_progress_latents
else None,
) as lc:
sd.set_tile_mode(prompt.tile_mode)
result_images: dict[str, torch.Tensor | None | Image.Image] = {}
@ -178,63 +190,14 @@ def generate_single_image(
assert prompt.seed is not None
noise = randn_seeded(seed=prompt.seed, size=shape).to(
get_device(), dtype=sd.dtype
sd.unet.device, dtype=sd.unet.dtype
)
noised_latent = noise
controlnets = []
if control_modes:
for control_input in control_inputs:
controlnet, control_image_t, control_image_disp = prep_control_input(
control_input=control_input,
sd=sd,
init_image_t=init_image_t,
fit_width=prompt.width,
fit_height=prompt.height,
)
result_images[f"control-{control_input.mode}"] = control_image_disp
controlnets.append((controlnet, control_image_t))
if prompt.allow_compose_phase:
cutoff_size = get_model_default_image_size(prompt.model_architecture)
cutoff_size = (int(cutoff_size[0] * 1.30), int(cutoff_size[1] * 1.30))
compose_kwargs = {
"prompt": prompt,
"target_height": prompt.height,
"target_width": prompt.width,
"cutoff": cutoff_size,
"dtype": dtype,
}
if prompt.init_image:
compose_kwargs.update(
{
"target_height": init_image.height,
"target_width": init_image.width,
}
)
comp_image, comp_img_orig = _generate_composition_image(**compose_kwargs)
if comp_image is not None:
prompt.fix_faces = False # done in composition
result_images["composition"] = comp_img_orig
result_images["composition-upscaled"] = comp_image
composition_strength = prompt.composition_strength
first_step = int((prompt.steps) * composition_strength)
noise_step = int((prompt.steps - 1) * composition_strength)
log_img(comp_img_orig, "comp_image")
log_img(comp_image, "comp_image_upscaled")
comp_image_t = pillow_img_to_torch_image(comp_image)
comp_image_t = comp_image_t.to(sd.lda.device, dtype=sd.lda.dtype)
init_latent = sd.lda.encode(comp_image_t)
compose_control_inputs: list[ControlInput]
if prompt.model_weights.architecture.primary_alias == "sdxl":
compose_control_inputs = []
else:
compose_control_inputs = [
ControlInput(mode="details", image=comp_image, strength=1),
]
for control_input in compose_control_inputs:
with lc.timing("control-image-prep"):
for control_input in control_inputs:
(
controlnet,
control_image_t,
@ -242,16 +205,73 @@ def generate_single_image(
) = prep_control_input(
control_input=control_input,
sd=sd,
init_image_t=None,
init_image_t=init_image_t,
fit_width=prompt.width,
fit_height=prompt.height,
)
result_images[f"control-{control_input.mode}"] = control_image_disp
controlnets.append((controlnet, control_image_t))
if prompt.allow_compose_phase:
with lc.timing("composition"):
cutoff_size = get_model_default_image_size(prompt.model_architecture)
cutoff_size = (int(cutoff_size[0] * 1.30), int(cutoff_size[1] * 1.30))
compose_kwargs = {
"prompt": prompt,
"target_height": prompt.height,
"target_width": prompt.width,
"cutoff": cutoff_size,
"dtype": dtype,
}
if prompt.init_image:
compose_kwargs.update(
{
"target_height": init_image.height,
"target_width": init_image.width,
}
)
comp_image, comp_img_orig = _generate_composition_image(
**compose_kwargs, logging_context=lc
)
if comp_image is not None:
prompt.fix_faces = False # done in composition
result_images["composition"] = comp_img_orig
result_images["composition-upscaled"] = comp_image
composition_strength = prompt.composition_strength
first_step = int((prompt.steps) * composition_strength)
noise_step = int((prompt.steps - 1) * composition_strength)
log_img(comp_img_orig, "comp_image")
log_img(comp_image, "comp_image_upscaled")
comp_image_t = pillow_img_to_torch_image(comp_image)
comp_image_t = comp_image_t.to(sd.lda.device, dtype=sd.lda.dtype)
init_latent = sd.lda.encode(comp_image_t)
compose_control_inputs: list[ControlInput]
if prompt.model_weights.architecture.primary_alias == "sdxl":
compose_control_inputs = []
else:
compose_control_inputs = [
ControlInput(mode="details", image=comp_image, strength=1),
]
for control_input in compose_control_inputs:
(
controlnet,
control_image_t,
control_image_disp,
) = prep_control_input(
control_input=control_input,
sd=sd,
init_image_t=None,
fit_width=prompt.width,
fit_height=prompt.height,
)
result_images[
f"control-{control_input.mode}"
] = control_image_disp
controlnets.append((controlnet, control_image_t))
for controlnet, control_image_t in controlnets:
msg = f"Injecting controlnet {controlnet.name}. setting to device: {sd.unet.device}, dtype: {sd.unet.dtype}"
print(msg)
controlnet.set_controlnet_condition(
control_image_t.to(device=sd.unet.device, dtype=sd.unet.dtype)
)
@ -263,7 +283,7 @@ def generate_single_image(
else:
msg = f"Unknown solver type: {prompt.solver_type}"
raise ValueError(msg)
sd.scheduler.to(device=sd.device, dtype=sd.dtype)
sd.scheduler.to(device=sd.unet.device, dtype=sd.unet.dtype)
sd.set_num_inference_steps(prompt.steps)
if hasattr(sd, "mask_latents") and mask_image is not None:
@ -288,60 +308,66 @@ def generate_single_image(
x=init_latent, noise=noise, step=sd.steps[noise_step]
)
text_conditioning_kwargs = sd.calculate_text_conditioning_kwargs(
positive_prompts=prompt.prompts,
negative_prompts=prompt.negative_prompt,
positive_conditioning_override=prompt.conditioning,
)
for k, v in text_conditioning_kwargs.items():
text_conditioning_kwargs[k] = v.to(
device=sd.unet.device, dtype=sd.unet.dtype
with lc.timing("text-conditioning"):
text_conditioning_kwargs = sd.calculate_text_conditioning_kwargs(
positive_prompts=prompt.prompts,
negative_prompts=prompt.negative_prompt,
positive_conditioning_override=prompt.conditioning,
)
for k, v in text_conditioning_kwargs.items():
text_conditioning_kwargs[k] = v.to(
device=sd.unet.device, dtype=sd.unet.dtype
)
x = noised_latent
x = x.to(device=sd.unet.device, dtype=sd.unet.dtype)
clear_gpu_cache()
for step in tqdm(sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}"):
log_latent(x, "noisy_latent")
x = sd(
x,
step=step,
condition_scale=prompt.prompt_strength,
**text_conditioning_kwargs,
)
# trying to clear memory. not sure if this helps
sd.unet.set_context(context="self_attention_map", value={})
sd.unet._reset_context()
clear_gpu_cache()
with lc.timing("unet"):
for step in tqdm(
sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}", leave=False
):
log_latent(x, "noisy_latent")
x = sd(
x,
step=step,
condition_scale=prompt.prompt_strength,
**text_conditioning_kwargs,
)
# trying to clear memory. not sure if this helps
sd.unet.set_context(context="self_attention_map", value={})
sd.unet._reset_context()
clear_gpu_cache()
logger.debug("Decoding image")
if x.device != sd.lda.device:
sd.lda.to(x.device)
clear_gpu_cache()
gen_img = sd.lda.decode_latents(x.to(dtype=sd.lda.dtype))
with lc.timing("decode-img"):
gen_img = sd.lda.decode_latents(x.to(dtype=sd.lda.dtype))
if mask_image_orig and init_image:
result_images["pre-reconstitution"] = gen_img
mask_final = mask_image_orig.copy()
# mask_final = ImageOps.invert(mask_final)
log_img(mask_final, "reconstituting mask")
# gen_img = Image.composite(gen_img, init_image, mask_final)
gen_img = combine_image(
original_img=init_image,
generated_img=gen_img,
mask_img=mask_final,
)
log_img(gen_img, "reconstituted image")
with lc.timing("combine-image"):
result_images["pre-reconstitution"] = gen_img
mask_final = mask_image_orig.copy()
# mask_final = ImageOps.invert(mask_final)
log_img(mask_final, "reconstituting mask")
# gen_img = Image.composite(gen_img, init_image, mask_final)
gen_img = combine_image(
original_img=init_image,
generated_img=gen_img,
mask_img=mask_final,
)
log_img(gen_img, "reconstituted image")
upscaled_img = None
rebuilt_orig_img = None
if add_caption:
caption = generate_caption(gen_img)
logger.info(f"Generated caption: {caption}")
with lc.timing("caption-img"):
caption = generate_caption(gen_img)
logger.info(f"Generated caption: {caption}")
with lc.timing("safety-filter"):
safety_score = create_safety_score(
@ -352,13 +378,17 @@ def generate_single_image(
progress_latents.clear()
if not safety_score.is_filtered:
if prompt.fix_faces:
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
with lc.timing("face enhancement"):
gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity)
with lc.timing("face-enhancement"):
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
with lc.timing("face-enhancement"):
gen_img = enhance_faces(
gen_img, fidelity=prompt.fix_faces_fidelity
)
if prompt.upscale:
logger.info("Upscaling 🖼 using real-ESRGAN...")
with lc.timing("upscaling"):
upscaled_img = upscale_image(gen_img)
logger.info("Upscaling 🖼 using real-ESRGAN...")
with lc.timing("upscaling"):
upscaled_img = upscale_image(gen_img)
# put the newly generated patch back into the original, full-size image
if prompt.mask_modify_original and mask_image_orig and starting_image:
@ -390,13 +420,19 @@ def generate_single_image(
is_nsfw=safety_score.is_nsfw,
safety_score=safety_score,
result_images=result_images,
timings=lc.get_timings(),
performance_stats=lc.get_performance_stats(),
progress_latents=[], # todo
)
_most_recent_result = result
if result.timings:
logger.info(f"Image Generated. Timings: {result.timings_str()}")
_image_name = f"{image_name} " if image_name else ""
logger.info(f"Generated {_image_name}image in {result.total_time():.1f}s")
if result.performance_stats:
log = logger.info if output_perf else logger.debug
log(f" Timings: {result.timings_str()}")
log(f" Peak VRAM: {result.gpu_str('memory_peak')}")
log(f" Ending VRAM: {result.gpu_str('memory_end')}")
for controlnet, _ in controlnets:
controlnet.eject()
clear_gpu_cache()
@ -495,6 +531,7 @@ def _generate_composition_image(
target_width,
cutoff: tuple[int, int] = (512, 512),
dtype=None,
logging_context=None,
):
from PIL import Image
@ -530,7 +567,13 @@ def _generate_composition_image(
},
)
result = generate_single_image(composition_prompt, dtype=dtype)
result = generate_single_image(
composition_prompt,
dtype=dtype,
logging_context=logging_context,
output_perf=False,
image_name="composition",
)
img = result.images["generated"]
while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image
@ -538,9 +581,11 @@ def _generate_composition_image(
if prompt.fix_faces:
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
img = upscale_image(img, ultrasharp=True)
with logging_context.timing("face-enhancement"):
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
with logging_context.timing("upscaling"):
img = upscale_image(img, ultrasharp=True)
img = img.resize(
(target_width, target_height),

@ -109,8 +109,11 @@ def _imagine_cmd(
raise ValueError(msg)
total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats
img_msg = ""
if len(init_images) > 0:
img_msg = f" and {len(init_images)} image(s)"
logger.info(
f"Received {len(prompt_texts)} prompt(s) and {len(init_images)} input image(s). Will repeat the generations {repeats} times to create {total_image_count} images."
f"Received {len(prompt_texts)} prompt(s){img_msg}. Will repeat these {repeats} times to create {total_image_count} images.\n"
)
from imaginairy.api import imagine_image_files

@ -124,7 +124,7 @@ MODEL_WEIGHT_CONFIGS = [
aliases=MODEL_ARCHITECTURE_LOOKUP["sd15"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt",
weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/889b629140e71758e1e0006e355c331a5744b4bf/",
),
ModelWeightsConfig(
name="Stable Diffusion 1.5 - Inpainting",
@ -151,7 +151,7 @@ MODEL_WEIGHT_CONFIGS = [
name="OpenJourney V4",
aliases=["openjourney-v4", "oj4", "ojv4", "openjourney4", "openjourney", "oj"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
weights_location="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors",
weights_location="https://huggingface.co/prompthero/openjourney/tree/f4572661b028c732b2b97c8fbdc32fa5db3afe03/",
defaults={"negative_prompt": "poor quality"},
),
ModelWeightsConfig(

@ -74,7 +74,7 @@ def enhance_faces(img, fidelity=0):
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5
)
logger.info(f"Enhancing {num_det_faces} faces")
logger.debug(f"Enhancing {num_det_faces} faces")
# align and warp each face
face_helper.align_warp_face()

@ -3,11 +3,18 @@
import logging
import math
from functools import lru_cache
from typing import List, Literal
from typing import Any, List, Literal
import refiners.fluxion.layers as fl
import torch
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import ChainError
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.model import (
TLatentDiffusionModel,
)
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
SelfAttentionMap,
)
@ -25,7 +32,11 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import (
SDXLAutoencoder,
StableDiffusion_XL as RefinerStableDiffusion_XL,
)
from torch import Tensor, nn
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
DoubleTextEncoder,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F
from imaginairy.schema import WeightedPrompt
@ -83,6 +94,45 @@ class TileModeMixin(nn.Module):
class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
def __init__(
self,
unet: SD1UNet | None = None,
lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SD1UNet(in_channels=4)
lda = lda or SD1Autoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DDIM(num_inference_steps=50)
fl.Module.__init__(self)
# all this is to allow us to make structural copies without unnecessary device or dtype shuffeling
# since default behavior was to put everything on the same device and dtype and we want the option to
# not alter them from whatever they're already set to
self.unet = unet
self.lda = lda
self.clip_text_encoder = clip_text_encoder
self.scheduler = scheduler
to_kwargs: dict[str, Any] = {}
if device is not None:
device = device if isinstance(device, Device) else Device(device=device)
to_kwargs["device"] = device
if dtype is not None:
to_kwargs["dtype"] = dtype
self.device = device
self.dtype = dtype
if to_kwargs:
self.unet = unet.to(**to_kwargs)
self.lda = lda.to(**to_kwargs)
self.clip_text_encoder = clip_text_encoder.to(**to_kwargs)
self.scheduler = scheduler.to(**to_kwargs)
def calculate_text_conditioning_kwargs(
self,
positive_prompts: List[WeightedPrompt],
@ -122,6 +172,48 @@ class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
def __init__(
self,
unet: SDXLUNet | None = None,
lda: SDXLAutoencoder | None = None,
clip_text_encoder: DoubleTextEncoder | None = None,
scheduler: Scheduler | None = None,
device: Device | str | None = "cpu",
dtype: DType | None = None,
) -> None:
unet = unet or SDXLUNet(in_channels=4)
lda = lda or SDXLAutoencoder()
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
scheduler = scheduler or DDIM(num_inference_steps=30)
fl.Module.__init__(self)
# all this is to allow us to make structural copies without unnecessary device or dtype shuffeling
# since default behavior was to put everything on the same device and dtype and we want the option to
# not alter them from whatever they're already set to
self.unet = unet
self.lda = lda
self.clip_text_encoder = clip_text_encoder
self.scheduler = scheduler
to_kwargs: dict[str, Any] = {}
if device is not None:
device = device if isinstance(device, Device) else Device(device=device)
to_kwargs["device"] = device
if dtype is not None:
to_kwargs["dtype"] = dtype
self.device = device # type: ignore
self.dtype = dtype # type: ignore
self.unet = unet
self.lda = lda
self.clip_text_encoder = clip_text_encoder
self.scheduler = scheduler
if to_kwargs:
self.unet = self.unet.to(**to_kwargs)
self.lda = self.lda.to(**to_kwargs)
self.clip_text_encoder = self.clip_text_encoder.to(**to_kwargs)
self.scheduler = self.scheduler.to(**to_kwargs)
def forward( # type: ignore
self,
x: Tensor,
@ -144,6 +236,22 @@ class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL):
**kwargs,
)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
logger.debug("Making structural copy of StableDiffusion_XL model")
sd = self.__class__(
unet=self.unet.structural_copy(),
lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder,
scheduler=self.scheduler,
device=self.device,
dtype=None, # type: ignore
)
logger.debug(
f"dtype: {sd.dtype} unet-dtype:{sd.unet.dtype} lda-dtype:{sd.lda.dtype} text-encoder-dtype:{sd.clip_text_encoder.dtype} scheduler-dtype:{sd.scheduler.dtype}"
)
return sd
def calculate_text_conditioning_kwargs(
self,
positive_prompts: List[WeightedPrompt],
@ -452,7 +560,7 @@ def monkeypatch_sd1controlnetadapter():
device=target.device,
dtype=target.dtype,
)
print(
logger.debug(
f"controlnet: {name} loaded to device {target.device} and type {target.dtype}"
)

@ -362,6 +362,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
composition_strength=composition_strength,
inpaint_method=inpaint_method,
)
self._default_negative_prompt = None
@field_validator("prompt", "negative_prompt", mode="before")
def make_into_weighted_prompts(
@ -401,16 +402,20 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
v.sort(key=lambda p: p.weight, reverse=True)
return v
@property
def default_negative_prompt(self):
default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
if self.model_weights:
default_negative_prompt = self.model_weights.defaults.get(
"negative_prompt", default_negative_prompt
)
return default_negative_prompt
@model_validator(mode="after")
def validate_negative_prompt(self):
if self.negative_prompt == []:
default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
if self.model_weights:
default_negative_prompt = self.model_weights.defaults.get(
"negative_prompt", default_negative_prompt
)
self.negative_prompt = [WeightedPrompt(text=self.default_negative_prompt)]
self.negative_prompt = [WeightedPrompt(text=default_negative_prompt)]
return self
@field_validator("prompt_strength", mode="before")
@ -667,15 +672,27 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return self.model_weights.architecture
def prompt_description(self):
if self.negative_prompt_text == self.default_negative_prompt:
neg_prompt = "DEFAULT-NEGATIVE-PROMPT"
else:
neg_prompt = f'"{self.negative_prompt_text}"'
from termcolor import colored
prompt_text = colored(self.prompt_text, "green")
return (
f'"{self.prompt_text}" {self.width}x{self.height}px '
f'negative-prompt:"{self.negative_prompt_text}" '
f'"{prompt_text}"\n'
" "
f"negative-prompt:{neg_prompt}\n"
" "
f"size:{self.width}x{self.height}px "
f"seed:{self.seed} "
f"prompt-strength:{self.prompt_strength} "
f"steps:{self.steps} solver-type:{self.solver_type} "
f"init-image-strength:{self.init_image_strength} "
f"arch:{self.model_architecture.aliases[0]} "
f"weights: {self.model_weights.aliases[0]}"
f"weights:{self.model_weights.aliases[0]}"
)
def logging_dict(self):
@ -724,7 +741,7 @@ class ImagineResult:
is_nsfw,
safety_score,
result_images=None,
timings=None,
performance_stats=None,
progress_latents=None,
):
import torch
@ -750,7 +767,7 @@ class ImagineResult:
r_img = torch_img_to_pillow_img(r_img)
self.images[img_type] = r_img
self.timings = timings
self.performance_stats = performance_stats
self.progress_latents = progress_latents
# for backward compat
@ -771,9 +788,24 @@ class ImagineResult:
}
def timings_str(self) -> str:
if not self.timings:
if not self.performance_stats:
return ""
return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items())
return " ".join(
f"{k}:{v['duration']:.2f}s" for k, v in self.performance_stats.items()
)
def total_time(self) -> float:
if not self.performance_stats:
return 0
return self.performance_stats["total"]["duration"]
def gpu_str(self, stat_name="memory_peak") -> str:
if not self.performance_stats:
return ""
return " ".join(
f"{k}:{v[stat_name]/(10**6):.1f}MB"
for k, v in self.performance_stats.items()
)
def _exif(self) -> "Image.Exif":
from PIL import Image

@ -5,6 +5,7 @@ import logging.config
import re
import time
import warnings
from typing import Callable
_CURRENT_LOGGING_CONTEXT = None
@ -53,23 +54,65 @@ def increment_step():
class TimingContext:
def __init__(self, logging_context, description):
self.logging_context = logging_context
"""Tracks time and memory usage of a block of code"""
def __init__(
self,
description: str,
device: str | None = None,
callback_fn: Callable | None = None,
):
from imaginairy.utils import get_device
self.description = description
self._device = device or get_device()
self.callback_fn = callback_fn
self.start_time = None
self.end_time = None
self.duration = 0
def __enter__(self):
self.memory_start = 0
self.memory_end = 0
self.memory_peak = 0
def start(self):
if self._device == "cuda":
import torch
torch.cuda.reset_peak_memory_stats()
self.memory_start = torch.cuda.memory_allocated()
self.end_time = None
self.start_time = time.time()
def stop(self):
self.end_time = time.time()
self.duration += self.end_time - self.start_time
if self._device == "cuda":
import torch
self.memory_end = torch.cuda.memory_allocated()
self.memory_peak = max(
torch.cuda.max_memory_allocated() - self.memory_start, self.memory_peak
)
if self.callback_fn is not None:
self.callback_fn(self)
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.logging_context.timings[self.description] = time.time() - self.start_time
self.stop()
class ImageLoggingContext:
def __init__(
self,
prompt,
model,
model=None,
debug_img_callback=None,
img_outdir=None,
progress_img_callback=None,
@ -88,29 +131,60 @@ class ImageLoggingContext:
self.progress_img_interval_min_s = progress_img_interval_min_s
self.progress_latent_callback = progress_latent_callback
self.start_ts = time.perf_counter()
self.timings = {}
self.summary_context = TimingContext("total")
self.summary_context.start()
self.timing_contexts = {}
self.last_progress_img_ts = 0
self.last_progress_img_step = -1000
self._prev_log_context = None
def __enter__(self):
self.start()
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def start(self):
global _CURRENT_LOGGING_CONTEXT
self._prev_log_context = _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def stop(self):
global _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = self._prev_log_context
def timing(self, description):
return TimingContext(self, description)
if description not in self.timing_contexts:
def cb(context):
self.timing_contexts[description] = context
tc = TimingContext(description, callback_fn=cb)
self.timing_contexts[description] = tc
return self.timing_contexts[description]
def get_performance_stats(self) -> dict[str, dict[str, float]]:
# calculate max peak seen in any timing context
self.summary_context.stop()
self.timing_contexts["total"] = self.summary_context
self.summary_context.memory_peak = max(
max(context.memory_peak, context.memory_start, context.memory_end)
for context in self.timing_contexts.values()
)
def get_timings(self):
self.timings["total"] = time.perf_counter() - self.start_ts
return self.timings
performance_stats = {}
for context in self.timing_contexts.values():
performance_stats[context.description] = {
"duration": context.duration,
"memory_start": context.memory_start,
"memory_end": context.memory_end,
"memory_peak": context.memory_peak,
"memory_delta": context.memory_end - context.memory_start,
}
return performance_stats
def log_conditioning(self, conditioning, description):
if not self.debug_img_callback:

@ -123,12 +123,9 @@ def load_model_from_config_old(
if half_mode:
model = model.half()
print("halved")
model.to(get_device())
print("moved to device")
model.eval()
print("set to eval mode")
return model
@ -222,12 +219,18 @@ def get_diffusion_model_refiners(
) -> LatentDiffusionModel:
"""Load a diffusion model."""
return _get_diffusion_model_refiners(
sd = _get_diffusion_model_refiners(
weights_location=weights_config.weights_location,
architecture_alias=weights_config.architecture.primary_alias,
for_inpainting=for_inpainting,
dtype=dtype,
)
# ensures a "fresh" copy that doesn't have additional injected parts
# sd = sd.structural_copy()
sd.set_self_attention_guidance(enable=True)
return sd
hf_repo_url_pattern = re.compile(
@ -241,7 +244,9 @@ def parse_diffusers_repo_url(url: str) -> dict[str, str]:
def is_diffusers_repo_url(url: str) -> bool:
return bool(parse_diffusers_repo_url(url))
result = bool(parse_diffusers_repo_url(url))
logger.debug(f"{url} is diffusers repo url: {result}")
return result
def normalize_diffusers_repo_url(url: str) -> str:
@ -332,17 +337,16 @@ def _get_sd15_diffusion_model_refiners(
device=device, dtype=dtype, lda=SD1AutoencoderSliced(), unet=unet
)
logger.debug("Loading VAE")
sd.lda.load_state_dict(vae_weights)
sd.lda.load_state_dict(vae_weights, assign=True)
logger.debug("Loading text encoder")
sd.clip_text_encoder.load_state_dict(text_encoder_weights)
sd.clip_text_encoder.load_state_dict(text_encoder_weights, assign=True)
logger.debug("Loading UNet")
sd.unet.load_state_dict(unet_weights, strict=False)
sd.unet.load_state_dict(unet_weights, strict=False, assign=True)
logger.debug(f"'{weights_location}' Loaded")
sd.set_self_attention_guidance(enable=True)
sd.to(device=device, dtype=dtype)
return sd
@ -711,16 +715,15 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
device="cpu",
)
)
text_encoder = DoubleTextEncoder(device="cpu", dtype=dtype)
text_encoder = DoubleTextEncoder(device="cpu", dtype=torch.float32)
text_encoder.load_state_dict(text_encoder_weights)
del text_encoder_weights
lda = lda.to(device=device)
lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device)
text_encoder = text_encoder.to(device=device)
sd = StableDiffusion_XL(
device=device, dtype=dtype, lda=lda, unet=unet, clip_text_encoder=text_encoder
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
)
sd.lda.to(device=device, dtype=torch.float32)
return sd
@ -729,9 +732,6 @@ def load_sdxl_pipeline(base_url, device=None):
logger.info(f"Loading SDXL weights from {base_url}")
device = device or get_device()
sd = load_sdxl_diffusers_weights(base_url, device=device)
sd.set_self_attention_guidance(enable=True)
return sd

@ -1,7 +1,7 @@
import re
def generate_phrase_list(subject, num_phrases=100):
def generate_phrase_list(subject, num_phrases=100, max_words=6):
"""Generate a list of phrases for a given subject."""
from openai import OpenAI
@ -9,7 +9,7 @@ def generate_phrase_list(subject, num_phrases=100):
prompt = (
f'Make list of archetypal imagery about "{subject}". These will provide composition ideas to an artist. '
f"No more than 6 words per idea. Make {num_phrases} ideas. Provide the output as plaintext with each idea on a new line. "
f"No more than {max_words} words per idea. Make {num_phrases} ideas. Provide the output as plaintext with each idea on a new line. "
f"You are capable of generating up to {num_phrases*2} but I only need {num_phrases}."
)
messages = [
@ -19,8 +19,8 @@ def generate_phrase_list(subject, num_phrases=100):
response = client.chat.completions.create(
model="gpt-4-1106-preview",
messages=messages,
temperature=0.1,
max_tokens=2623,
temperature=1,
max_tokens=4000,
top_p=1,
frequency_penalty=0.17,
presence_penalty=0,
@ -39,4 +39,9 @@ def generate_phrase_list(subject, num_phrases=100):
if __name__ == "__main__":
print(generate_phrase_list("traditional christmas", num_phrases=200))
phrase_list = generate_phrase_list(
subject="symbolism for the human condition and the struggle between good and evil",
num_phrases=200,
max_words=15,
)
print(phrase_list)

@ -42,3 +42,8 @@ def test_prompt_expander_from_wordlist():
def test_get_phraselist_names():
print(", ".join(category_list()))
def test_complex_prompt():
prompt = "{_painting-style_} of {_art-scene_}. painting"
assert len(list(expand_prompts(prompt, n=100))) == 100

@ -12,7 +12,7 @@ plugins = pydantic.mypy
exclude = ^(\./|)(downloads|dist|build|other|testing_support|imaginairy/vendored|imaginairy/modules/sgm)
ignore_missing_imports = True
warn_unused_configs = True
warn_unused_ignores = True
warn_unused_ignores = False
[mypy-imaginairy.vendored.*]
follow_imports = skip

Loading…
Cancel
Save