@ -1,10 +1,13 @@
""" Functions for generating refined images """
import logging
from contextlib import nullcontext
from typing import Any
from imaginairy . config import CONTROL_CONFIG_SHORTCUTS
from imaginairy . schema import ControlInput , ImaginePrompt , MaskMode
from imaginairy . utils import clear_gpu_cache
from imaginairy . utils . log_utils import ImageLoggingContext
logger = logging . getLogger ( __name__ )
@ -18,7 +21,9 @@ def generate_single_image(
add_caption = False ,
return_latent = False ,
dtype = None ,
half_mode = None ,
logging_context : ImageLoggingContext | None = None ,
output_perf = False ,
image_name = " " ,
) :
import torch . nn
from PIL import Image , ImageOps
@ -59,41 +64,48 @@ def generate_single_image(
from imaginairy . utils . safety import create_safety_score
if dtype is None :
dtype = torch . float16 if half_mode else torch . float32
dtype = torch . float16
get_device ( )
clear_gpu_cache ( )
prompt = prompt . make_concrete_copy ( )
sd = get_diffusion_model_refiners (
weights_config = prompt . model_weights ,
for_inpainting = prompt . should_use_inpainting
and prompt . inpaint_method == " finetune " ,
dtype = dtype ,
)
if not logging_context :
def latent_logger ( latents ) :
progress_latents . append ( latents )
lc = ImageLoggingContext (
prompt = prompt ,
debug_img_callback = debug_img_callback ,
progress_img_callback = progress_img_callback ,
progress_img_interval_steps = progress_img_interval_steps ,
progress_img_interval_min_s = progress_img_interval_min_s ,
progress_latent_callback = latent_logger
if prompt . collect_progress_latents
else None ,
)
_context : Any = lc
else :
lc = logging_context
_context = nullcontext ( )
with _context :
with lc . timing ( " model-load " ) :
sd = get_diffusion_model_refiners (
weights_config = prompt . model_weights ,
for_inpainting = prompt . should_use_inpainting
and prompt . inpaint_method == " finetune " ,
dtype = dtype ,
)
lc . model = sd
seed_everything ( prompt . seed )
downsampling_factor = 8
latent_channels = 4
batch_size = 1
mask_image = None
mask_image_orig = None
seed_everything ( prompt . seed )
downsampling_factor = 8
latent_channels = 4
batch_size = 1
mask_image = None
mask_image_orig = None
def latent_logger ( latents ) :
progress_latents . append ( latents )
with ImageLoggingContext (
prompt = prompt ,
model = sd ,
debug_img_callback = debug_img_callback ,
progress_img_callback = progress_img_callback ,
progress_img_interval_steps = progress_img_interval_steps ,
progress_img_interval_min_s = progress_img_interval_min_s ,
progress_latent_callback = latent_logger
if prompt . collect_progress_latents
else None ,
) as lc :
sd . set_tile_mode ( prompt . tile_mode )
result_images : dict [ str , torch . Tensor | None | Image . Image ] = { }
@ -178,63 +190,14 @@ def generate_single_image(
assert prompt . seed is not None
noise = randn_seeded ( seed = prompt . seed , size = shape ) . to (
get_device( ) , dtype = sd . dtype
sd. unet . device , dtype = sd . unet . dtype
)
noised_latent = noise
controlnets = [ ]
if control_modes :
for control_input in control_inputs :
controlnet , control_image_t , control_image_disp = prep_control_input (
control_input = control_input ,
sd = sd ,
init_image_t = init_image_t ,
fit_width = prompt . width ,
fit_height = prompt . height ,
)
result_images [ f " control- { control_input . mode } " ] = control_image_disp
controlnets . append ( ( controlnet , control_image_t ) )
if prompt . allow_compose_phase :
cutoff_size = get_model_default_image_size ( prompt . model_architecture )
cutoff_size = ( int ( cutoff_size [ 0 ] * 1.30 ) , int ( cutoff_size [ 1 ] * 1.30 ) )
compose_kwargs = {
" prompt " : prompt ,
" target_height " : prompt . height ,
" target_width " : prompt . width ,
" cutoff " : cutoff_size ,
" dtype " : dtype ,
}
if prompt . init_image :
compose_kwargs . update (
{
" target_height " : init_image . height ,
" target_width " : init_image . width ,
}
)
comp_image , comp_img_orig = _generate_composition_image ( * * compose_kwargs )
if comp_image is not None :
prompt . fix_faces = False # done in composition
result_images [ " composition " ] = comp_img_orig
result_images [ " composition-upscaled " ] = comp_image
composition_strength = prompt . composition_strength
first_step = int ( ( prompt . steps ) * composition_strength )
noise_step = int ( ( prompt . steps - 1 ) * composition_strength )
log_img ( comp_img_orig , " comp_image " )
log_img ( comp_image , " comp_image_upscaled " )
comp_image_t = pillow_img_to_torch_image ( comp_image )
comp_image_t = comp_image_t . to ( sd . lda . device , dtype = sd . lda . dtype )
init_latent = sd . lda . encode ( comp_image_t )
compose_control_inputs : list [ ControlInput ]
if prompt . model_weights . architecture . primary_alias == " sdxl " :
compose_control_inputs = [ ]
else :
compose_control_inputs = [
ControlInput ( mode = " details " , image = comp_image , strength = 1 ) ,
]
for control_input in compose_control_inputs :
with lc . timing ( " control-image-prep " ) :
for control_input in control_inputs :
(
controlnet ,
control_image_t ,
@ -242,16 +205,73 @@ def generate_single_image(
) = prep_control_input (
control_input = control_input ,
sd = sd ,
init_image_t = None ,
init_image_t = init_image_t ,
fit_width = prompt . width ,
fit_height = prompt . height ,
)
result_images [ f " control- { control_input . mode } " ] = control_image_disp
controlnets . append ( ( controlnet , control_image_t ) )
if prompt . allow_compose_phase :
with lc . timing ( " composition " ) :
cutoff_size = get_model_default_image_size ( prompt . model_architecture )
cutoff_size = ( int ( cutoff_size [ 0 ] * 1.30 ) , int ( cutoff_size [ 1 ] * 1.30 ) )
compose_kwargs = {
" prompt " : prompt ,
" target_height " : prompt . height ,
" target_width " : prompt . width ,
" cutoff " : cutoff_size ,
" dtype " : dtype ,
}
if prompt . init_image :
compose_kwargs . update (
{
" target_height " : init_image . height ,
" target_width " : init_image . width ,
}
)
comp_image , comp_img_orig = _generate_composition_image (
* * compose_kwargs , logging_context = lc
)
if comp_image is not None :
prompt . fix_faces = False # done in composition
result_images [ " composition " ] = comp_img_orig
result_images [ " composition-upscaled " ] = comp_image
composition_strength = prompt . composition_strength
first_step = int ( ( prompt . steps ) * composition_strength )
noise_step = int ( ( prompt . steps - 1 ) * composition_strength )
log_img ( comp_img_orig , " comp_image " )
log_img ( comp_image , " comp_image_upscaled " )
comp_image_t = pillow_img_to_torch_image ( comp_image )
comp_image_t = comp_image_t . to ( sd . lda . device , dtype = sd . lda . dtype )
init_latent = sd . lda . encode ( comp_image_t )
compose_control_inputs : list [ ControlInput ]
if prompt . model_weights . architecture . primary_alias == " sdxl " :
compose_control_inputs = [ ]
else :
compose_control_inputs = [
ControlInput ( mode = " details " , image = comp_image , strength = 1 ) ,
]
for control_input in compose_control_inputs :
(
controlnet ,
control_image_t ,
control_image_disp ,
) = prep_control_input (
control_input = control_input ,
sd = sd ,
init_image_t = None ,
fit_width = prompt . width ,
fit_height = prompt . height ,
)
result_images [
f " control- { control_input . mode } "
] = control_image_disp
controlnets . append ( ( controlnet , control_image_t ) )
for controlnet , control_image_t in controlnets :
msg = f " Injecting controlnet { controlnet . name } . setting to device: { sd . unet . device } , dtype: { sd . unet . dtype } "
print ( msg )
controlnet . set_controlnet_condition (
control_image_t . to ( device = sd . unet . device , dtype = sd . unet . dtype )
)
@ -263,7 +283,7 @@ def generate_single_image(
else :
msg = f " Unknown solver type: { prompt . solver_type } "
raise ValueError ( msg )
sd . scheduler . to ( device = sd . device, dtype = sd . dtype )
sd . scheduler . to ( device = sd . unet. device, dtype = sd . unet . dtype )
sd . set_num_inference_steps ( prompt . steps )
if hasattr ( sd , " mask_latents " ) and mask_image is not None :
@ -288,60 +308,66 @@ def generate_single_image(
x = init_latent , noise = noise , step = sd . steps [ noise_step ]
)
text_conditioning_kwargs = sd . calculate_text_conditioning_kwargs (
positive_prompts = prompt . prompts ,
negative_prompts = prompt . negative_prompt ,
positive_conditioning_override = prompt . conditioning ,
)
for k , v in text_conditioning_kwargs . items ( ) :
text_conditioning_kwargs [ k ] = v . to (
device = sd . unet . device , dtype = sd . unet . dtype
with lc . timing ( " text-conditioning " ) :
text_conditioning_kwargs = sd . calculate_text_conditioning_kwargs (
positive_prompts = prompt . prompts ,
negative_prompts = prompt . negative_prompt ,
positive_conditioning_override = prompt . conditioning ,
)
for k , v in text_conditioning_kwargs . items ( ) :
text_conditioning_kwargs [ k ] = v . to (
device = sd . unet . device , dtype = sd . unet . dtype
)
x = noised_latent
x = x . to ( device = sd . unet . device , dtype = sd . unet . dtype )
clear_gpu_cache ( )
for step in tqdm ( sd . steps [ first_step : ] , bar_format = " {l_bar} {bar} {r_bar} " ) :
log_latent ( x , " noisy_latent " )
x = sd (
x ,
step = step ,
condition_scale = prompt . prompt_strength ,
* * text_conditioning_kwargs ,
)
# trying to clear memory. not sure if this helps
sd . unet . set_context ( context = " self_attention_map " , value = { } )
sd . unet . _reset_context ( )
clear_gpu_cache ( )
with lc . timing ( " unet " ) :
for step in tqdm (
sd . steps [ first_step : ] , bar_format = " {l_bar} {bar} {r_bar} " , leave = False
) :
log_latent ( x , " noisy_latent " )
x = sd (
x ,
step = step ,
condition_scale = prompt . prompt_strength ,
* * text_conditioning_kwargs ,
)
# trying to clear memory. not sure if this helps
sd . unet . set_context ( context = " self_attention_map " , value = { } )
sd . unet . _reset_context ( )
clear_gpu_cache ( )
logger . debug ( " Decoding image " )
if x . device != sd . lda . device :
sd . lda . to ( x . device )
clear_gpu_cache ( )
gen_img = sd . lda . decode_latents ( x . to ( dtype = sd . lda . dtype ) )
with lc . timing ( " decode-img " ) :
gen_img = sd . lda . decode_latents ( x . to ( dtype = sd . lda . dtype ) )
if mask_image_orig and init_image :
result_images [ " pre-reconstitution " ] = gen_img
mask_final = mask_image_orig . copy ( )
# mask_final = ImageOps.invert(mask_final)
log_img ( mask_final , " reconstituting mask " )
# gen_img = Image.composite(gen_img, init_image, mask_final)
gen_img = combine_image (
original_img = init_image ,
generated_img = gen_img ,
mask_img = mask_final ,
)
log_img ( gen_img , " reconstituted image " )
with lc . timing ( " combine-image " ) :
result_images [ " pre-reconstitution " ] = gen_img
mask_final = mask_image_orig . copy ( )
# mask_final = ImageOps.invert(mask_final)
log_img ( mask_final , " reconstituting mask " )
# gen_img = Image.composite(gen_img, init_image, mask_final)
gen_img = combine_image (
original_img = init_image ,
generated_img = gen_img ,
mask_img = mask_final ,
)
log_img ( gen_img , " reconstituted image " )
upscaled_img = None
rebuilt_orig_img = None
if add_caption :
caption = generate_caption ( gen_img )
logger . info ( f " Generated caption: { caption } " )
with lc . timing ( " caption-img " ) :
caption = generate_caption ( gen_img )
logger . info ( f " Generated caption: { caption } " )
with lc . timing ( " safety-filter " ) :
safety_score = create_safety_score (
@ -352,13 +378,17 @@ def generate_single_image(
progress_latents . clear ( )
if not safety_score . is_filtered :
if prompt . fix_faces :
logger . info ( " Fixing 😊 ' s in 🖼 using CodeFormer... " )
with lc . timing ( " face enhancement " ) :
gen_img = enhance_faces ( gen_img , fidelity = prompt . fix_faces_fidelity )
with lc . timing ( " face-enhancement " ) :
logger . info ( " Fixing 😊 ' s in 🖼 using CodeFormer... " )
with lc . timing ( " face-enhancement " ) :
gen_img = enhance_faces (
gen_img , fidelity = prompt . fix_faces_fidelity
)
if prompt . upscale :
logger . info ( " Upscaling 🖼 using real-ESRGAN... " )
with lc . timing ( " upscaling " ) :
upscaled_img = upscale_image ( gen_img )
logger . info ( " Upscaling 🖼 using real-ESRGAN... " )
with lc . timing ( " upscaling " ) :
upscaled_img = upscale_image ( gen_img )
# put the newly generated patch back into the original, full-size image
if prompt . mask_modify_original and mask_image_orig and starting_image :
@ -390,13 +420,19 @@ def generate_single_image(
is_nsfw = safety_score . is_nsfw ,
safety_score = safety_score ,
result_images = result_images ,
timings= lc . get_timing s( ) ,
performance_stats= lc . get_performance_stat s( ) ,
progress_latents = [ ] , # todo
)
_most_recent_result = result
if result . timings :
logger . info ( f " Image Generated. Timings: { result . timings_str ( ) } " )
_image_name = f " { image_name } " if image_name else " "
logger . info ( f " Generated { _image_name } image in { result . total_time ( ) : .1f } s " )
if result . performance_stats :
log = logger . info if output_perf else logger . debug
log ( f " Timings: { result . timings_str ( ) } " )
log ( f " Peak VRAM: { result . gpu_str ( ' memory_peak ' ) } " )
log ( f " Ending VRAM: { result . gpu_str ( ' memory_end ' ) } " )
for controlnet , _ in controlnets :
controlnet . eject ( )
clear_gpu_cache ( )
@ -495,6 +531,7 @@ def _generate_composition_image(
target_width ,
cutoff : tuple [ int , int ] = ( 512 , 512 ) ,
dtype = None ,
logging_context = None ,
) :
from PIL import Image
@ -530,7 +567,13 @@ def _generate_composition_image(
} ,
)
result = generate_single_image ( composition_prompt , dtype = dtype )
result = generate_single_image (
composition_prompt ,
dtype = dtype ,
logging_context = logging_context ,
output_perf = False ,
image_name = " composition " ,
)
img = result . images [ " generated " ]
while img . width < target_width :
from imaginairy . enhancers . upscale_realesrgan import upscale_image
@ -538,9 +581,11 @@ def _generate_composition_image(
if prompt . fix_faces :
from imaginairy . enhancers . face_restoration_codeformer import enhance_faces
img = enhance_faces ( img , fidelity = prompt . fix_faces_fidelity )
img = upscale_image ( img , ultrasharp = True )
with logging_context . timing ( " face-enhancement " ) :
logger . info ( " Fixing 😊 ' s in 🖼 using CodeFormer... " )
img = enhance_faces ( img , fidelity = prompt . fix_faces_fidelity )
with logging_context . timing ( " upscaling " ) :
img = upscale_image ( img , ultrasharp = True )
img = img . resize (
( target_width , target_height ) ,