@ -1,10 +1,13 @@
""" Functions for generating refined images """
import logging
from contextlib import nullcontext
from typing import Any
from imaginairy . config import CONTROL_CONFIG_SHORTCUTS
from imaginairy . schema import ControlInput , ImaginePrompt , MaskMode
from imaginairy . utils import clear_gpu_cache
from imaginairy . utils . log_utils import ImageLoggingContext
logger = logging . getLogger ( __name__ )
@ -18,7 +21,9 @@ def generate_single_image(
add_caption = False ,
return_latent = False ,
dtype = None ,
half_mode = None ,
logging_context : ImageLoggingContext | None = None ,
output_perf = False ,
image_name = " " ,
) :
import torch . nn
from PIL import Image , ImageOps
@ -59,19 +64,40 @@ def generate_single_image(
from imaginairy . utils . safety import create_safety_score
if dtype is None :
dtype = torch . float16 if half_mode else torch . float32
dtype = torch . float16
get_device ( )
clear_gpu_cache ( )
prompt = prompt . make_concrete_copy ( )
if not logging_context :
def latent_logger ( latents ) :
progress_latents . append ( latents )
lc = ImageLoggingContext (
prompt = prompt ,
debug_img_callback = debug_img_callback ,
progress_img_callback = progress_img_callback ,
progress_img_interval_steps = progress_img_interval_steps ,
progress_img_interval_min_s = progress_img_interval_min_s ,
progress_latent_callback = latent_logger
if prompt . collect_progress_latents
else None ,
)
_context : Any = lc
else :
lc = logging_context
_context = nullcontext ( )
with _context :
with lc . timing ( " model-load " ) :
sd = get_diffusion_model_refiners (
weights_config = prompt . model_weights ,
for_inpainting = prompt . should_use_inpainting
and prompt . inpaint_method == " finetune " ,
dtype = dtype ,
)
lc . model = sd
seed_everything ( prompt . seed )
downsampling_factor = 8
latent_channels = 4
@ -80,20 +106,6 @@ def generate_single_image(
mask_image = None
mask_image_orig = None
def latent_logger ( latents ) :
progress_latents . append ( latents )
with ImageLoggingContext (
prompt = prompt ,
model = sd ,
debug_img_callback = debug_img_callback ,
progress_img_callback = progress_img_callback ,
progress_img_interval_steps = progress_img_interval_steps ,
progress_img_interval_min_s = progress_img_interval_min_s ,
progress_latent_callback = latent_logger
if prompt . collect_progress_latents
else None ,
) as lc :
sd . set_tile_mode ( prompt . tile_mode )
result_images : dict [ str , torch . Tensor | None | Image . Image ] = { }
@ -178,14 +190,19 @@ def generate_single_image(
assert prompt . seed is not None
noise = randn_seeded ( seed = prompt . seed , size = shape ) . to (
get_device( ) , dtype = sd . dtype
sd. unet . device , dtype = sd . unet . dtype
)
noised_latent = noise
controlnets = [ ]
if control_modes :
with lc . timing ( " control-image-prep " ) :
for control_input in control_inputs :
controlnet , control_image_t , control_image_disp = prep_control_input (
(
controlnet ,
control_image_t ,
control_image_disp ,
) = prep_control_input (
control_input = control_input ,
sd = sd ,
init_image_t = init_image_t ,
@ -196,6 +213,7 @@ def generate_single_image(
controlnets . append ( ( controlnet , control_image_t ) )
if prompt . allow_compose_phase :
with lc . timing ( " composition " ) :
cutoff_size = get_model_default_image_size ( prompt . model_architecture )
cutoff_size = ( int ( cutoff_size [ 0 ] * 1.30 ) , int ( cutoff_size [ 1 ] * 1.30 ) )
compose_kwargs = {
@ -213,7 +231,9 @@ def generate_single_image(
" target_width " : init_image . width ,
}
)
comp_image , comp_img_orig = _generate_composition_image ( * * compose_kwargs )
comp_image , comp_img_orig = _generate_composition_image (
* * compose_kwargs , logging_context = lc
)
if comp_image is not None :
prompt . fix_faces = False # done in composition
@ -246,12 +266,12 @@ def generate_single_image(
fit_width = prompt . width ,
fit_height = prompt . height ,
)
result_images [ f " control- { control_input . mode } " ] = control_image_disp
result_images [
f " control- { control_input . mode } "
] = control_image_disp
controlnets . append ( ( controlnet , control_image_t ) )
for controlnet , control_image_t in controlnets :
msg = f " Injecting controlnet { controlnet . name } . setting to device: { sd . unet . device } , dtype: { sd . unet . dtype } "
print ( msg )
controlnet . set_controlnet_condition (
control_image_t . to ( device = sd . unet . device , dtype = sd . unet . dtype )
)
@ -263,7 +283,7 @@ def generate_single_image(
else :
msg = f " Unknown solver type: { prompt . solver_type } "
raise ValueError ( msg )
sd . scheduler . to ( device = sd . device, dtype = sd . dtype )
sd . scheduler . to ( device = sd . unet. device, dtype = sd . unet . dtype )
sd . set_num_inference_steps ( prompt . steps )
if hasattr ( sd , " mask_latents " ) and mask_image is not None :
@ -288,6 +308,7 @@ def generate_single_image(
x = init_latent , noise = noise , step = sd . steps [ noise_step ]
)
with lc . timing ( " text-conditioning " ) :
text_conditioning_kwargs = sd . calculate_text_conditioning_kwargs (
positive_prompts = prompt . prompts ,
negative_prompts = prompt . negative_prompt ,
@ -302,7 +323,10 @@ def generate_single_image(
x = x . to ( device = sd . unet . device , dtype = sd . unet . dtype )
clear_gpu_cache ( )
for step in tqdm ( sd . steps [ first_step : ] , bar_format = " {l_bar} {bar} {r_bar} " ) :
with lc . timing ( " unet " ) :
for step in tqdm (
sd . steps [ first_step : ] , bar_format = " {l_bar} {bar} {r_bar} " , leave = False
) :
log_latent ( x , " noisy_latent " )
x = sd (
x ,
@ -319,10 +343,11 @@ def generate_single_image(
if x . device != sd . lda . device :
sd . lda . to ( x . device )
clear_gpu_cache ( )
with lc . timing ( " decode-img " ) :
gen_img = sd . lda . decode_latents ( x . to ( dtype = sd . lda . dtype ) )
if mask_image_orig and init_image :
with lc . timing ( " combine-image " ) :
result_images [ " pre-reconstitution " ] = gen_img
mask_final = mask_image_orig . copy ( )
# mask_final = ImageOps.invert(mask_final)
@ -340,6 +365,7 @@ def generate_single_image(
rebuilt_orig_img = None
if add_caption :
with lc . timing ( " caption-img " ) :
caption = generate_caption ( gen_img )
logger . info ( f " Generated caption: { caption } " )
@ -352,10 +378,14 @@ def generate_single_image(
progress_latents . clear ( )
if not safety_score . is_filtered :
if prompt . fix_faces :
with lc . timing ( " face-enhancement " ) :
logger . info ( " Fixing 😊 ' s in 🖼 using CodeFormer... " )
with lc . timing ( " face enhancement " ) :
gen_img = enhance_faces ( gen_img , fidelity = prompt . fix_faces_fidelity )
with lc . timing ( " face-enhancement " ) :
gen_img = enhance_faces (
gen_img , fidelity = prompt . fix_faces_fidelity
)
if prompt . upscale :
with lc . timing ( " upscaling " ) :
logger . info ( " Upscaling 🖼 using real-ESRGAN... " )
with lc . timing ( " upscaling " ) :
upscaled_img = upscale_image ( gen_img )
@ -390,13 +420,19 @@ def generate_single_image(
is_nsfw = safety_score . is_nsfw ,
safety_score = safety_score ,
result_images = result_images ,
timings= lc . get_timing s( ) ,
performance_stats= lc . get_performance_stat s( ) ,
progress_latents = [ ] , # todo
)
_most_recent_result = result
if result . timings :
logger . info ( f " Image Generated. Timings: { result . timings_str ( ) } " )
_image_name = f " { image_name } " if image_name else " "
logger . info ( f " Generated { _image_name } image in { result . total_time ( ) : .1f } s " )
if result . performance_stats :
log = logger . info if output_perf else logger . debug
log ( f " Timings: { result . timings_str ( ) } " )
log ( f " Peak VRAM: { result . gpu_str ( ' memory_peak ' ) } " )
log ( f " Ending VRAM: { result . gpu_str ( ' memory_end ' ) } " )
for controlnet , _ in controlnets :
controlnet . eject ( )
clear_gpu_cache ( )
@ -495,6 +531,7 @@ def _generate_composition_image(
target_width ,
cutoff : tuple [ int , int ] = ( 512 , 512 ) ,
dtype = None ,
logging_context = None ,
) :
from PIL import Image
@ -530,7 +567,13 @@ def _generate_composition_image(
} ,
)
result = generate_single_image ( composition_prompt , dtype = dtype )
result = generate_single_image (
composition_prompt ,
dtype = dtype ,
logging_context = logging_context ,
output_perf = False ,
image_name = " composition " ,
)
img = result . images [ " generated " ]
while img . width < target_width :
from imaginairy . enhancers . upscale_realesrgan import upscale_image
@ -538,8 +581,10 @@ def _generate_composition_image(
if prompt . fix_faces :
from imaginairy . enhancers . face_restoration_codeformer import enhance_faces
with logging_context . timing ( " face-enhancement " ) :
logger . info ( " Fixing 😊 ' s in 🖼 using CodeFormer... " )
img = enhance_faces ( img , fidelity = prompt . fix_faces_fidelity )
with logging_context . timing ( " upscaling " ) :
img = upscale_image ( img , ultrasharp = True )
img = img . resize (