tests: add more tests

This commit is contained in:
Bryce 2022-09-17 15:49:38 -07:00
parent d7cbf6e416
commit 8238e59067
10 changed files with 548 additions and 162 deletions

View File

@ -18,10 +18,6 @@ def realesrgan_upsampler():
model_path = get_cached_url_path(url)
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0)
if get_device() == "cuda":
device = "cuda"
else:
device = "cpu"
device = get_device()
upsampler.device = torch.device(device)
@ -35,7 +31,3 @@ def upscale_image(img):
np_img = np.array(img, dtype=np.uint8)
upsampler_output, img_mode = realesrgan_upsampler().enhance(np_img[:, :, ::-1])
return Image.fromarray(upsampler_output[:, :, ::-1], mode=img_mode)
if __name__ == "__main__":
realesrgan_upsampler()

View File

@ -10,16 +10,12 @@ from imaginairy.modules.diffusion.util import checkpoint
from imaginairy.utils import get_device, get_device_name
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
if val is not None:
return val
return d() if isfunction(d) else d
@ -50,7 +46,7 @@ class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
dim_out = dim_out if dim_out is not None else dim
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
@ -152,7 +148,7 @@ class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
context_dim = context_dim if context_dim is not None else query_dim
self.scale = dim_head**-0.5
self.heads = heads
@ -172,7 +168,7 @@ class CrossAttention(nn.Module):
h = self.heads
q = self.to_q(x)
context = default(context, x)
context = context if context is not None else x
k = self.to_k(context)
v = self.to_v(context)
@ -180,7 +176,7 @@ class CrossAttention(nn.Module):
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
if mask is not None:
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)

View File

