mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
tests: add more tests
This commit is contained in:
parent
d7cbf6e416
commit
8238e59067
@ -18,10 +18,6 @@ def realesrgan_upsampler():
|
||||
model_path = get_cached_url_path(url)
|
||||
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0)
|
||||
|
||||
if get_device() == "cuda":
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
device = get_device()
|
||||
|
||||
upsampler.device = torch.device(device)
|
||||
@ -35,7 +31,3 @@ def upscale_image(img):
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
upsampler_output, img_mode = realesrgan_upsampler().enhance(np_img[:, :, ::-1])
|
||||
return Image.fromarray(upsampler_output[:, :, ::-1], mode=img_mode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
realesrgan_upsampler()
|
||||
|
@ -10,16 +10,12 @@ from imaginairy.modules.diffusion.util import checkpoint
|
||||
from imaginairy.utils import get_device, get_device_name
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
if val is not None:
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
@ -50,7 +46,7 @@ class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
@ -152,7 +148,7 @@ class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
context_dim = context_dim if context_dim is not None else query_dim
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
@ -172,7 +168,7 @@ class CrossAttention(nn.Module):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
context = context if context is not None else x
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
@ -180,7 +176,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, "b ... -> b (...)")
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||
|
@ -13,150 +13,6 @@ from imaginairy.utils import instantiate_from_config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
Improved version over original VectorQuantizer, can be used as a drop-in replacement. Mostly
|
||||
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
||||
|
||||
https://github.com/CompVis/taming-transformers/blob/141eb746f567a731f71cd703796d4d53a323f45f/taming/modules/vqvae/quantize.py#L213
|
||||
"""
|
||||
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
# backwards compatibility we use the buggy version by default, but you can
|
||||
# specify legacy=False to fix it.
|
||||
def __init__(
|
||||
self,
|
||||
n_e,
|
||||
e_dim,
|
||||
beta,
|
||||
remap=None,
|
||||
unknown_index="random",
|
||||
sane_index_shape=False,
|
||||
legacy=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
self.legacy = legacy
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
logger.info(
|
||||
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
else:
|
||||
self.re_embed = n_e
|
||||
|
||||
self.sane_index_shape = sane_index_shape
|
||||
|
||||
def remap_to_used(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:, :, None] == used[None, None, ...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||
device=new.device
|
||||
)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
||||
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
||||
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
||||
assert return_logits == False, "Only for interface compatible with Gumbel"
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = (
|
||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||
- 2
|
||||
* torch.einsum(
|
||||
"bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
|
||||
)
|
||||
)
|
||||
|
||||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||
perplexity = None
|
||||
min_encodings = None
|
||||
|
||||
# compute loss for embedding
|
||||
if not self.legacy:
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
else:
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
||||
|
||||
if self.remap is not None:
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z.shape[0], -1
|
||||
) # add batch axis
|
||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||
|
||||
if self.sane_index_shape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
||||
)
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||
indices = self.unmap_to_all(indices)
|
||||
indices = indices.reshape(-1) # flatten again
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = self.embedding(indices)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1,5 +1,6 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
@ -20,6 +21,6 @@ def is_nsfw(img, x_sample):
|
||||
clip_input = safety_checker_input.pixel_values
|
||||
|
||||
_, has_nsfw_concept = safety_checker(
|
||||
images=x_sample[None, :], clip_input=clip_input
|
||||
images=[np.empty((2, 2))], clip_input=clip_input
|
||||
)
|
||||
return has_nsfw_concept[0]
|
||||
|
BIN
tests/data/distorted_face.png
Normal file
BIN
tests/data/distorted_face.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 286 KiB |
BIN
tests/data/safety.jpg
Normal file
BIN
tests/data/safety.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 8.6 KiB |
13
tests/test_cmds.py
Normal file
13
tests/test_cmds.py
Normal file
@ -0,0 +1,13 @@
|
||||
from click.testing import CliRunner
|
||||
|
||||
from imaginairy.cmds import imagine_cmd
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
|
||||
def test_imagine_cmd():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
imagine_cmd,
|
||||
["gold coins", "--steps", "5", "--outdir", f"{TESTS_FOLDER}/test_output"],
|
||||
)
|
||||
assert result.exit_code == 0
|
23
tests/test_enhancers.py
Normal file
23
tests/test_enhancers.py
Normal file
@ -0,0 +1,23 @@
|
||||
import hashlib
|
||||
|
||||
from PIL import Image
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.utils import get_device
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
|
||||
def test_fix_faces():
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/distorted_face.png")
|
||||
seed_everything(1)
|
||||
img = enhance_faces(img)
|
||||
img.save(f"{TESTS_FOLDER}/test_output/fixed_face.png")
|
||||
if "mps" in get_device():
|
||||
assert img_hash(img) == "a75991307eda675a26eeb7073f828e93"
|
||||
else:
|
||||
assert img_hash(img) == "5aa847a1464de75b158658a35800b6bf"
|
||||
|
||||
|
||||
def img_hash(img):
|
||||
return hashlib.md5(img.tobytes()).hexdigest()
|
481
tests/test_guidance.py
Normal file
481
tests/test_guidance.py
Normal file
@ -0,0 +1,481 @@
|
||||
#
|
||||
# import torch
|
||||
#
|
||||
# from imaginairy.utils import get_device
|
||||
#
|
||||
# torch.manual_seed(0)
|
||||
# from transformers import CLIPTextModel, CLIPTokenizer
|
||||
# from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
|
||||
#
|
||||
# from tqdm.auto import tqdm, trange
|
||||
# from torch import autocast
|
||||
# import PIL.Image as PImage
|
||||
# from PIL import Image
|
||||
# import numpy
|
||||
# from torchvision import transforms
|
||||
# import torchvision.transforms.functional as f
|
||||
# import random
|
||||
# import requests
|
||||
# from io import BytesIO
|
||||
#
|
||||
# # import clip
|
||||
# import open_clip as clip
|
||||
# from torch import nn
|
||||
# import torch.nn.functional as F
|
||||
# import io
|
||||
#
|
||||
# offload_device = "cpu"
|
||||
# model_name = "CompVis/stable-diffusion-v1-4"
|
||||
# attention_slicing = True #@param {"type":"boolean"}
|
||||
# unet_path = False
|
||||
#
|
||||
# vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", use_auth_token=True)
|
||||
#
|
||||
# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
# try:
|
||||
# text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", use_auth_token=True)
|
||||
# except:
|
||||
# print("Text encoder could not be loaded from the repo specified for some reason, falling back to the vit-l repo")
|
||||
# text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
#
|
||||
# if unet_path!=None:
|
||||
# # unet = UNet2DConditionModel.from_pretrained(unet_path)
|
||||
# from huggingface_hub import hf_hub_download
|
||||
# model_name = hf_hub_download(repo_id=unet_path, filename="unet.pt")
|
||||
# unet = torch.jit.load(model_name)
|
||||
# else:
|
||||
# unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", use_auth_token=True)
|
||||
# if attention_slicing:
|
||||
# slice_size = unet.config.attention_head_dim // 2
|
||||
# unet.set_attention_slice(slice_size)
|
||||
#
|
||||
# scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
||||
#
|
||||
# vae = vae.to(offload_device).half()
|
||||
# text_encoder = text_encoder.to(offload_device).half()
|
||||
# unet = unet.to(get_device()).half()
|
||||
# class MakeCutouts(nn.Module):
|
||||
# def __init__(self, cut_size, cutn, cut_pow=1.):
|
||||
# super().__init__()
|
||||
# self.cut_size = cut_size
|
||||
# self.cutn = cutn
|
||||
# self.cut_pow = cut_pow
|
||||
#
|
||||
# def forward(self, input):
|
||||
# sideY, sideX = input.shape[2:4]
|
||||
# max_size = min(sideX, sideY)
|
||||
# min_size = min(sideX, sideY, self.cut_size)
|
||||
# cutouts = []
|
||||
# for _ in range(self.cutn):
|
||||
# size = int(torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size)
|
||||
# offsetx = torch.randint(0, sideX - size + 1, ())
|
||||
# offsety = torch.randint(0, sideY - size + 1, ())
|
||||
# cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
|
||||
# cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
|
||||
# return torch.cat(cutouts)
|
||||
#
|
||||
#
|
||||
# to_tensor_tfm = transforms.ToTensor()
|
||||
#
|
||||
#
|
||||
# # mismatch of tons of image encoding / decoding / loading functions i cant be asked to clean up right now
|
||||
#
|
||||
# def pil_to_latent(input_im):
|
||||
# # Single image -> single latent in a batch (so size 1, 4, 64, 64)
|
||||
# with torch.no_grad():
|
||||
# with autocast("cuda"):
|
||||
# latent = vae.encode(to_tensor_tfm(input_im.convert("RGB")).unsqueeze(0).to(
|
||||
# get_device()) * 2 - 1).latent_dist # Note scaling
|
||||
# # print(latent)
|
||||
# return 0.18215 * latent.mode() # or .mean or .sample
|
||||
#
|
||||
#
|
||||
# def latents_to_pil(latents):
|
||||
# # bath of latents -> list of images
|
||||
# latents = (1 / 0.18215) * latents
|
||||
# with torch.no_grad():
|
||||
# image = vae.decode(latents)
|
||||
# image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
# images = (image * 255).round().astype("uint8")
|
||||
# pil_images = [Image.fromarray(image) for image in images]
|
||||
# return pil_images
|
||||
#
|
||||
#
|
||||
# def get_latent_from_url(url, size=(512, 512)):
|
||||
# response = requests.get(url)
|
||||
# img = PImage.open(BytesIO(response.content))
|
||||
# img = img.resize(size).convert("RGB")
|
||||
# latent = pil_to_latent(img)
|
||||
# return latent
|
||||
#
|
||||
#
|
||||
# def scale_and_decode(latents):
|
||||
# with autocast("cuda"):
|
||||
# # scale and decode the image latents with vae
|
||||
# latents = 1 / 0.18215 * latents
|
||||
# with torch.no_grad():
|
||||
# image = vae.decode(latents).sample.squeeze(0)
|
||||
# image = f.to_pil_image((image / 2 + 0.5).clamp(0, 1))
|
||||
# return image
|
||||
#
|
||||
#
|
||||
# def fetch(url_or_path):
|
||||
# import io
|
||||
# if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
|
||||
# r = requests.get(url_or_path)
|
||||
# r.raise_for_status()
|
||||
# fd = io.BytesIO()
|
||||
# fd.write(r.content)
|
||||
# fd.seek(0)
|
||||
# return PImage.open(fd).convert('RGB')
|
||||
# return PImage.open(open(url_or_path, 'rb')).convert('RGB')
|
||||
#
|
||||
#
|
||||
# """
|
||||
# grabs all text up to the first occurrence of ':'
|
||||
# uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
# if ':' has no value defined, defaults to 1.0
|
||||
# repeats until no text remaining
|
||||
# """
|
||||
#
|
||||
#
|
||||
# def split_weighted_subprompts(text, split=":"):
|
||||
# remaining = len(text)
|
||||
# prompts = []
|
||||
# weights = []
|
||||
# while remaining > 0:
|
||||
# if split in text:
|
||||
# idx = text.index(split) # first occurrence from start
|
||||
# # grab up to index as sub-prompt
|
||||
# prompt = text[:idx]
|
||||
# remaining -= idx
|
||||
# # remove from main text
|
||||
# text = text[idx + 1:]
|
||||
# # find value for weight
|
||||
# if " " in text:
|
||||
# idx = text.index(" ") # first occurence
|
||||
# else: # no space, read to end
|
||||
# idx = len(text)
|
||||
# if idx != 0:
|
||||
# try:
|
||||
# weight = float(text[:idx])
|
||||
# except: # couldn't treat as float
|
||||
# print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
|
||||
# weight = 1.0
|
||||
# else: # no value found
|
||||
# weight = 1.0
|
||||
# # remove from main text
|
||||
# remaining -= idx
|
||||
# text = text[idx + 1:]
|
||||
# # append the sub-prompt and its weight
|
||||
# prompts.append(prompt)
|
||||
# weights.append(weight)
|
||||
# else: # no : found
|
||||
# if len(text) > 0: # there is still text though
|
||||
# # take remainder as weight 1
|
||||
# prompts.append(text)
|
||||
# weights.append(1.0)
|
||||
# remaining = 0
|
||||
# print(prompts, weights)
|
||||
# return prompts, weights
|
||||
#
|
||||
#
|
||||
# # from some stackoverflow comment
|
||||
# import numpy as np
|
||||
#
|
||||
#
|
||||
# def lerp(a, b, x):
|
||||
# "linear interpolation"
|
||||
# return a + x * (b - a)
|
||||
#
|
||||
#
|
||||
# def fade(t):
|
||||
# "6t^5 - 15t^4 + 10t^3"
|
||||
# return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
|
||||
#
|
||||
#
|
||||
# def gradient(h, x, y):
|
||||
# "grad converts h to the right gradient vector and return the dot product with (x,y)"
|
||||
# vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
|
||||
# g = vectors[h % 4]
|
||||
# return g[:, :, 0] * x + g[:, :, 1] * y
|
||||
#
|
||||
#
|
||||
# def perlin(x, y, seed=0):
|
||||
# # permutation table
|
||||
# np.random.seed(seed)
|
||||
# p = np.arange(256, dtype=int)
|
||||
# np.random.shuffle(p)
|
||||
# p = np.stack([p, p]).flatten()
|
||||
# # coordinates of the top-left
|
||||
# xi, yi = x.astype(int), y.astype(int)
|
||||
# # internal coordinates
|
||||
# xf, yf = x - xi, y - yi
|
||||
# # fade factors
|
||||
# u, v = fade(xf), fade(yf)
|
||||
# # noise components
|
||||
# n00 = gradient(p[p[xi] + yi], xf, yf)
|
||||
# n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
|
||||
# n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
|
||||
# n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
|
||||
# # combine noises
|
||||
# x1 = lerp(n00, n10, u)
|
||||
# x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
|
||||
# return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
|
||||
#
|
||||
#
|
||||
# def sample(args):
|
||||
# global in_channels
|
||||
# global text_encoder # uugghhhghhghgh
|
||||
# global vae # UUGHGHHGHGH
|
||||
# global unet # .hggfkgjks;ldjf
|
||||
# # prompt = args.prompt
|
||||
# prompts, weights = split_weighted_subprompts(args.prompt)
|
||||
# h, w = args.size
|
||||
# steps = args.steps
|
||||
# scale = args.scale
|
||||
# classifier_guidance = args.classifier_guidance
|
||||
# use_init = len(args.init_img) > 1
|
||||
# if args.seed != -1:
|
||||
# seed = args.seed
|
||||
# generator = torch.manual_seed(seed)
|
||||
# else:
|
||||
# seed = random.randint(0, 10_000)
|
||||
# generator = torch.manual_seed(seed)
|
||||
# print(f"Generating with seed {seed}...")
|
||||
#
|
||||
# # tokenize / encode text
|
||||
# tokens = [tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
|
||||
# return_tensors="pt") for prompt in prompts]
|
||||
# with torch.no_grad():
|
||||
# # move CLIP to cuda
|
||||
# text_encoder = text_encoder.to(get_device())
|
||||
# text_embeddings = [text_encoder(tok.input_ids.to(get_device()))[0].unsqueeze(0) for tok in tokens]
|
||||
# text_embeddings = [text_embeddings[i] * weights[i] for i in range(len(text_embeddings))]
|
||||
# text_embeddings = torch.cat(text_embeddings, 0).sum(0)
|
||||
# max_length = 77
|
||||
# uncond_input = tokenizer(
|
||||
# [""], padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
# )
|
||||
# uncond_embeddings = text_encoder(uncond_input.input_ids.to(get_device()))[0]
|
||||
# text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
# # move it back to CPU so there's more vram for generating
|
||||
# text_encoder = text_encoder.to(offload_device)
|
||||
# images = []
|
||||
#
|
||||
# if args.lpips_guidance:
|
||||
# import lpips
|
||||
# lpips_model = lpips.LPIPS(net='vgg').to(get_device())
|
||||
# init = to_tensor_tfm(fetch(args.init_img).resize(args.size)).to(get_device())
|
||||
#
|
||||
# for batch_n in trange(args.batches):
|
||||
# with autocast("cuda"):
|
||||
# # unet = unet.to(get_device())
|
||||
# scheduler.set_timesteps(steps)
|
||||
# if not use_init or args.start_step == 0:
|
||||
# latents = torch.randn(
|
||||
# (1, in_channels, h // 8, w // 8),
|
||||
# generator=generator
|
||||
# )
|
||||
# latents = latents.to(get_device())
|
||||
# latents = latents * scheduler.sigmas[0]
|
||||
# start_step = args.start_step
|
||||
# else:
|
||||
# # Start step
|
||||
# start_step = args.start_step - 1
|
||||
# start_sigma = scheduler.sigmas[start_step]
|
||||
# start_timestep = int(scheduler.timesteps[start_step])
|
||||
#
|
||||
# # Prep latents
|
||||
# vae = vae.to(get_device())
|
||||
# encoded = get_latent_from_url(args.init_img)
|
||||
# if not classifier_guidance:
|
||||
# vae = vae.to(offload_device)
|
||||
#
|
||||
# noise = torch.randn_like(encoded)
|
||||
# sigmas = scheduler.match_shape(scheduler.sigmas[start_step], noise)
|
||||
# noisy_samples = encoded + noise * sigmas
|
||||
#
|
||||
# latents = noisy_samples.to(get_device()).half()
|
||||
#
|
||||
# if args.perlin_multi != 0:
|
||||
# linx = np.linspace(0, 5, h // 8, endpoint=False)
|
||||
# liny = np.linspace(0, 5, w // 8, endpoint=False)
|
||||
# x, y = np.meshgrid(liny, linx)
|
||||
# p = [np.expand_dims(perlin(x, y, seed=i), 0) for i in range(4)] # reproducable seed
|
||||
# p = np.concatenate(p, 0)
|
||||
# p = torch.tensor(p).unsqueeze(0).cuda()
|
||||
# latents = latents + (p * args.perlin_multi).to(get_device()).half()
|
||||
#
|
||||
# for i, t in tqdm(enumerate(scheduler.timesteps), total=steps):
|
||||
# if i > start_step:
|
||||
# latent_model_input = torch.cat([latents] * 2)
|
||||
# sigma = scheduler.sigmas[i]
|
||||
# latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5)
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# # noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
# # noise_pred = unet(latent_model_input, torch.tensor(t, dtype=torch.float32).cuda().half(), text_embeddings)#["sample"]
|
||||
# noise_pred = unet(latent_model_input, torch.tensor(t, dtype=torch.float32).cuda(),
|
||||
# text_embeddings) # ["sample"]
|
||||
#
|
||||
# # cfg
|
||||
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
# noise_pred = noise_pred_uncond + scale * (noise_pred_text - noise_pred_uncond)
|
||||
#
|
||||
# # cg
|
||||
# if classifier_guidance:
|
||||
# # vae = vae.to(get_device())
|
||||
# if vae.device != latents.device:
|
||||
# vae = vae.to(latents.device)
|
||||
# latents = latents.detach().requires_grad_()
|
||||
# latents_x0 = latents - sigma * noise_pred
|
||||
# denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
|
||||
# if args.loss_scale != 0:
|
||||
# loss = args.loss_fn(denoised_images) * args.loss_scale
|
||||
# else:
|
||||
# loss = 0
|
||||
# init_losses = lpips_model(denoised_images, init)
|
||||
# loss = loss + init_losses.sum() * args.lpips_scale
|
||||
#
|
||||
# cond_grad = -torch.autograd.grad(loss, latents)[0]
|
||||
# latents = latents.detach() + cond_grad * sigma ** 2
|
||||
# # vae = vae.to(offload_device)
|
||||
#
|
||||
# latents = scheduler.step(noise_pred, i, latents)["prev_sample"]
|
||||
# vae = vae.to(get_device())
|
||||
# output_image = scale_and_decode(latents)
|
||||
# vae = vae.to(offload_device)
|
||||
# images.append(output_image)
|
||||
#
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
#
|
||||
# images[-1].save(f"output/{batch_n}.png")
|
||||
#
|
||||
# return images
|
||||
#
|
||||
# def test_guided_image():
|
||||
# prompt = "tardigrade portrait [intricate] [artstation]" # @param {"type":"string"}
|
||||
#
|
||||
# # prompt = add_suffixes(prompt)
|
||||
#
|
||||
# init_img = "" # @param {"type":"string"}
|
||||
# size = [640, 640] # @param
|
||||
# steps = 65 # @param
|
||||
# start_step = 0 # @param
|
||||
# perlin_multi = 0.4 # @param
|
||||
# scale = 7 # @param
|
||||
# seed = -1 # @param
|
||||
# batches = 4 # @param
|
||||
# # @markdown ---
|
||||
#
|
||||
# # @markdown ### Classifier Guidance
|
||||
# # @markdown `classifier_guidance` is whether or not to use the loss function in the previous cell to guide the image (slows down image generation a lot) <br>
|
||||
# # @markdown it also is very hit-and-miss in terms of quality, but can be really really good, try setting batches high and then taking a nap <br>
|
||||
# # @markdown `lpips_guidance` is for if you're using an init_img, it'll let you start closer to the beginning while trying to keep the overall shapes similar
|
||||
# # @markdown `lpips_scale` is similar to `loss_scale` but it's how much to push the model to keep the shapes the same <br>
|
||||
# # @markdown `loss_scale` is how much to guide according to that loss function <br>
|
||||
# # @markdown `clip_text_prompt` is a prompt for CLIP to optimize towards, if using classifier guidance (supports weighting with `prompt:weight`) <br>
|
||||
# # @markdown `clip_image_prompt` is an image url for CLIP to optimize towards if using classifier guidance (supports weighting with `url|weight` because of colons coming up in urls) <br>
|
||||
# # @markdown for `clip_model_name` and `clip_model_pretrained` check out the openclip repository https://github.com/mlfoundations/open_clip <br>
|
||||
# # @markdown `cutn` is the amount of permutations of the image to show to clip (can help with stability) <br>
|
||||
# # @markdown `accumulate` is how many times to run the image through the clip model, can help if you can only fit low cutn on the machine <br>
|
||||
# # @markdown *you cannot use the textual inversion tokens with the clip text prompt* <br>
|
||||
# # @markdown *also clip guidance sucks for most things except removing very small details that dont make sense*
|
||||
# classifier_guidance = True # @param {"type":"boolean"}
|
||||
# lpips_guidance = False # @param {"type":"boolean"}
|
||||
# lpips_scale = 0 # @param
|
||||
# loss_scale = 1. # @param
|
||||
#
|
||||
# class BlankClass():
|
||||
# def __init__(self):
|
||||
# bruh = 'BRUH'
|
||||
#
|
||||
# args = BlankClass()
|
||||
# args.prompt = prompt
|
||||
# args.init_img = init_img
|
||||
# args.size = size
|
||||
# args.steps = steps
|
||||
# args.start_step = start_step
|
||||
# args.scale = scale
|
||||
# args.perlin_multi = perlin_multi
|
||||
# args.seed = seed
|
||||
# args.batches = batches
|
||||
# args.classifier_guidance = classifier_guidance
|
||||
# args.lpips_guidance = lpips_guidance
|
||||
# args.lpips_scale = lpips_scale
|
||||
# args.loss_scale = loss_scale
|
||||
#
|
||||
# loss_scale = 1
|
||||
# # make_cutouts = MakeCutouts(224, 16)
|
||||
#
|
||||
# clip_text_prompt = "tardigrade portrait [intricate] [artstation]" # @param {"type":"string"}
|
||||
# # clip_text_prompt = add_suffixes(clip_text_prompt)
|
||||
# clip_image_prompt = "" # @param {"type":"string"}
|
||||
#
|
||||
# if loss_scale != 0:
|
||||
# # clip_model = clip.load("ViT-B/32", jit=False)[0].eval().requires_grad_(False).to(get_device())
|
||||
# clip_model_name = "ViT-B-32" # @param {"type":"string"}
|
||||
# clip_model_pretrained = "laion2b_s34b_b79k" # @param {"type":"string"}
|
||||
# clip_model, _, preprocess = clip.create_model_and_transforms(clip_model_name, pretrained=clip_model_pretrained)
|
||||
# clip_model = clip_model.eval().requires_grad_(False).to(get_device())
|
||||
#
|
||||
# cutn = 4 # @param
|
||||
# make_cutouts = MakeCutouts(clip_model.visual.image_size if type(clip_model.visual.image_size) != tuple else
|
||||
# clip_model.visual.image_size[0], cutn)
|
||||
#
|
||||
# target = None
|
||||
# if len(clip_text_prompt) > 1:
|
||||
# clip_text_prompt, clip_text_weights = split_weighted_subprompts(clip_text_prompt)
|
||||
# target = clip_model.encode_text(clip.tokenize(clip_text_prompt).to(get_device())) * torch.tensor(
|
||||
# clip_text_weights).view(len(clip_text_prompt), 1).to(get_device())
|
||||
# if len(clip_image_prompt) > 1:
|
||||
# clip_image_prompt, clip_image_weights = split_weighted_subprompts(clip_image_prompt, split="|")
|
||||
# # pesky spaces
|
||||
# clip_image_prompt = [p.replace(" ", "") for p in clip_image_prompt]
|
||||
# images = [fetch(image) for image in clip_image_prompt]
|
||||
# images = [f.to_tensor(i).unsqueeze(0) for i in images]
|
||||
# images = [make_cutouts(i) for i in images]
|
||||
# encodings = [clip_model.encode_image(i.to(get_device())).mean(0) for i in images]
|
||||
#
|
||||
# for i in range(len(encodings)):
|
||||
# encodings[i] = (encodings[i] * clip_image_weights[i]).unsqueeze(0)
|
||||
# # print(encodings.shape)
|
||||
# encodings = torch.cat(encodings, 0)
|
||||
# encoding = encodings.sum(0)
|
||||
#
|
||||
# if target != None:
|
||||
# target = target + encoding
|
||||
# else:
|
||||
# target = encoding
|
||||
# target = target.half().to(get_device())
|
||||
#
|
||||
# # free a little memory, we dont use the text encoder after this so just delete it
|
||||
# clip_model.transformer = None
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
#
|
||||
# def spherical_distance(x, y):
|
||||
# x = F.normalize(x, dim=-1)
|
||||
# y = F.normalize(y, dim=-1)
|
||||
# l = (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean()
|
||||
# return l
|
||||
#
|
||||
# def loss_fn(x):
|
||||
# with torch.autocast("cuda"):
|
||||
# cutouts = make_cutouts(x)
|
||||
# encoding = clip_model.encode_image(cutouts.float()).half()
|
||||
# loss = spherical_distance(encoding, target)
|
||||
# return loss.mean()
|
||||
#
|
||||
# args.loss_fn = loss_fn
|
||||
#
|
||||
#
|
||||
# dtype = torch.float16
|
||||
# with torch.amp.autocast(device_type=get_device(), dtype=dtype):
|
||||
# output = sample(args)
|
||||
# print("Done!")
|
24
tests/test_safety.py
Normal file
24
tests/test_safety.py
Normal file
@ -0,0 +1,24 @@
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy.api import load_model
|
||||
from imaginairy.safety import is_nsfw
|
||||
from imaginairy.utils import get_device, pillow_img_to_torch_image
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
|
||||
def test_is_nsfw():
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
|
||||
latent = _pil_to_latent(img)
|
||||
assert is_nsfw(img, latent)
|
||||
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
||||
latent = _pil_to_latent(img)
|
||||
assert not is_nsfw(img, latent)
|
||||
|
||||
|
||||
def _pil_to_latent(img):
|
||||
model = load_model()
|
||||
img, w, h = pillow_img_to_torch_image(img)
|
||||
img = img.to(get_device())
|
||||
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
|
||||
return latent
|
Loading…
Reference in New Issue
Block a user