perf: get "half" mode working when running on cuda

pull/1/head
Bryce 2 years ago
parent b9c00dd8de
commit 292d1bcab5

@ -17,8 +17,7 @@ from transformers import cached_path
from imaginairy.modules.diffusion.ddim import DDIMSampler
from imaginairy.modules.diffusion.plms import PLMSSampler
from imaginairy.modules.find_noise import find_noise_for_latent
from imaginairy.safety import is_nsfw
from imaginairy.safety import is_nsfw, safety_models
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
fix_torch_nn_layer_norm,
@ -147,9 +146,18 @@ def imagine_images(
ddim_eta=0.0,
img_callback=None,
tile_mode=False,
half_mode=None,
):
model = load_model(tile_mode=tile_mode)
# model = model.half()
if not IMAGINAIRY_ALLOW_NSFW:
# needs to be loaded before we set default tensor type to half
safety_models()
# only run half-mode on cuda. run it by default
half_mode = True if half_mode is None and get_device() == "cuda" else False
if half_mode:
model = model.half()
# needed when model is in half mode, remove if not using half mode
torch.set_default_tensor_type(torch.HalfTensor)
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
_img_callback = None
@ -164,9 +172,6 @@ def imagine_images(
logger.info(f"Generating {prompt.prompt_description()}")
seed_everything(prompt.seed)
# needed when model is in half mode, remove if not using half mode
# torch.set_default_tensor_type(torch.HalfTensor)
uc = None
if prompt.prompt_strength != 1.0:
uc = model.get_learned_conditioning(1 * [""])
@ -238,7 +243,9 @@ def imagine_images(
for x_sample in x_samples:
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
img = Image.fromarray(x_sample.astype(np.uint8))
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(img, x_sample):
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(
img, x_sample, half_mode=half_mode
):
logger.info(" ⚠️ Filtering NSFW image")
img = Image.new("RGB", img.size, (228, 150, 150))
if prompt.fix_faces:

@ -287,6 +287,7 @@ class BasicTransformerBlock(nn.Module):
)
def _forward(self, x, context=None):
x = x.contiguous() if x.device.type == "mps" else x
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x

@ -14,10 +14,12 @@ def safety_models():
return safety_feature_extractor, safety_checker
def is_nsfw(img, x_sample):
def is_nsfw(img, x_sample, half_mode=False):
safety_feature_extractor, safety_checker = safety_models()
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
clip_input = safety_checker_input.pixel_values
_, has_nsfw_concept = safety_checker(
images=x_sample[None, :], clip_input=safety_checker_input.pixel_values
images=x_sample[None, :], clip_input=clip_input
)
return has_nsfw_concept[0]

@ -6,7 +6,6 @@ from functools import lru_cache
from typing import List, Optional
import numpy as np
import PIL
import torch
from PIL import Image
from torch import Tensor
@ -102,7 +101,7 @@ def fix_torch_nn_layer_norm():
def img_path_to_torch_image(path, max_height=512, max_width=512):
image = Image.open(path).convert("RGB")
logger.info(f"loaded input image of size {image.size} from {path}")
logger.info(f"Loaded input 🖼 of size {image.size} from {path}")
return pillow_img_to_torch_image(image, max_height=max_height, max_width=max_width)
@ -111,7 +110,7 @@ def pillow_img_to_torch_image(image, max_height=512, max_width=512):
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)

Loading…
Cancel
Save