refactor: implements changes to comply with pylama

This commit is contained in:
Jay Drennan 2022-09-22 11:56:18 -06:00
parent 8e844f2eae
commit df28bf8805
10 changed files with 33 additions and 30 deletions

View File

@ -280,7 +280,6 @@ def imagine(
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
eta=ddim_eta,
initial_noise_tensor=start_code,
img_callback=_img_callback,
)

View File

@ -30,7 +30,7 @@ def find_img_text_similarity(image: Image.Image, phrases: Sequence):
def find_embed_text_similarity(embed_features, phrases):
model, preprocess = get_model()
model, _ = get_model()
text = clip.tokenize(phrases).to(device)
with torch.no_grad():

View File

@ -11,7 +11,7 @@ from functools import partial
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch import nn
from einops import rearrange
from torchvision.utils import make_grid
from tqdm import tqdm
@ -28,9 +28,11 @@ logger = logging.getLogger(__name__)
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
def disabled_train(self):
"""
Overwrite model.train with this function to make sure train/eval mode
does not change anymore.
"""
return self
@ -265,6 +267,7 @@ class LatentDiffusion(DDPM):
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.cond_ids = None
self.clip_denoised = False
self.bbox_tokenizer = None
@ -346,13 +349,13 @@ class LatentDiffusion(DDPM):
self.cond_stage_model = model
def _get_denoise_row_from_list(
self, samples, desc="", force_no_decoder_quantization=False
self, samples, desc=""
):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(
self.decode_first_stage(
zd.to(self.device), force_not_quantize=force_no_decoder_quantization
zd.to(self.device)
)
)
n_imgs_per_row = len(denoise_row)
@ -547,7 +550,7 @@ class LatentDiffusion(DDPM):
else:
xc = x
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
if isinstance(xc, (dict, list)):
# import pudb; pudb.set_trace()
c = self.get_learned_conditioning(xc)
else:
@ -577,7 +580,7 @@ class LatentDiffusion(DDPM):
return out
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
def decode_first_stage(self, z, predict_cids=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()

View File

@ -262,9 +262,9 @@ def conv_nd(dims, *args, **kwargs):
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
if dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
if dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")

View File

@ -13,9 +13,9 @@ from einops import repeat
from torch import autocast
from imaginairy.utils import get_device, pillow_img_to_torch_image
from imaginairy.vendored import k_diffusion as K
def pil_img_to_latent(model, img, batch_size=1, device="cuda", half=True):
def pil_img_to_latent(model, img, batch_size=1, half=True):
# init_image = pil_img_to_torch(img, half=half).to(device)
init_image = pillow_img_to_torch_image(img).to(get_device())
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
@ -40,11 +40,7 @@ def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=
)
def find_noise_for_latent(
model, img_latent, prompt, steps=50, cond_scale=1.0, half=True
):
from imaginairy.vendored import k_diffusion as K
def find_noise_for_latent(model, img_latent, prompt, steps=50, cond_scale=1.0):
x = img_latent
_autocast = autocast if get_device() in ("cuda", "cpu") else nullcontext

View File

@ -15,7 +15,7 @@ def safety_models():
return safety_feature_extractor, safety_checker
def is_nsfw(img, x_sample):
def is_nsfw(img):
safety_feature_extractor, safety_checker = safety_models()
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
clip_input = safety_checker_input.pixel_values

View File

@ -3,6 +3,7 @@ from torch import nn
from imaginairy.utils import get_device
SAMPLER_TYPE_OPTIONS = [
"plms",
"ddim",
@ -32,11 +33,12 @@ def get_sampler(sampler_type, model):
sampler_type = sampler_type.lower()
if sampler_type == "plms":
return PLMSSampler(model)
elif sampler_type == "ddim":
if sampler_type == "ddim":
return DDIMSampler(model)
elif sampler_type.startswith("k_"):
if sampler_type.startswith("k_"):
sampler_type = _k_sampler_type_lookup[sampler_type]
return KDiffusionSampler(model, sampler_type)
raise ValueError("invalid sampler_type")
class CFGDenoiser(nn.Module):

View File

@ -127,8 +127,8 @@ class DDIMSampler:
x_T=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
):
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
@ -376,7 +376,7 @@ class DDIMSampler:
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
## x_dec_alt = x_dec + (original_loss * 0.1) ** 2
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
log_latent(x_dec, f"x_dec {i}")
log_latent(pred_x0, f"pred_x0 {i}")

View File

@ -20,13 +20,14 @@ logger = logging.getLogger(__name__)
class PLMSSampler:
"""probabilistic least-mean-squares"""
def __init__(self, model, **kwargs):
def __init__(self, model):
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.device_available = get_device()
self.ddim_timesteps = None
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if isinstance(attr, torch.Tensor):
if attr.device != torch.device(self.device_available):
attr = attr.to(torch.float32).to(torch.device(self.device_available))
setattr(self, name, attr)
@ -43,7 +44,9 @@ class PLMSSampler:
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
@ -431,7 +434,7 @@ class PLMSSampler:
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
## x_dec_alt = x_dec + (original_loss * 0.1) ** 2
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
old_eps.append(e_t)
if len(old_eps) >= 4:

View File

@ -9,11 +9,11 @@ from tests import TESTS_FOLDER
def test_is_nsfw():
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
latent = _pil_to_latent(img)
assert is_nsfw(img, latent)
assert is_nsfw(img)
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
latent = _pil_to_latent(img)
assert not is_nsfw(img, latent)
assert not is_nsfw(img)
def _pil_to_latent(img):