@ -13,150 +13,6 @@ from imaginairy.utils import instantiate_from_config
logger = logging.getLogger(__name__)
class VectorQuantizer(nn.Module):
"""
Improved version over original VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
https://github.com/CompVis/taming-transformers/blob/141eb746f567a731f71cd703796d4d53a323f45f/taming/modules/vqvae/quantize.py#L213
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(
self,
n_e,
e_dim,
beta,
remap=None,
unknown_index="random",
sane_index_shape=False,
legacy=True,
):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
logger.info(
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits == False, "Only for interface compatible with Gumbel"
assert return_logits == False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, "b c h w -> b h w c").contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2
* torch.einsum(
"bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
)
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
(z_q - z.detach()) ** 2
)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(
z.shape[0], -1
) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class AutoencoderKL(pl.LightningModule):
def __init__(
self,

View File

@ -1,5 +1,6 @@
from functools import lru_cache
import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
@ -20,6 +21,6 @@ def is_nsfw(img, x_sample):
clip_input = safety_checker_input.pixel_values
_, has_nsfw_concept = safety_checker(
images=x_sample[None, :], clip_input=clip_input
images=[np.empty((2, 2))], clip_input=clip_input
)
return has_nsfw_concept[0]

Binary file not shown.

After

Width:  |  Height:  |  Size: 286 KiB

BIN
tests/data/safety.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.6 KiB

13
tests/test_cmds.py Normal file
View File

@ -0,0 +1,13 @@
from click.testing import CliRunner
from imaginairy.cmds import imagine_cmd
from tests import TESTS_FOLDER
def test_imagine_cmd():
runner = CliRunner()
result = runner.invoke(
imagine_cmd,
["gold coins", "--steps", "5", "--outdir", f"{TESTS_FOLDER}/test_output"],
)
assert result.exit_code == 0

23
tests/test_enhancers.py Normal file
View File

@ -0,0 +1,23 @@
import hashlib
from PIL import Image
from pytorch_lightning import seed_everything
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.utils import get_device
from tests import TESTS_FOLDER
def test_fix_faces():
img = Image.open(f"{TESTS_FOLDER}/data/distorted_face.png")
seed_everything(1)
img = enhance_faces(img)
img.save(f"{TESTS_FOLDER}/test_output/fixed_face.png")
if "mps" in get_device():
assert img_hash(img) == "a75991307eda675a26eeb7073f828e93"
else:
assert img_hash(img) == "5aa847a1464de75b158658a35800b6bf"
def img_hash(img):
return hashlib.md5(img.tobytes()).hexdigest()

481
tests/test_guidance.py Normal file
View File

@ -0,0 +1,481 @@
#
# import torch
#
# from imaginairy.utils import get_device
#
# torch.manual_seed(0)
# from transformers import CLIPTextModel, CLIPTokenizer
# from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
#
# from tqdm.auto import tqdm, trange
# from torch import autocast
# import PIL.Image as PImage
# from PIL import Image
# import numpy
# from torchvision import transforms
# import torchvision.transforms.functional as f
# import random
# import requests
# from io import BytesIO
#
# # import clip
# import open_clip as clip
# from torch import nn
# import torch.nn.functional as F
# import io
#
# offload_device = "cpu"
# model_name = "CompVis/stable-diffusion-v1-4"
# attention_slicing = True #@param {"type":"boolean"}
# unet_path = False
#
# vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", use_auth_token=True)
#
# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# try:
# text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", use_auth_token=True)
# except:
# print("Text encoder could not be loaded from the repo specified for some reason, falling back to the vit-l repo")
# text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
#
# if unet_path!=None:
# # unet = UNet2DConditionModel.from_pretrained(unet_path)
# from huggingface_hub import hf_hub_download
# model_name = hf_hub_download(repo_id=unet_path, filename="unet.pt")
# unet = torch.jit.load(model_name)
# else:
# unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", use_auth_token=True)
# if attention_slicing:
# slice_size = unet.config.attention_head_dim // 2
# unet.set_attention_slice(slice_size)
#
# scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
#
# vae = vae.to(offload_device).half()
# text_encoder = text_encoder.to(offload_device).half()
# unet = unet.to(get_device()).half()
# class MakeCutouts(nn.Module):
# def __init__(self, cut_size, cutn, cut_pow=1.):
# super().__init__()
# self.cut_size = cut_size
# self.cutn = cutn
# self.cut_pow = cut_pow
#
# def forward(self, input):
# sideY, sideX = input.shape[2:4]
# max_size = min(sideX, sideY)
# min_size = min(sideX, sideY, self.cut_size)
# cutouts = []
# for _ in range(self.cutn):
# size = int(torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size)
# offsetx = torch.randint(0, sideX - size + 1, ())
# offsety = torch.randint(0, sideY - size + 1, ())
# cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
# cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
# return torch.cat(cutouts)
#
#
# to_tensor_tfm = transforms.ToTensor()
#
#
# # mismatch of tons of image encoding / decoding / loading functions i cant be asked to clean up right now
#
# def pil_to_latent(input_im):
# # Single image -> single latent in a batch (so size 1, 4, 64, 64)
# with torch.no_grad():
# with autocast("cuda"):
# latent = vae.encode(to_tensor_tfm(input_im.convert("RGB")).unsqueeze(0).to(
# get_device()) * 2 - 1).latent_dist # Note scaling
# # print(latent)
# return 0.18215 * latent.mode() # or .mean or .sample
#
#
# def latents_to_pil(latents):
# # bath of latents -> list of images
# latents = (1 / 0.18215) * latents
# with torch.no_grad():
# image = vae.decode(latents)
# image = (image / 2 + 0.5).clamp(0, 1)
# image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
# images = (image * 255).round().astype("uint8")
# pil_images = [Image.fromarray(image) for image in images]
# return pil_images
#
#
# def get_latent_from_url(url, size=(512, 512)):
# response = requests.get(url)
# img = PImage.open(BytesIO(response.content))
# img = img.resize(size).convert("RGB")
# latent = pil_to_latent(img)
# return latent
#
#
# def scale_and_decode(latents):
# with autocast("cuda"):
# # scale and decode the image latents with vae
# latents = 1 / 0.18215 * latents
# with torch.no_grad():
# image = vae.decode(latents).sample.squeeze(0)
# image = f.to_pil_image((image / 2 + 0.5).clamp(0, 1))
# return image
#
#
# def fetch(url_or_path):
# import io
# if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
# r = requests.get(url_or_path)
# r.raise_for_status()
# fd = io.BytesIO()
# fd.write(r.content)
# fd.seek(0)
# return PImage.open(fd).convert('RGB')
# return PImage.open(open(url_or_path, 'rb')).convert('RGB')
#
#
# """
# grabs all text up to the first occurrence of ':'
# uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
# if ':' has no value defined, defaults to 1.0
# repeats until no text remaining
# """
#
#
# def split_weighted_subprompts(text, split=":"):
# remaining = len(text)
# prompts = []
# weights = []
# while remaining > 0:
# if split in text:
# idx = text.index(split) # first occurrence from start
# # grab up to index as sub-prompt
# prompt = text[:idx]
# remaining -= idx
# # remove from main text
# text = text[idx + 1:]
# # find value for weight
# if " " in text:
# idx = text.index(" ") # first occurence
# else: # no space, read to end
# idx = len(text)
# if idx != 0:
# try:
# weight = float(text[:idx])
# except: # couldn't treat as float
# print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
# weight = 1.0
# else: # no value found
# weight = 1.0
# # remove from main text
# remaining -= idx
# text = text[idx + 1:]
# # append the sub-prompt and its weight
# prompts.append(prompt)
# weights.append(weight)
# else: # no : found
# if len(text) > 0: # there is still text though
# # take remainder as weight 1
# prompts.append(text)
# weights.append(1.0)
# remaining = 0
# print(prompts, weights)
# return prompts, weights
#
#
# # from some stackoverflow comment
# import numpy as np
#
#
# def lerp(a, b, x):
# "linear interpolation"
# return a + x * (b - a)
#
#
# def fade(t):
# "6t^5 - 15t^4 + 10t^3"
# return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
#
#
# def gradient(h, x, y):
# "grad converts h to the right gradient vector and return the dot product with (x,y)"
# vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
# g = vectors[h % 4]
# return g[:, :, 0] * x + g[:, :, 1] * y
#
#
# def perlin(x, y, seed=0):
# # permutation table
# np.random.seed(seed)
# p = np.arange(256, dtype=int)
# np.random.shuffle(p)
# p = np.stack([p, p]).flatten()
# # coordinates of the top-left
# xi, yi = x.astype(int), y.astype(int)
# # internal coordinates
# xf, yf = x - xi, y - yi
# # fade factors
# u, v = fade(xf), fade(yf)
# # noise components
# n00 = gradient(p[p[xi] + yi], xf, yf)
# n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
# n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
# n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
# # combine noises
# x1 = lerp(n00, n10, u)
# x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
# return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
#
#
# def sample(args):
# global in_channels
# global text_encoder # uugghhhghhghgh
# global vae # UUGHGHHGHGH
# global unet # .hggfkgjks;ldjf
# # prompt = args.prompt
# prompts, weights = split_weighted_subprompts(args.prompt)
# h, w = args.size
# steps = args.steps
# scale = args.scale
# classifier_guidance = args.classifier_guidance
# use_init = len(args.init_img) > 1
# if args.seed != -1:
# seed = args.seed
# generator = torch.manual_seed(seed)
# else:
# seed = random.randint(0, 10_000)
# generator = torch.manual_seed(seed)
# print(f"Generating with seed {seed}...")
#
# # tokenize / encode text
# tokens = [tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
# return_tensors="pt") for prompt in prompts]
# with torch.no_grad():
# # move CLIP to cuda
# text_encoder = text_encoder.to(get_device())
# text_embeddings = [text_encoder(tok.input_ids.to(get_device()))[0].unsqueeze(0) for tok in tokens]
# text_embeddings = [text_embeddings[i] * weights[i] for i in range(len(text_embeddings))]
# text_embeddings = torch.cat(text_embeddings, 0).sum(0)
# max_length = 77
# uncond_input = tokenizer(
# [""], padding="max_length", max_length=max_length, return_tensors="pt"
# )
# uncond_embeddings = text_encoder(uncond_input.input_ids.to(get_device()))[0]
# text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# # move it back to CPU so there's more vram for generating
# text_encoder = text_encoder.to(offload_device)
# images = []
#
# if args.lpips_guidance:
# import lpips
# lpips_model = lpips.LPIPS(net='vgg').to(get_device())
# init = to_tensor_tfm(fetch(args.init_img).resize(args.size)).to(get_device())
#
# for batch_n in trange(args.batches):
# with autocast("cuda"):
# # unet = unet.to(get_device())
# scheduler.set_timesteps(steps)
# if not use_init or args.start_step == 0:
# latents = torch.randn(
# (1, in_channels, h // 8, w // 8),
# generator=generator
# )
# latents = latents.to(get_device())
# latents = latents * scheduler.sigmas[0]
# start_step = args.start_step
# else:
# # Start step
# start_step = args.start_step - 1
# start_sigma = scheduler.sigmas[start_step]
# start_timestep = int(scheduler.timesteps[start_step])
#
# # Prep latents
# vae = vae.to(get_device())
# encoded = get_latent_from_url(args.init_img)
# if not classifier_guidance:
# vae = vae.to(offload_device)
#
# noise = torch.randn_like(encoded)
# sigmas = scheduler.match_shape(scheduler.sigmas[start_step], noise)
# noisy_samples = encoded + noise * sigmas
#
# latents = noisy_samples.to(get_device()).half()
#
# if args.perlin_multi != 0:
# linx = np.linspace(0, 5, h // 8, endpoint=False)
# liny = np.linspace(0, 5, w // 8, endpoint=False)
# x, y = np.meshgrid(liny, linx)
# p = [np.expand_dims(perlin(x, y, seed=i), 0) for i in range(4)] # reproducable seed
# p = np.concatenate(p, 0)
# p = torch.tensor(p).unsqueeze(0).cuda()
# latents = latents + (p * args.perlin_multi).to(get_device()).half()
#
# for i, t in tqdm(enumerate(scheduler.timesteps), total=steps):
# if i > start_step:
# latent_model_input = torch.cat([latents] * 2)
# sigma = scheduler.sigmas[i]
# latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5)
#
# with torch.no_grad():
# # noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# # noise_pred = unet(latent_model_input, torch.tensor(t, dtype=torch.float32).cuda().half(), text_embeddings)#["sample"]
# noise_pred = unet(latent_model_input, torch.tensor(t, dtype=torch.float32).cuda(),
# text_embeddings) # ["sample"]
#
# # cfg
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + scale * (noise_pred_text - noise_pred_uncond)
#
# # cg
# if classifier_guidance:
# # vae = vae.to(get_device())
# if vae.device != latents.device:
# vae = vae.to(latents.device)
# latents = latents.detach().requires_grad_()
# latents_x0 = latents - sigma * noise_pred
# denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
# if args.loss_scale != 0:
# loss = args.loss_fn(denoised_images) * args.loss_scale
# else:
# loss = 0
# init_losses = lpips_model(denoised_images, init)
# loss = loss + init_losses.sum() * args.lpips_scale
#
# cond_grad = -torch.autograd.grad(loss, latents)[0]
# latents = latents.detach() + cond_grad * sigma ** 2
# # vae = vae.to(offload_device)
#
# latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
# vae = vae.to(get_device())
# output_image = scale_and_decode(latents)
# vae = vae.to(offload_device)
# images.append(output_image)
#
# import gc
# gc.collect()
# torch.cuda.empty_cache()
#
# images[-1].save(f"output/{batch_n}.png")
#
# return images
#
# def test_guided_image():
# prompt = "tardigrade portrait [intricate] [artstation]" # @param {"type":"string"}
#
# # prompt = add_suffixes(prompt)
#
# init_img = "" # @param {"type":"string"}
# size = [640, 640] # @param
# steps = 65 # @param
# start_step = 0 # @param
# perlin_multi = 0.4 # @param
# scale = 7 # @param
# seed = -1 # @param
# batches = 4 # @param
# # @markdown ---
#
# # @markdown ### Classifier Guidance
# # @markdown `classifier_guidance` is whether or not to use the loss function in the previous cell to guide the image (slows down image generation a lot) <br>
# # @markdown it also is very hit-and-miss in terms of quality, but can be really really good, try setting batches high and then taking a nap <br>
# # @markdown `lpips_guidance` is for if you're using an init_img, it'll let you start closer to the beginning while trying to keep the overall shapes similar
# # @markdown `lpips_scale` is similar to `loss_scale` but it's how much to push the model to keep the shapes the same <br>
# # @markdown `loss_scale` is how much to guide according to that loss function <br>
# # @markdown `clip_text_prompt` is a prompt for CLIP to optimize towards, if using classifier guidance (supports weighting with `prompt:weight`) <br>
# # @markdown `clip_image_prompt` is an image url for CLIP to optimize towards if using classifier guidance (supports weighting with `url|weight` because of colons coming up in urls) <br>
# # @markdown for `clip_model_name` and `clip_model_pretrained` check out the openclip repository https://github.com/mlfoundations/open_clip <br>
# # @markdown `cutn` is the amount of permutations of the image to show to clip (can help with stability) <br>
# # @markdown `accumulate` is how many times to run the image through the clip model, can help if you can only fit low cutn on the machine <br>
# # @markdown *you cannot use the textual inversion tokens with the clip text prompt* <br>
# # @markdown *also clip guidance sucks for most things except removing very small details that dont make sense*
# classifier_guidance = True # @param {"type":"boolean"}
# lpips_guidance = False # @param {"type":"boolean"}
# lpips_scale = 0 # @param
# loss_scale = 1. # @param
#
# class BlankClass():
# def __init__(self):
# bruh = 'BRUH'
#
# args = BlankClass()
# args.prompt = prompt
# args.init_img = init_img
# args.size = size
# args.steps = steps
# args.start_step = start_step
# args.scale = scale
# args.perlin_multi = perlin_multi
# args.seed = seed
# args.batches = batches
# args.classifier_guidance = classifier_guidance
# args.lpips_guidance = lpips_guidance
# args.lpips_scale = lpips_scale
# args.loss_scale = loss_scale
#
# loss_scale = 1
# # make_cutouts = MakeCutouts(224, 16)
#
# clip_text_prompt = "tardigrade portrait [intricate] [artstation]" # @param {"type":"string"}
# # clip_text_prompt = add_suffixes(clip_text_prompt)
# clip_image_prompt = "" # @param {"type":"string"}
#
# if loss_scale != 0:
# # clip_model = clip.load("ViT-B/32", jit=False)[0].eval().requires_grad_(False).to(get_device())
# clip_model_name = "ViT-B-32" # @param {"type":"string"}
# clip_model_pretrained = "laion2b_s34b_b79k" # @param {"type":"string"}
# clip_model, _, preprocess = clip.create_model_and_transforms(clip_model_name, pretrained=clip_model_pretrained)
# clip_model = clip_model.eval().requires_grad_(False).to(get_device())
#
# cutn = 4 # @param
# make_cutouts = MakeCutouts(clip_model.visual.image_size if type(clip_model.visual.image_size) != tuple else
# clip_model.visual.image_size[0], cutn)
#
# target = None
# if len(clip_text_prompt) > 1:
# clip_text_prompt, clip_text_weights = split_weighted_subprompts(clip_text_prompt)
# target = clip_model.encode_text(clip.tokenize(clip_text_prompt).to(get_device())) * torch.tensor(
# clip_text_weights).view(len(clip_text_prompt), 1).to(get_device())
# if len(clip_image_prompt) > 1:
# clip_image_prompt, clip_image_weights = split_weighted_subprompts(clip_image_prompt, split="|")
# # pesky spaces
# clip_image_prompt = [p.replace(" ", "") for p in clip_image_prompt]
# images = [fetch(image) for image in clip_image_prompt]
# images = [f.to_tensor(i).unsqueeze(0) for i in images]
# images = [make_cutouts(i) for i in images]
# encodings = [clip_model.encode_image(i.to(get_device())).mean(0) for i in images]
#
# for i in range(len(encodings)):
# encodings[i] = (encodings[i] * clip_image_weights[i]).unsqueeze(0)
# # print(encodings.shape)
# encodings = torch.cat(encodings, 0)
# encoding = encodings.sum(0)
#
# if target != None:
# target = target + encoding
# else:
# target = encoding
# target = target.half().to(get_device())
#
# # free a little memory, we dont use the text encoder after this so just delete it
# clip_model.transformer = None
# import gc
# gc.collect()
# torch.cuda.empty_cache()
#
# def spherical_distance(x, y):
# x = F.normalize(x, dim=-1)
# y = F.normalize(y, dim=-1)
# l = (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean()
# return l
#
# def loss_fn(x):
# with torch.autocast("cuda"):
# cutouts = make_cutouts(x)
# encoding = clip_model.encode_image(cutouts.float()).half()
# loss = spherical_distance(encoding, target)
# return loss.mean()
#
# args.loss_fn = loss_fn
#
#
# dtype = torch.float16
# with torch.amp.autocast(device_type=get_device(), dtype=dtype):
# output = sample(args)
# print("Done!")

24
tests/test_safety.py Normal file
View File

@ -0,0 +1,24 @@
from PIL import Image
from imaginairy.api import load_model
from imaginairy.safety import is_nsfw
from imaginairy.utils import get_device, pillow_img_to_torch_image
from tests import TESTS_FOLDER
def test_is_nsfw():
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
latent = _pil_to_latent(img)
assert is_nsfw(img, latent)
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
latent = _pil_to_latent(img)
assert not is_nsfw(img, latent)
def _pil_to_latent(img):
model = load_model()
img, w, h = pillow_img_to_torch_image(img)
img = img.to(get_device())
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
return latent