@ -6,7 +6,7 @@ import numpy as np
import torch
import torch
import torch . nn
import torch . nn
from einops import rearrange , repeat
from einops import rearrange , repeat
from PIL import Image , ImageDraw , Image Filter, Image Ops
from PIL import Image , ImageDraw , Image Ops
from pytorch_lightning import seed_everything
from pytorch_lightning import seed_everything
from imaginairy . enhancers . clip_masking import get_img_mask
from imaginairy . enhancers . clip_masking import get_img_mask
@ -51,6 +51,9 @@ if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
elif IMAGINAIRY_SAFETY_MODE == " filter " :
elif IMAGINAIRY_SAFETY_MODE == " filter " :
IMAGINAIRY_SAFETY_MODE = SafetyMode . STRICT
IMAGINAIRY_SAFETY_MODE = SafetyMode . STRICT
# we put this in the global scope so it can be used in the interactive shell
_most_recent_result = None
def imagine_image_files (
def imagine_image_files (
prompts ,
prompts ,
@ -88,6 +91,9 @@ def imagine_image_files(
add_caption = print_caption ,
add_caption = print_caption ,
) :
) :
prompt = result . prompt
prompt = result . prompt
if prompt . is_intermediate :
# we don't save intermediate images
continue
img_str = " "
img_str = " "
if prompt . init_image :
if prompt . init_image :
img_str = f " _img2img- { prompt . init_image_strength } "
img_str = f " _img2img- { prompt . init_image_strength } "
@ -103,7 +109,7 @@ def imagine_image_files(
subpath , f " { basefilename } _[ { image_type } ]. { output_file_extension } "
subpath , f " { basefilename } _[ { image_type } ]. { output_file_extension } "
)
)
result . save ( filepath , image_type = image_type )
result . save ( filepath , image_type = image_type )
logger . info ( f " 🖼 [{ image_type } ] saved to: { filepath } " )
logger . info ( f " [{ image_type } ] saved to: { filepath } " )
if image_type == return_filename_type :
if image_type == return_filename_type :
result_filenames . append ( filepath )
result_filenames . append ( filepath )
if make_comparison_gif and prompt . init_image :
if make_comparison_gif and prompt . init_image :
@ -134,6 +140,7 @@ def imagine(
half_mode = None ,
half_mode = None ,
add_caption = False ,
add_caption = False ,
) :
) :
global _most_recent_result # noqa
latent_channels = 4
latent_channels = 4
downsampling_factor = 8
downsampling_factor = 8
batch_size = 1
batch_size = 1
@ -153,6 +160,18 @@ def imagine(
precision
precision
) , fix_torch_nn_layer_norm ( ) , fix_torch_group_norm ( ) :
) , fix_torch_nn_layer_norm ( ) , fix_torch_group_norm ( ) :
for i , prompt in enumerate ( prompts ) :
for i , prompt in enumerate ( prompts ) :
# handle prompt pulling in previous values
if isinstance ( prompt . init_image , str ) and prompt . init_image . startswith (
" *prev "
) :
_ , img_type = prompt . init_image . strip ( " * " ) . split ( " . " )
prompt . init_image = _most_recent_result . images [ img_type ]
if isinstance ( prompt . mask_image , str ) and prompt . mask_image . startswith (
" *prev "
) :
_ , img_type = prompt . mask_image . strip ( " * " ) . split ( " . " )
prompt . mask_image = _most_recent_result . images [ img_type ]
logger . info (
logger . info (
f " Generating 🖼 { i + 1 } / { num_prompts } : { prompt . prompt_description ( ) } "
f " Generating 🖼 { i + 1 } / { num_prompts } : { prompt . prompt_description ( ) } "
)
)
@ -419,9 +438,10 @@ def imagine(
x_sample_8_orig = x_sample . astype ( np . uint8 )
x_sample_8_orig = x_sample . astype ( np . uint8 )
img = Image . fromarray ( x_sample_8_orig )
img = Image . fromarray ( x_sample_8_orig )
if mask_image_orig and init_image :
if mask_image_orig and init_image :
mask_final = mask_image_orig . filter (
# mask_final = mask_image_orig.filter(
ImageFilter . GaussianBlur ( radius = 3 )
# ImageFilter.GaussianBlur(radius=3)
)
# )
mask_final = mask_image_orig . copy ( )
log_img ( mask_final , " reconstituting mask " )
log_img ( mask_final , " reconstituting mask " )
mask_final = ImageOps . invert ( mask_final )
mask_final = ImageOps . invert ( mask_final )
img = Image . composite ( img , init_image , mask_final )
img = Image . composite ( img , init_image , mask_final )
@ -471,9 +491,9 @@ def imagine(
starting_image . size ,
starting_image . size ,
resample = Image . Resampling . LANCZOS ,
resample = Image . Resampling . LANCZOS ,
)
)
mask_for_orig_size = mask_for_orig_size . filter (
# mask_for_orig_size = mask_for_orig_size.filter (
ImageFilter . GaussianBlur ( radius = 5 )
# ImageFilter.GaussianBlur(radius=5)
)
# )
log_img ( mask_for_orig_size , " mask for original image size " )
log_img ( mask_for_orig_size , " mask for original image size " )
rebuilt_orig_img = Image . composite (
rebuilt_orig_img = Image . composite (
@ -495,6 +515,7 @@ def imagine(
depth_image = depth_image_display ,
depth_image = depth_image_display ,
timings = lc . get_timings ( ) ,
timings = lc . get_timings ( ) ,
)
)
_most_recent_result = result
logger . info ( f " Image Generated. Timings: { result . timings_str ( ) } " )
logger . info ( f " Image Generated. Timings: { result . timings_str ( ) } " )
yield result
yield result