mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-17 09:25:47 +00:00
311 lines
11 KiB
Python
311 lines
11 KiB
Python
from functools import lru_cache
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from imaginairy.log_utils import log_latent
|
|
from imaginairy.model_manager import hf_hub_download
|
|
from imaginairy.utils import get_device, platform_appropriate_autocast
|
|
from imaginairy.vendored import k_diffusion as K
|
|
from imaginairy.vendored.k_diffusion import layers
|
|
from imaginairy.vendored.k_diffusion.models.image_v1 import ImageDenoiserModelV1
|
|
from imaginairy.vendored.k_diffusion.utils import append_dims
|
|
|
|
|
|
class NoiseLevelAndTextConditionedUpscaler(nn.Module):
|
|
def __init__(self, inner_model, sigma_data=1.0, embed_dim=256):
|
|
super().__init__()
|
|
self.inner_model = inner_model
|
|
self.sigma_data = sigma_data
|
|
self.low_res_noise_embed = K.layers.FourierFeatures(1, embed_dim, std=2)
|
|
|
|
def forward(self, inp, sigma, low_res, low_res_sigma, c, **kwargs):
|
|
cross_cond, cross_cond_padding, pooler = c
|
|
sigma_data = self.sigma_data
|
|
# 'MPS does not support power op with int64 input'
|
|
if isinstance(low_res_sigma, torch.Tensor):
|
|
low_res_sigma = low_res_sigma.to(torch.float32)
|
|
if isinstance(sigma_data, torch.Tensor):
|
|
sigma_data = sigma_data.to(torch.float32)
|
|
c_in = 1 / (low_res_sigma**2 + sigma_data**2) ** 0.5
|
|
c_noise = low_res_sigma.log1p()[:, None]
|
|
c_in = append_dims(c_in, low_res.ndim)
|
|
low_res_noise_embed = self.low_res_noise_embed(c_noise)
|
|
low_res_in = F.interpolate(low_res, scale_factor=2, mode="nearest") * c_in
|
|
mapping_cond = torch.cat([low_res_noise_embed, pooler], dim=1)
|
|
return self.inner_model(
|
|
inp,
|
|
sigma,
|
|
unet_cond=low_res_in,
|
|
mapping_cond=mapping_cond,
|
|
cross_cond=cross_cond,
|
|
cross_cond_padding=cross_cond_padding,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_upscaler_model(
|
|
model_path,
|
|
pooler_dim=768,
|
|
train=False,
|
|
device=get_device(),
|
|
):
|
|
config = {
|
|
"type": "image_v1",
|
|
"input_channels": 4,
|
|
"input_size": [48, 48],
|
|
"patch_size": 1,
|
|
"mapping_out": 768,
|
|
"mapping_cond_dim": 896,
|
|
"unet_cond_dim": 4,
|
|
"depths": [4, 4, 4, 4],
|
|
"channels": [384, 384, 768, 768],
|
|
"self_attn_depths": [False, False, False, True],
|
|
"cross_attn_depths": [False, True, True, True],
|
|
"cross_cond_dim": 768,
|
|
"has_variance": True,
|
|
"dropout_rate": 0.0,
|
|
"augment_prob": 0.0,
|
|
"augment_wrapper": False,
|
|
"sigma_data": 1.0,
|
|
"sigma_min": 1e-2,
|
|
"sigma_max": 20,
|
|
"sigma_sample_density": {"type": "lognormal", "mean": -0.5, "std": 1.2},
|
|
"skip_stages": 0,
|
|
}
|
|
|
|
model = ImageDenoiserModelV1(
|
|
config["input_channels"],
|
|
config["mapping_out"],
|
|
config["depths"],
|
|
config["channels"],
|
|
config["self_attn_depths"],
|
|
config["cross_attn_depths"],
|
|
patch_size=config["patch_size"],
|
|
dropout_rate=config["dropout_rate"],
|
|
mapping_cond_dim=config["mapping_cond_dim"]
|
|
+ (9 if config["augment_wrapper"] else 0),
|
|
unet_cond_dim=config["unet_cond_dim"],
|
|
cross_cond_dim=config["cross_cond_dim"],
|
|
skip_stages=config["skip_stages"],
|
|
has_variance=config["has_variance"],
|
|
)
|
|
|
|
model = NoiseLevelAndTextConditionedUpscaler(
|
|
model,
|
|
sigma_data=config["sigma_data"],
|
|
embed_dim=config["mapping_cond_dim"] - pooler_dim,
|
|
)
|
|
ckpt = torch.load(model_path, map_location="cpu")
|
|
model.load_state_dict(ckpt["model_ema"])
|
|
model = layers.DenoiserWithVariance(model, sigma_data=config["sigma_data"])
|
|
if not train:
|
|
model = model.eval().requires_grad_(False)
|
|
return model.to(device)
|
|
|
|
|
|
class CFGUpscaler(nn.Module):
|
|
def __init__(self, model, uc, cond_scale, device):
|
|
super().__init__()
|
|
self.inner_model = model
|
|
self.uc = uc
|
|
self.cond_scale = cond_scale
|
|
self.device = device
|
|
|
|
def forward(self, x, sigma, low_res, low_res_sigma, c):
|
|
if self.cond_scale in (0.0, 1.0):
|
|
# Shortcut for when we don't need to run both.
|
|
if self.cond_scale == 0.0:
|
|
c_in = self.uc
|
|
elif self.cond_scale == 1.0:
|
|
c_in = c
|
|
return self.inner_model(
|
|
x, sigma, low_res=low_res, low_res_sigma=low_res_sigma, c=c_in
|
|
)
|
|
|
|
x_in = torch.cat([x] * 2)
|
|
sigma_in = torch.cat([sigma] * 2)
|
|
low_res_in = torch.cat([low_res] * 2)
|
|
low_res_sigma_in = torch.cat([low_res_sigma] * 2)
|
|
c_in = [torch.cat([uc_item, c_item]) for uc_item, c_item in zip(self.uc, c)]
|
|
uncond, cond = self.inner_model(
|
|
x_in, sigma_in, low_res=low_res_in, low_res_sigma=low_res_sigma_in, c=c_in
|
|
).chunk(2)
|
|
return uncond + (cond - uncond) * self.cond_scale
|
|
|
|
|
|
class CLIPTokenizerTransform:
|
|
def __init__(self, version="openai/clip-vit-large-patch14", max_length=77):
|
|
from transformers import CLIPTokenizer
|
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
|
self.max_length = max_length
|
|
|
|
def __call__(self, text):
|
|
indexer = 0 if isinstance(text, str) else ...
|
|
tok_out = self.tokenizer(
|
|
text,
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
return_length=True,
|
|
return_overflowing_tokens=False,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
input_ids = tok_out["input_ids"][indexer]
|
|
attention_mask = 1 - tok_out["attention_mask"][indexer]
|
|
return input_ids, attention_mask
|
|
|
|
|
|
class CLIPEmbedder(nn.Module):
|
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)."""
|
|
|
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda"):
|
|
super().__init__()
|
|
from transformers import CLIPTextModel, logging
|
|
|
|
logging.set_verbosity_error()
|
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
|
self.transformer = self.transformer.eval().requires_grad_(False).to(device)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.transformer.device
|
|
|
|
def forward(self, tok_out):
|
|
input_ids, cross_cond_padding = tok_out
|
|
clip_out = self.transformer(
|
|
input_ids=input_ids.to(self.device), output_hidden_states=True
|
|
)
|
|
return (
|
|
clip_out.hidden_states[-1],
|
|
cross_cond_padding.to(self.device),
|
|
clip_out.pooler_output,
|
|
)
|
|
|
|
|
|
@lru_cache()
|
|
def clip_up_models():
|
|
with platform_appropriate_autocast():
|
|
tok_up = CLIPTokenizerTransform()
|
|
text_encoder_up = CLIPEmbedder(device=get_device())
|
|
return text_encoder_up, tok_up
|
|
|
|
|
|
@torch.no_grad()
|
|
def condition_up(prompts):
|
|
text_encoder_up, tok_up = clip_up_models()
|
|
return text_encoder_up(tok_up(prompts))
|
|
|
|
|
|
@torch.no_grad()
|
|
def upscale_latent(
|
|
low_res_latent,
|
|
upscale_prompt="",
|
|
seed=0,
|
|
steps=15,
|
|
guidance_scale=1.0,
|
|
batch_size=1,
|
|
num_samples=1,
|
|
# Amount of noise to add per step (0.0=deterministic). Used in all samplers except `k_euler`.
|
|
eta=0.1,
|
|
device=get_device(),
|
|
):
|
|
# Add noise to the latent vectors before upscaling. This theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default.
|
|
noise_aug_level = 0 # @param {type: 'slider', min: 0.0, max: 0.6, step:0.025}
|
|
noise_aug_type = "fake" # @param ["gaussian", "fake"]
|
|
|
|
# @markdown Sampler settings. `k_dpm_adaptive` uses an adaptive solver with error tolerance `tol_scale`, all other use a fixed number of steps.
|
|
sampler = "k_dpm_2_ancestral" # @param ["k_euler", "k_euler_ancestral", "k_dpm_2_ancestral", "k_dpm_fast", "k_dpm_adaptive"]
|
|
|
|
tol_scale = 0.25 # @param {type: 'number'}
|
|
|
|
# seed_everything(seed)
|
|
|
|
# uc = condition_up(batch_size * ["blurry, low resolution, 720p, grainy"])
|
|
uc = condition_up(batch_size * [""])
|
|
c = condition_up(batch_size * [upscale_prompt])
|
|
|
|
[_, C, H, W] = low_res_latent.shape
|
|
|
|
# Noise levels from stable diffusion.
|
|
sigma_min, sigma_max = 0.029167532920837402, 14.614642143249512
|
|
model_up = get_upscaler_model(
|
|
model_path=hf_hub_download(
|
|
"pcuenq/k-upscaler", "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
|
|
),
|
|
device=device,
|
|
)
|
|
model_wrap = CFGUpscaler(model_up, uc, cond_scale=guidance_scale, device=device)
|
|
low_res_sigma = torch.full([batch_size], noise_aug_level, device=device)
|
|
x_shape = [batch_size, C, 2 * H, 2 * W]
|
|
|
|
def do_sample(noise, extra_args):
|
|
# We take log-linear steps in noise-level from sigma_max to sigma_min, using one of the k diffusion samplers.
|
|
sigmas = (
|
|
torch.linspace(np.log(sigma_max), np.log(sigma_min), steps + 1)
|
|
.exp()
|
|
.to(device)
|
|
)
|
|
if sampler == "k_euler":
|
|
return K.sampling.sample_euler(
|
|
model_wrap, noise * sigma_max, sigmas, extra_args=extra_args
|
|
)
|
|
if sampler == "k_euler_ancestral":
|
|
return K.sampling.sample_euler_ancestral(
|
|
model_wrap, noise * sigma_max, sigmas, extra_args=extra_args, eta=eta
|
|
)
|
|
if sampler == "k_dpm_2_ancestral":
|
|
return K.sampling.sample_dpm_2_ancestral(
|
|
model_wrap, noise * sigma_max, sigmas, extra_args=extra_args, eta=eta
|
|
)
|
|
if sampler == "k_dpm_fast":
|
|
return K.sampling.sample_dpm_fast(
|
|
model_wrap,
|
|
noise * sigma_max,
|
|
sigma_min,
|
|
sigma_max,
|
|
steps,
|
|
extra_args=extra_args,
|
|
eta=eta,
|
|
)
|
|
if sampler == "k_dpm_adaptive":
|
|
sampler_opts = {
|
|
"s_noise": 1.0,
|
|
"rtol": tol_scale * 0.05,
|
|
"atol": tol_scale / 127.5,
|
|
"pcoeff": 0.2,
|
|
"icoeff": 0.4,
|
|
"dcoeff": 0,
|
|
}
|
|
return K.sampling.sample_dpm_adaptive(
|
|
model_wrap,
|
|
noise * sigma_max,
|
|
sigma_min,
|
|
sigma_max,
|
|
extra_args=extra_args,
|
|
eta=eta,
|
|
**sampler_opts,
|
|
)
|
|
raise ValueError(f"Unknown sampler {sampler}")
|
|
|
|
for _ in range((num_samples - 1) // batch_size + 1):
|
|
if noise_aug_type == "gaussian":
|
|
latent_noised = low_res_latent + noise_aug_level * torch.randn_like(
|
|
low_res_latent
|
|
)
|
|
elif noise_aug_type == "fake":
|
|
latent_noised = low_res_latent * (noise_aug_level**2 + 1) ** 0.5
|
|
extra_args = {
|
|
"low_res": latent_noised, # noqa
|
|
"low_res_sigma": low_res_sigma,
|
|
"c": c,
|
|
}
|
|
noise = torch.randn(x_shape, device=device)
|
|
up_latents = do_sample(noise, extra_args)
|
|
log_latent(low_res_latent, "low_res_latent")
|
|
return up_latents
|