|
|
@ -24,7 +24,7 @@ from imaginairy.img_log import (
|
|
|
|
log_latent,
|
|
|
|
log_latent,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
|
|
|
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
|
|
|
from imaginairy.safety import is_nsfw
|
|
|
|
from imaginairy.safety import SafetyMode, create_safety_score
|
|
|
|
from imaginairy.samplers.base import get_sampler
|
|
|
|
from imaginairy.samplers.base import get_sampler
|
|
|
|
from imaginairy.samplers.plms import PLMSSchedule
|
|
|
|
from imaginairy.samplers.plms import PLMSSchedule
|
|
|
|
from imaginairy.schema import ImaginePrompt, ImagineResult
|
|
|
|
from imaginairy.schema import ImaginePrompt, ImagineResult
|
|
|
@ -40,16 +40,14 @@ LIB_PATH = os.path.dirname(__file__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SafetyMode:
|
|
|
|
|
|
|
|
DISABLED = "disabled"
|
|
|
|
|
|
|
|
CLASSIFY = "classify"
|
|
|
|
|
|
|
|
FILTER = "filter"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# leave undocumented. I'd ask that no one publicize this flag. Just want a
|
|
|
|
# leave undocumented. I'd ask that no one publicize this flag. Just want a
|
|
|
|
# slight barrier to entry. Please don't use this is any way that's gonna cause
|
|
|
|
# slight barrier to entry. Please don't use this is any way that's gonna cause
|
|
|
|
# the press or governments to freak out about AI...
|
|
|
|
# the media or politicians to freak out about AI...
|
|
|
|
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER)
|
|
|
|
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT)
|
|
|
|
|
|
|
|
if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
|
|
|
|
|
|
|
|
IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED
|
|
|
|
|
|
|
|
elif IMAGINAIRY_SAFETY_MODE == "filter":
|
|
|
|
|
|
|
|
IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_MODEL_WEIGHTS_LOCATION = (
|
|
|
|
DEFAULT_MODEL_WEIGHTS_LOCATION = (
|
|
|
|
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
|
|
|
|
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
|
|
|
@ -68,8 +66,7 @@ def load_model_from_config(
|
|
|
|
ckpt_path = cached_path(model_weights_location)
|
|
|
|
ckpt_path = cached_path(model_weights_location)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ckpt_path = model_weights_location
|
|
|
|
ckpt_path = model_weights_location
|
|
|
|
logger.info(f"Loading model onto {get_device()} backend...")
|
|
|
|
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
|
|
|
|
logger.debug(f"Loading model from {ckpt_path}")
|
|
|
|
|
|
|
|
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
if "global_step" in pl_sd:
|
|
|
|
if "global_step" in pl_sd:
|
|
|
|
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
|
|
|
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
|
|
@ -352,59 +349,61 @@ def imagine(
|
|
|
|
|
|
|
|
|
|
|
|
upscaled_img = None
|
|
|
|
upscaled_img = None
|
|
|
|
rebuilt_orig_img = None
|
|
|
|
rebuilt_orig_img = None
|
|
|
|
is_nsfw_img = None
|
|
|
|
|
|
|
|
if add_caption:
|
|
|
|
if add_caption:
|
|
|
|
caption = generate_caption(img)
|
|
|
|
caption = generate_caption(img)
|
|
|
|
logger.info(f" Generated caption: {caption}")
|
|
|
|
logger.info(f" Generated caption: {caption}")
|
|
|
|
if IMAGINAIRY_SAFETY_MODE != SafetyMode.DISABLED:
|
|
|
|
|
|
|
|
is_nsfw_img = is_nsfw(img)
|
|
|
|
safety_score = create_safety_score(
|
|
|
|
if is_nsfw_img and IMAGINAIRY_SAFETY_MODE == SafetyMode.FILTER:
|
|
|
|
img,
|
|
|
|
logger.info(" ⚠️ Filtering NSFW image")
|
|
|
|
safety_mode=IMAGINAIRY_SAFETY_MODE,
|
|
|
|
img = img.filter(ImageFilter.GaussianBlur(radius=40))
|
|
|
|
)
|
|
|
|
|
|
|
|
if not safety_score.is_filtered:
|
|
|
|
if prompt.fix_faces:
|
|
|
|
if prompt.fix_faces:
|
|
|
|
logger.info(" Fixing 😊 's in 🖼 using CodeFormer...")
|
|
|
|
logger.info(" Fixing 😊 's in 🖼 using CodeFormer...")
|
|
|
|
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
|
|
|
|
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
|
|
|
|
if prompt.upscale:
|
|
|
|
if prompt.upscale:
|
|
|
|
logger.info(" Upscaling 🖼 using real-ESRGAN...")
|
|
|
|
logger.info(" Upscaling 🖼 using real-ESRGAN...")
|
|
|
|
upscaled_img = upscale_image(img)
|
|
|
|
upscaled_img = upscale_image(img)
|
|
|
|
|
|
|
|
|
|
|
|
# put the newly generated patch back into the original, full size image
|
|
|
|
# put the newly generated patch back into the original, full size image
|
|
|
|
if (
|
|
|
|
if (
|
|
|
|
prompt.mask_modify_original
|
|
|
|
prompt.mask_modify_original
|
|
|
|
and mask_image_orig
|
|
|
|
and mask_image_orig
|
|
|
|
and prompt.init_image
|
|
|
|
and prompt.init_image
|
|
|
|
):
|
|
|
|
):
|
|
|
|
img_to_add_back_to_original = (
|
|
|
|
img_to_add_back_to_original = (
|
|
|
|
upscaled_img if upscaled_img else img
|
|
|
|
upscaled_img if upscaled_img else img
|
|
|
|
)
|
|
|
|
)
|
|
|
|
img_to_add_back_to_original = (
|
|
|
|
img_to_add_back_to_original = (
|
|
|
|
img_to_add_back_to_original.resize(
|
|
|
|
img_to_add_back_to_original.resize(
|
|
|
|
|
|
|
|
prompt.init_image.size,
|
|
|
|
|
|
|
|
resample=Image.Resampling.LANCZOS,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_for_orig_size = mask_image_orig.resize(
|
|
|
|
prompt.init_image.size,
|
|
|
|
prompt.init_image.size,
|
|
|
|
resample=Image.Resampling.LANCZOS,
|
|
|
|
resample=Image.Resampling.LANCZOS,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
mask_for_orig_size = mask_for_orig_size.filter(
|
|
|
|
|
|
|
|
ImageFilter.GaussianBlur(radius=5)
|
|
|
|
mask_for_orig_size = mask_image_orig.resize(
|
|
|
|
)
|
|
|
|
prompt.init_image.size, resample=Image.Resampling.LANCZOS
|
|
|
|
log_img(mask_for_orig_size, "mask for original image size")
|
|
|
|
)
|
|
|
|
|
|
|
|
mask_for_orig_size = mask_for_orig_size.filter(
|
|
|
|
|
|
|
|
ImageFilter.GaussianBlur(radius=5)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
log_img(mask_for_orig_size, "mask for original image size")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rebuilt_orig_img = Image.composite(
|
|
|
|
rebuilt_orig_img = Image.composite(
|
|
|
|
prompt.init_image,
|
|
|
|
prompt.init_image,
|
|
|
|
img_to_add_back_to_original,
|
|
|
|
img_to_add_back_to_original,
|
|
|
|
mask_for_orig_size,
|
|
|
|
mask_for_orig_size,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
log_img(rebuilt_orig_img, "reconstituted original")
|
|
|
|
log_img(rebuilt_orig_img, "reconstituted original")
|
|
|
|
|
|
|
|
|
|
|
|
yield ImagineResult(
|
|
|
|
yield ImagineResult(
|
|
|
|
img=img,
|
|
|
|
img=img,
|
|
|
|
prompt=prompt,
|
|
|
|
prompt=prompt,
|
|
|
|
upscaled_img=upscaled_img,
|
|
|
|
upscaled_img=upscaled_img,
|
|
|
|
is_nsfw=is_nsfw_img,
|
|
|
|
is_nsfw=safety_score.is_nsfw,
|
|
|
|
|
|
|
|
safety_score=safety_score,
|
|
|
|
modified_original=rebuilt_orig_img,
|
|
|
|
modified_original=rebuilt_orig_img,
|
|
|
|
mask_binary=mask_image_orig,
|
|
|
|
mask_binary=mask_image_orig,
|
|
|
|
mask_grayscale=mask_grayscale,
|
|
|
|
mask_grayscale=mask_grayscale,
|
|
|
|