mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
Merge pull request #22 from brycedrennan/pylama_lint_updates
refactor: implements changes to comply with pylama
This commit is contained in:
commit
eb54b9ca7f
@ -280,7 +280,6 @@ def imagine(
|
||||
unconditional_guidance_scale=prompt.prompt_strength,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
initial_noise_tensor=start_code,
|
||||
img_callback=_img_callback,
|
||||
)
|
||||
|
||||
|
@ -30,7 +30,7 @@ def find_img_text_similarity(image: Image.Image, phrases: Sequence):
|
||||
|
||||
|
||||
def find_embed_text_similarity(embed_features, phrases):
|
||||
model, preprocess = get_model()
|
||||
model, _ = get_model()
|
||||
text = clip.tokenize(phrases).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
@ -11,8 +11,8 @@ from functools import partial
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torchvision.utils import make_grid
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -28,9 +28,11 @@ logger = logging.getLogger(__name__)
|
||||
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
def disabled_train(self):
|
||||
"""
|
||||
Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
@ -265,6 +267,7 @@ class LatentDiffusion(DDPM):
|
||||
self.instantiate_first_stage(first_stage_config)
|
||||
self.instantiate_cond_stage(cond_stage_config)
|
||||
self.cond_stage_forward = cond_stage_forward
|
||||
self.cond_ids = None
|
||||
self.clip_denoised = False
|
||||
self.bbox_tokenizer = None
|
||||
|
||||
@ -345,16 +348,10 @@ class LatentDiffusion(DDPM):
|
||||
model = instantiate_from_config(config)
|
||||
self.cond_stage_model = model
|
||||
|
||||
def _get_denoise_row_from_list(
|
||||
self, samples, desc="", force_no_decoder_quantization=False
|
||||
):
|
||||
def _get_denoise_row_from_list(self, samples, desc=""):
|
||||
denoise_row = []
|
||||
for zd in tqdm(samples, desc=desc):
|
||||
denoise_row.append(
|
||||
self.decode_first_stage(
|
||||
zd.to(self.device), force_not_quantize=force_no_decoder_quantization
|
||||
)
|
||||
)
|
||||
denoise_row.append(self.decode_first_stage(zd.to(self.device)))
|
||||
n_imgs_per_row = len(denoise_row)
|
||||
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
||||
denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
|
||||
@ -547,7 +544,7 @@ class LatentDiffusion(DDPM):
|
||||
else:
|
||||
xc = x
|
||||
if not self.cond_stage_trainable or force_c_encode:
|
||||
if isinstance(xc, dict) or isinstance(xc, list):
|
||||
if isinstance(xc, (dict, list)):
|
||||
# import pudb; pudb.set_trace()
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
@ -577,7 +574,7 @@ class LatentDiffusion(DDPM):
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
||||
def decode_first_stage(self, z, predict_cids=False):
|
||||
if predict_cids:
|
||||
if z.dim() == 4:
|
||||
z = torch.argmax(z.exp(), dim=1).long()
|
||||
|
@ -262,9 +262,9 @@ def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
if dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
if dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
@ -13,9 +13,10 @@ from einops import repeat
|
||||
from torch import autocast
|
||||
|
||||
from imaginairy.utils import get_device, pillow_img_to_torch_image
|
||||
from imaginairy.vendored import k_diffusion as K
|
||||
|
||||
|
||||
def pil_img_to_latent(model, img, batch_size=1, device="cuda", half=True):
|
||||
def pil_img_to_latent(model, img, batch_size=1, half=True):
|
||||
# init_image = pil_img_to_torch(img, half=half).to(device)
|
||||
init_image = pillow_img_to_torch_image(img).to(get_device())
|
||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||
@ -40,11 +41,7 @@ def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=
|
||||
)
|
||||
|
||||
|
||||
def find_noise_for_latent(
|
||||
model, img_latent, prompt, steps=50, cond_scale=1.0, half=True
|
||||
):
|
||||
from imaginairy.vendored import k_diffusion as K
|
||||
|
||||
def find_noise_for_latent(model, img_latent, prompt, steps=50, cond_scale=1.0):
|
||||
x = img_latent
|
||||
|
||||
_autocast = autocast if get_device() in ("cuda", "cpu") else nullcontext
|
||||
|
@ -15,7 +15,7 @@ def safety_models():
|
||||
return safety_feature_extractor, safety_checker
|
||||
|
||||
|
||||
def is_nsfw(img, x_sample):
|
||||
def is_nsfw(img):
|
||||
safety_feature_extractor, safety_checker = safety_models()
|
||||
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
|
||||
clip_input = safety_checker_input.pixel_values
|
||||
|
@ -32,11 +32,12 @@ def get_sampler(sampler_type, model):
|
||||
sampler_type = sampler_type.lower()
|
||||
if sampler_type == "plms":
|
||||
return PLMSSampler(model)
|
||||
elif sampler_type == "ddim":
|
||||
if sampler_type == "ddim":
|
||||
return DDIMSampler(model)
|
||||
elif sampler_type.startswith("k_"):
|
||||
if sampler_type.startswith("k_"):
|
||||
sampler_type = _k_sampler_type_lookup[sampler_type]
|
||||
return KDiffusionSampler(model, sampler_type)
|
||||
raise ValueError("invalid sampler_type")
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
|
@ -127,8 +127,8 @@ class DDIMSampler:
|
||||
x_T=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
):
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
@ -376,7 +376,7 @@ class DDIMSampler:
|
||||
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
|
||||
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
|
||||
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
|
||||
## x_dec_alt = x_dec + (original_loss * 0.1) ** 2
|
||||
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
|
||||
|
||||
log_latent(x_dec, f"x_dec {i}")
|
||||
log_latent(pred_x0, f"pred_x0 {i}")
|
||||
|
@ -20,13 +20,14 @@ logger = logging.getLogger(__name__)
|
||||
class PLMSSampler:
|
||||
"""probabilistic least-mean-squares"""
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.device_available = get_device()
|
||||
self.ddim_timesteps = None
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if isinstance(attr, torch.Tensor):
|
||||
if attr.device != torch.device(self.device_available):
|
||||
attr = attr.to(torch.float32).to(torch.device(self.device_available))
|
||||
setattr(self, name, attr)
|
||||
@ -43,7 +44,9 @@ class PLMSSampler:
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
def to_torch(x):
|
||||
return x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(self.model.betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
@ -431,7 +434,7 @@ class PLMSSampler:
|
||||
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
|
||||
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
|
||||
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
|
||||
## x_dec_alt = x_dec + (original_loss * 0.1) ** 2
|
||||
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
|
||||
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
|
@ -9,11 +9,11 @@ from tests import TESTS_FOLDER
|
||||
def test_is_nsfw():
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
|
||||
latent = _pil_to_latent(img)
|
||||
assert is_nsfw(img, latent)
|
||||
assert is_nsfw(img)
|
||||
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
||||
latent = _pil_to_latent(img)
|
||||
assert not is_nsfw(img, latent)
|
||||
assert not is_nsfw(img)
|
||||
|
||||
|
||||
def _pil_to_latent(img):
|
||||
|
Loading…
Reference in New Issue
Block a user