|
|
|
@ -17,8 +17,7 @@ from transformers import cached_path
|
|
|
|
|
|
|
|
|
|
from imaginairy.modules.diffusion.ddim import DDIMSampler
|
|
|
|
|
from imaginairy.modules.diffusion.plms import PLMSSampler
|
|
|
|
|
from imaginairy.modules.find_noise import find_noise_for_latent
|
|
|
|
|
from imaginairy.safety import is_nsfw
|
|
|
|
|
from imaginairy.safety import is_nsfw, safety_models
|
|
|
|
|
from imaginairy.schema import ImaginePrompt, ImagineResult
|
|
|
|
|
from imaginairy.utils import (
|
|
|
|
|
fix_torch_nn_layer_norm,
|
|
|
|
@ -147,9 +146,18 @@ def imagine_images(
|
|
|
|
|
ddim_eta=0.0,
|
|
|
|
|
img_callback=None,
|
|
|
|
|
tile_mode=False,
|
|
|
|
|
half_mode=None,
|
|
|
|
|
):
|
|
|
|
|
model = load_model(tile_mode=tile_mode)
|
|
|
|
|
# model = model.half()
|
|
|
|
|
if not IMAGINAIRY_ALLOW_NSFW:
|
|
|
|
|
# needs to be loaded before we set default tensor type to half
|
|
|
|
|
safety_models()
|
|
|
|
|
# only run half-mode on cuda. run it by default
|
|
|
|
|
half_mode = True if half_mode is None and get_device() == "cuda" else False
|
|
|
|
|
if half_mode:
|
|
|
|
|
model = model.half()
|
|
|
|
|
# needed when model is in half mode, remove if not using half mode
|
|
|
|
|
torch.set_default_tensor_type(torch.HalfTensor)
|
|
|
|
|
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
|
|
|
|
|
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
|
|
|
|
|
_img_callback = None
|
|
|
|
@ -164,9 +172,6 @@ def imagine_images(
|
|
|
|
|
logger.info(f"Generating {prompt.prompt_description()}")
|
|
|
|
|
seed_everything(prompt.seed)
|
|
|
|
|
|
|
|
|
|
# needed when model is in half mode, remove if not using half mode
|
|
|
|
|
# torch.set_default_tensor_type(torch.HalfTensor)
|
|
|
|
|
|
|
|
|
|
uc = None
|
|
|
|
|
if prompt.prompt_strength != 1.0:
|
|
|
|
|
uc = model.get_learned_conditioning(1 * [""])
|
|
|
|
@ -238,7 +243,9 @@ def imagine_images(
|
|
|
|
|
for x_sample in x_samples:
|
|
|
|
|
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
|
|
|
|
img = Image.fromarray(x_sample.astype(np.uint8))
|
|
|
|
|
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(img, x_sample):
|
|
|
|
|
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(
|
|
|
|
|
img, x_sample, half_mode=half_mode
|
|
|
|
|
):
|
|
|
|
|
logger.info(" ⚠️ Filtering NSFW image")
|
|
|
|
|
img = Image.new("RGB", img.size, (228, 150, 150))
|
|
|
|
|
if prompt.fix_faces:
|
|
|
|
|