feature: generate large images
Added a composition stage so large images are more coherentpull/259/head
parent
b93b6a4d7c
commit
2aef6089e0
@ -0,0 +1,299 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import nn
|
||||
|
||||
from imaginairy.model_manager import hf_hub_download
|
||||
from imaginairy.utils import get_device, platform_appropriate_autocast
|
||||
from imaginairy.vendored import k_diffusion as K
|
||||
from imaginairy.vendored.k_diffusion import layers
|
||||
from imaginairy.vendored.k_diffusion.models.image_v1 import ImageDenoiserModelV1
|
||||
from imaginairy.vendored.k_diffusion.utils import append_dims
|
||||
|
||||
|
||||
class NoiseLevelAndTextConditionedUpscaler(nn.Module):
|
||||
def __init__(self, inner_model, sigma_data=1.0, embed_dim=256):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = sigma_data
|
||||
self.low_res_noise_embed = K.layers.FourierFeatures(1, embed_dim, std=2)
|
||||
|
||||
def forward(self, inp, sigma, low_res, low_res_sigma, c, **kwargs):
|
||||
cross_cond, cross_cond_padding, pooler = c
|
||||
c_in = 1 / (low_res_sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_noise = low_res_sigma.log1p()[:, None]
|
||||
c_in = append_dims(c_in, low_res.ndim)
|
||||
low_res_noise_embed = self.low_res_noise_embed(c_noise)
|
||||
low_res_in = F.interpolate(low_res, scale_factor=2, mode="nearest") * c_in
|
||||
mapping_cond = torch.cat([low_res_noise_embed, pooler], dim=1)
|
||||
return self.inner_model(
|
||||
inp,
|
||||
sigma,
|
||||
unet_cond=low_res_in,
|
||||
mapping_cond=mapping_cond,
|
||||
cross_cond=cross_cond,
|
||||
cross_cond_padding=cross_cond_padding,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_upscaler_model(
|
||||
model_path,
|
||||
pooler_dim=768,
|
||||
train=False,
|
||||
device=get_device(),
|
||||
):
|
||||
config = {
|
||||
"type": "image_v1",
|
||||
"input_channels": 4,
|
||||
"input_size": [48, 48],
|
||||
"patch_size": 1,
|
||||
"mapping_out": 768,
|
||||
"mapping_cond_dim": 896,
|
||||
"unet_cond_dim": 4,
|
||||
"depths": [4, 4, 4, 4],
|
||||
"channels": [384, 384, 768, 768],
|
||||
"self_attn_depths": [False, False, False, True],
|
||||
"cross_attn_depths": [False, True, True, True],
|
||||
"cross_cond_dim": 768,
|
||||
"has_variance": True,
|
||||
"dropout_rate": 0.0,
|
||||
"augment_prob": 0.0,
|
||||
"augment_wrapper": False,
|
||||
"sigma_data": 1.0,
|
||||
"sigma_min": 1e-2,
|
||||
"sigma_max": 20,
|
||||
"sigma_sample_density": {"type": "lognormal", "mean": -0.5, "std": 1.2},
|
||||
"skip_stages": 0,
|
||||
}
|
||||
|
||||
model = ImageDenoiserModelV1(
|
||||
config["input_channels"],
|
||||
config["mapping_out"],
|
||||
config["depths"],
|
||||
config["channels"],
|
||||
config["self_attn_depths"],
|
||||
config["cross_attn_depths"],
|
||||
patch_size=config["patch_size"],
|
||||
dropout_rate=config["dropout_rate"],
|
||||
mapping_cond_dim=config["mapping_cond_dim"]
|
||||
+ (9 if config["augment_wrapper"] else 0),
|
||||
unet_cond_dim=config["unet_cond_dim"],
|
||||
cross_cond_dim=config["cross_cond_dim"],
|
||||
skip_stages=config["skip_stages"],
|
||||
has_variance=config["has_variance"],
|
||||
)
|
||||
|
||||
model = NoiseLevelAndTextConditionedUpscaler(
|
||||
model,
|
||||
sigma_data=config["sigma_data"],
|
||||
embed_dim=config["mapping_cond_dim"] - pooler_dim,
|
||||
)
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(ckpt["model_ema"])
|
||||
model = layers.DenoiserWithVariance(model, sigma_data=config["sigma_data"])
|
||||
if not train:
|
||||
model = model.eval().requires_grad_(False)
|
||||
return model.to(device)
|
||||
|
||||
|
||||
class CFGUpscaler(nn.Module):
|
||||
def __init__(self, model, uc, cond_scale, device):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.uc = uc
|
||||
self.cond_scale = cond_scale
|
||||
self.device = device
|
||||
|
||||
def forward(self, x, sigma, low_res, low_res_sigma, c):
|
||||
if self.cond_scale in (0.0, 1.0):
|
||||
# Shortcut for when we don't need to run both.
|
||||
if self.cond_scale == 0.0:
|
||||
c_in = self.uc
|
||||
elif self.cond_scale == 1.0:
|
||||
c_in = c
|
||||
return self.inner_model(
|
||||
x, sigma, low_res=low_res, low_res_sigma=low_res_sigma, c=c_in
|
||||
)
|
||||
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
low_res_in = torch.cat([low_res] * 2)
|
||||
low_res_sigma_in = torch.cat([low_res_sigma] * 2)
|
||||
c_in = [torch.cat([uc_item, c_item]) for uc_item, c_item in zip(self.uc, c)]
|
||||
uncond, cond = self.inner_model(
|
||||
x_in, sigma_in, low_res=low_res_in, low_res_sigma=low_res_sigma_in, c=c_in
|
||||
).chunk(2)
|
||||
return uncond + (cond - uncond) * self.cond_scale
|
||||
|
||||
|
||||
class CLIPTokenizerTransform:
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", max_length=77):
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.max_length = max_length
|
||||
|
||||
def __call__(self, text):
|
||||
indexer = 0 if isinstance(text, str) else ...
|
||||
tok_out = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = tok_out["input_ids"][indexer]
|
||||
attention_mask = 1 - tok_out["attention_mask"][indexer]
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
class CLIPEmbedder(nn.Module):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)."""
|
||||
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda"):
|
||||
super().__init__()
|
||||
from transformers import CLIPTextModel, logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.transformer = self.transformer.eval().requires_grad_(False).to(device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.transformer.device
|
||||
|
||||
def forward(self, tok_out):
|
||||
input_ids, cross_cond_padding = tok_out
|
||||
clip_out = self.transformer(
|
||||
input_ids=input_ids.to(self.device), output_hidden_states=True
|
||||
)
|
||||
return (
|
||||
clip_out.hidden_states[-1],
|
||||
cross_cond_padding.to(self.device),
|
||||
clip_out.pooler_output,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def clip_up_models():
|
||||
with platform_appropriate_autocast():
|
||||
tok_up = CLIPTokenizerTransform()
|
||||
text_encoder_up = CLIPEmbedder(device=get_device())
|
||||
return text_encoder_up, tok_up
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def condition_up(prompts):
|
||||
text_encoder_up, tok_up = clip_up_models()
|
||||
return text_encoder_up(tok_up(prompts))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def upscale_latent(
|
||||
low_res_latent,
|
||||
upscale_prompt="",
|
||||
seed=0,
|
||||
steps=30,
|
||||
guidance_scale=1.0,
|
||||
batch_size=1,
|
||||
num_samples=1,
|
||||
# Amount of noise to add per step (0.0=deterministic). Used in all samplers except `k_euler`.
|
||||
eta=1.0,
|
||||
device=get_device(),
|
||||
):
|
||||
# Add noise to the latent vectors before upscaling. This theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default.
|
||||
noise_aug_level = 0 # @param {type: 'slider', min: 0.0, max: 0.6, step:0.025}
|
||||
noise_aug_type = "gaussian" # @param ["gaussian", "fake"]
|
||||
|
||||
# @markdown Sampler settings. `k_dpm_adaptive` uses an adaptive solver with error tolerance `tol_scale`, all other use a fixed number of steps.
|
||||
sampler = "k_dpm_2_ancestral" # @param ["k_euler", "k_euler_ancestral", "k_dpm_2_ancestral", "k_dpm_fast", "k_dpm_adaptive"]
|
||||
|
||||
tol_scale = 0.25 # @param {type: 'number'}
|
||||
|
||||
seed_everything(seed)
|
||||
|
||||
# uc = condition_up(batch_size * ["blurry, low resolution, 720p, grainy"])
|
||||
uc = condition_up(batch_size * [""])
|
||||
c = condition_up(batch_size * [upscale_prompt])
|
||||
|
||||
[_, C, H, W] = low_res_latent.shape
|
||||
|
||||
# Noise levels from stable diffusion.
|
||||
sigma_min, sigma_max = 0.029167532920837402, 14.614642143249512
|
||||
model_up = get_upscaler_model(
|
||||
model_path=hf_hub_download(
|
||||
"pcuenq/k-upscaler", "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
model_wrap = CFGUpscaler(model_up, uc, cond_scale=guidance_scale, device=device)
|
||||
low_res_sigma = torch.full([batch_size], noise_aug_level, device=device)
|
||||
x_shape = [batch_size, C, 2 * H, 2 * W]
|
||||
|
||||
def do_sample(noise, extra_args):
|
||||
# We take log-linear steps in noise-level from sigma_max to sigma_min, using one of the k diffusion samplers.
|
||||
sigmas = (
|
||||
torch.linspace(np.log(sigma_max), np.log(sigma_min), steps + 1)
|
||||
.exp()
|
||||
.to(device)
|
||||
)
|
||||
if sampler == "k_euler":
|
||||
return K.sampling.sample_euler(
|
||||
model_wrap, noise * sigma_max, sigmas, extra_args=extra_args
|
||||
)
|
||||
if sampler == "k_euler_ancestral":
|
||||
return K.sampling.sample_euler_ancestral(
|
||||
model_wrap, noise * sigma_max, sigmas, extra_args=extra_args, eta=eta
|
||||
)
|
||||
if sampler == "k_dpm_2_ancestral":
|
||||
return K.sampling.sample_dpm_2_ancestral(
|
||||
model_wrap, noise * sigma_max, sigmas, extra_args=extra_args, eta=eta
|
||||
)
|
||||
if sampler == "k_dpm_fast":
|
||||
return K.sampling.sample_dpm_fast(
|
||||
model_wrap,
|
||||
noise * sigma_max,
|
||||
sigma_min,
|
||||
sigma_max,
|
||||
steps,
|
||||
extra_args=extra_args,
|
||||
eta=eta,
|
||||
)
|
||||
if sampler == "k_dpm_adaptive":
|
||||
sampler_opts = {
|
||||
"s_noise": 1.0,
|
||||
"rtol": tol_scale * 0.05,
|
||||
"atol": tol_scale / 127.5,
|
||||
"pcoeff": 0.2,
|
||||
"icoeff": 0.4,
|
||||
"dcoeff": 0,
|
||||
}
|
||||
return K.sampling.sample_dpm_adaptive(
|
||||
model_wrap,
|
||||
noise * sigma_max,
|
||||
sigma_min,
|
||||
sigma_max,
|
||||
extra_args=extra_args,
|
||||
eta=eta,
|
||||
**sampler_opts,
|
||||
)
|
||||
raise ValueError(f"Unknown sampler {sampler}")
|
||||
|
||||
for _ in range((num_samples - 1) // batch_size + 1):
|
||||
if noise_aug_type == "gaussian":
|
||||
latent_noised = low_res_latent + noise_aug_level * torch.randn_like(
|
||||
low_res_latent
|
||||
)
|
||||
elif noise_aug_type == "fake":
|
||||
latent_noised = low_res_latent * (noise_aug_level**2 + 1) ** 0.5
|
||||
extra_args = {"low_res": latent_noised, "low_res_sigma": low_res_sigma, "c": c}
|
||||
noise = torch.randn(x_shape, device=device)
|
||||
up_latents = do_sample(noise, extra_args)
|
||||
return up_latents
|
Loading…
Reference in New Issue