mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
feature: k-diffusion samplers
- improved image logging functionality. can just stick log_latent wherever you want - improved some variable naming - moved all the samplers together - vendored k-diffusion library
This commit is contained in:
parent
20ac04d9df
commit
b4a3b8c2b3
14
Makefile
14
Makefile
@ -77,15 +77,23 @@ vendor_openai_clip:
|
||||
echo "vendored from git@github.com:openai/CLIP.git" | tee ./imaginairy/vendored/clip/readme.txt
|
||||
|
||||
revendorize:
|
||||
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip
|
||||
#make vendorize REPO=git@github.com:xinntao/Real-ESRGAN.git PKG=realesrgan
|
||||
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip COMMIT=d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
|
||||
make vendorize REPO=git@github.com:crowsonkb/k-diffusion.git PKG=k_diffusion COMMIT=1a0703dfb7d24d8806267c3e7ccc4caf67fd1331
|
||||
#sed -i'' -e 's/^import\sclip/from\simaginairy.vendored\simport\sclip/g' imaginairy/vendored/k_diffusion/evaluation.py
|
||||
rm imaginairy/vendored/k_diffusion/evaluation.py
|
||||
touch imaginairy/vendored/k_diffusion/evaluation.py
|
||||
rm imaginairy/vendored/k_diffusion/config.py
|
||||
touch imaginairy/vendored/k_diffusion/config.py
|
||||
# without this most of the k-diffusion samplers didn't work
|
||||
sed -i'' -e 's#return (x - denoised) / utils.append_dims(sigma, x.ndim)#return (x - denoised) / sigma#g' imaginairy/vendored/k_diffusion/sampling.py
|
||||
make af
|
||||
|
||||
|
||||
|
||||
vendorize: ## vendorize a github repo. `make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip`
|
||||
mkdir -p ./downloads
|
||||
-cd ./downloads && git clone $(REPO) $(PKG)
|
||||
cd ./downloads/$(PKG) && git pull
|
||||
cd ./downloads/$(PKG) && git fetch && git checkout $(COMMIT)
|
||||
rm -rf ./imaginairy/vendored/$(PKG)
|
||||
cp -R ./downloads/$(PKG)/$(PKG) imaginairy/vendored/
|
||||
git --git-dir ./downloads/$(PKG)/.git rev-parse HEAD | tee ./imaginairy/vendored/$(PKG)/clip-commit-hash.txt
|
||||
|
@ -141,7 +141,7 @@ imagine_image_files(prompts, outdir="./my-art")
|
||||
- ✅ init-image at command line
|
||||
- prompt expansion
|
||||
- Image Generation Features
|
||||
- add k-diffusion sampling methods
|
||||
- ✅ add k-diffusion sampling methods
|
||||
- upscaling
|
||||
- ✅ realesrgan
|
||||
- ldm
|
||||
|
@ -16,9 +16,9 @@ from transformers import cached_path
|
||||
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||
from imaginairy.modules.diffusion.ddim import DDIMSampler
|
||||
from imaginairy.modules.diffusion.plms import PLMSSampler
|
||||
from imaginairy.img_log import LatentLoggingContext, log_latent
|
||||
from imaginairy.safety import is_nsfw, safety_models
|
||||
from imaginairy.samplers.base import get_sampler
|
||||
from imaginairy.schema import ImaginePrompt, ImagineResult
|
||||
from imaginairy.utils import (
|
||||
fix_torch_nn_layer_norm,
|
||||
@ -57,7 +57,10 @@ def load_model_from_config(config):
|
||||
|
||||
|
||||
def patch_conv(**patch):
|
||||
"""https://github.com/replicate/cog-stable-diffusion/compare/main...TomMoore515:material_stable_diffusion:main"""
|
||||
"""
|
||||
Patch to enable tiling mode
|
||||
https://github.com/replicate/cog-stable-diffusion/compare/main...TomMoore515:material_stable_diffusion:main
|
||||
"""
|
||||
cls = torch.nn.Conv2d
|
||||
init = cls.__init__
|
||||
|
||||
@ -96,34 +99,26 @@ def imagine_image_files(
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
base_count = len(os.listdir(outdir))
|
||||
step_count = 0
|
||||
output_file_extension = output_file_extension.lower()
|
||||
if output_file_extension not in {"jpg", "png"}:
|
||||
raise ValueError("Must output a png or jpg")
|
||||
|
||||
def _record_steps(samples, description, model, prompt):
|
||||
nonlocal step_count
|
||||
step_count += 1
|
||||
samples = model.decode_first_stage(samples)
|
||||
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
def _record_step(img, description, step_count, prompt):
|
||||
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
|
||||
os.makedirs(steps_path, exist_ok=True)
|
||||
for pred_x0 in samples:
|
||||
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
|
||||
filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}.jpg"
|
||||
img = Image.fromarray(pred_x0.astype(np.uint8))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.text((10, 10), str(description))
|
||||
img.save(os.path.join(steps_path, filename))
|
||||
filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}.jpg"
|
||||
destination = os.path.join(steps_path, filename)
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.text((10, 10), str(description))
|
||||
img.save(destination)
|
||||
|
||||
img_callback = _record_steps if record_step_images else None
|
||||
for result in imagine(
|
||||
prompts,
|
||||
latent_channels=latent_channels,
|
||||
downsampling_factor=downsampling_factor,
|
||||
precision=precision,
|
||||
ddim_eta=ddim_eta,
|
||||
img_callback=img_callback,
|
||||
img_callback=_record_step if record_step_images else None,
|
||||
tile_mode=tile_mode,
|
||||
):
|
||||
prompt = result.prompt
|
||||
@ -164,6 +159,7 @@ def imagine(
|
||||
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
|
||||
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
|
||||
_img_callback = None
|
||||
step_count = 0
|
||||
|
||||
precision_scope = (
|
||||
autocast
|
||||
@ -172,108 +168,101 @@ def imagine(
|
||||
)
|
||||
with (torch.no_grad(), precision_scope(get_device()), fix_torch_nn_layer_norm()):
|
||||
for prompt in prompts:
|
||||
logger.info(f"Generating {prompt.prompt_description()}")
|
||||
seed_everything(prompt.seed)
|
||||
with LatentLoggingContext(
|
||||
prompt=prompt, model=model, img_callback=img_callback
|
||||
):
|
||||
logger.info(f"Generating {prompt.prompt_description()}")
|
||||
seed_everything(prompt.seed)
|
||||
|
||||
uc = None
|
||||
if prompt.prompt_strength != 1.0:
|
||||
uc = model.get_learned_conditioning(1 * [""])
|
||||
total_weight = sum(wp.weight for wp in prompt.prompts)
|
||||
c = sum(
|
||||
[
|
||||
model.get_learned_conditioning(wp.text) * (wp.weight / total_weight)
|
||||
for wp in prompt.prompts
|
||||
uc = None
|
||||
if prompt.prompt_strength != 1.0:
|
||||
uc = model.get_learned_conditioning(1 * [""])
|
||||
total_weight = sum(wp.weight for wp in prompt.prompts)
|
||||
c = sum(
|
||||
[
|
||||
model.get_learned_conditioning(wp.text)
|
||||
* (wp.weight / total_weight)
|
||||
for wp in prompt.prompts
|
||||
]
|
||||
)
|
||||
|
||||
shape = [
|
||||
latent_channels,
|
||||
prompt.height // downsampling_factor,
|
||||
prompt.width // downsampling_factor,
|
||||
]
|
||||
)
|
||||
|
||||
def _img_callback(samples, description):
|
||||
if img_callback:
|
||||
img_callback(samples, description, model, prompt)
|
||||
start_code = None
|
||||
sampler = get_sampler(prompt.sampler_type, model)
|
||||
if prompt.init_image:
|
||||
generation_strength = 1 - prompt.init_image_strength
|
||||
ddim_steps = int(prompt.steps / generation_strength)
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
|
||||
|
||||
shape = [
|
||||
latent_channels,
|
||||
prompt.height // downsampling_factor,
|
||||
prompt.width // downsampling_factor,
|
||||
]
|
||||
t_enc = int(generation_strength * ddim_steps)
|
||||
init_image, w, h = img_path_to_torch_image(prompt.init_image)
|
||||
init_image = init_image.to(get_device())
|
||||
init_latent = model.get_first_stage_encoding(
|
||||
model.encode_first_stage(init_image)
|
||||
)
|
||||
log_latent(init_latent, "init_latent")
|
||||
|
||||
start_code = None
|
||||
sampler = get_sampler(prompt.sampler_type, model)
|
||||
if prompt.init_image:
|
||||
generation_strength = 1 - prompt.init_image_strength
|
||||
ddim_steps = int(prompt.steps / generation_strength)
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
init_latent, torch.tensor([t_enc]).to(get_device())
|
||||
)
|
||||
log_latent(z_enc, "z_enc")
|
||||
|
||||
t_enc = int(generation_strength * ddim_steps)
|
||||
init_image, w, h = img_path_to_torch_image(prompt.init_image)
|
||||
init_image = init_image.to(get_device())
|
||||
init_latent = model.get_first_stage_encoding(
|
||||
model.encode_first_stage(init_image)
|
||||
)
|
||||
_img_callback(init_latent, "init_latent")
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
unconditional_guidance_scale=prompt.prompt_strength,
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=_img_callback,
|
||||
)
|
||||
else:
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
init_latent, torch.tensor([t_enc]).to(get_device())
|
||||
)
|
||||
_img_callback(z_enc, "z_enc")
|
||||
samples, _ = sampler.sample(
|
||||
num_steps=prompt.steps,
|
||||
conditioning=c,
|
||||
batch_size=1,
|
||||
shape=shape,
|
||||
unconditional_guidance_scale=prompt.prompt_strength,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
initial_noise_tensor=start_code,
|
||||
img_callback=_img_callback,
|
||||
)
|
||||
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
unconditional_guidance_scale=prompt.prompt_strength,
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=_img_callback,
|
||||
)
|
||||
else:
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
S=prompt.steps,
|
||||
conditioning=c,
|
||||
batch_size=1,
|
||||
shape=shape,
|
||||
unconditional_guidance_scale=prompt.prompt_strength,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
x_T=start_code,
|
||||
img_callback=_img_callback,
|
||||
)
|
||||
for x_sample in x_samples:
|
||||
x_sample = 255.0 * rearrange(
|
||||
x_sample.cpu().numpy(), "c h w -> h w c"
|
||||
)
|
||||
x_sample_8_orig = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample_8_orig)
|
||||
upscaled_img = None
|
||||
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(
|
||||
img, x_sample, half_mode=half_mode
|
||||
):
|
||||
logger.info(" ⚠️ Filtering NSFW image")
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=40))
|
||||
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
if prompt.fix_faces:
|
||||
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")
|
||||
img = enhance_faces(img, fidelity=0.2)
|
||||
if prompt.upscale:
|
||||
logger.info(" Upscaling 🖼 using real-ESRGAN...")
|
||||
upscaled_img = upscale_image(img)
|
||||
|
||||
for x_sample in x_samples:
|
||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
||||
x_sample_8_orig = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample_8_orig)
|
||||
upscaled_img = None
|
||||
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(
|
||||
img, x_sample, half_mode=half_mode
|
||||
):
|
||||
logger.info(" ⚠️ Filtering NSFW image")
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=40))
|
||||
|
||||
if prompt.fix_faces:
|
||||
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")
|
||||
img = enhance_faces(img, fidelity=0.2)
|
||||
if prompt.upscale:
|
||||
logger.info(" Upscaling 🖼 using real-ESRGAN...")
|
||||
upscaled_img = upscale_image(img)
|
||||
|
||||
yield ImagineResult(img=img, prompt=prompt, upscaled_img=upscaled_img)
|
||||
yield ImagineResult(
|
||||
img=img, prompt=prompt, upscaled_img=upscaled_img
|
||||
)
|
||||
|
||||
|
||||
def prompt_normalized(prompt):
|
||||
return re.sub(r"[^a-zA-Z0-9.,]+", "_", prompt)[:130]
|
||||
|
||||
|
||||
DOWNLOADED_FILES_PATH = f"{LIB_PATH}/../downloads/"
|
||||
|
||||
|
||||
def get_sampler(sampler_type, model):
|
||||
sampler_type = sampler_type.upper()
|
||||
if sampler_type == "PLMS":
|
||||
return PLMSSampler(model)
|
||||
elif sampler_type == "DDIM":
|
||||
return DDIMSampler(model)
|
||||
|
@ -100,8 +100,19 @@ def configure_logging(level="INFO"):
|
||||
@click.option("--fix-faces-method", default="gfpgan", type=click.Choice(["gfpgan"]))
|
||||
@click.option(
|
||||
"--sampler-type",
|
||||
default="PLMS",
|
||||
type=click.Choice(["PLMS", "DDIM"]),
|
||||
default="plms",
|
||||
type=click.Choice(
|
||||
[
|
||||
"plms",
|
||||
"ddim",
|
||||
"k_lms",
|
||||
"k_dpm_2",
|
||||
"k_dpm_2_a",
|
||||
"k_euler",
|
||||
"k_euler_a",
|
||||
"k_heun",
|
||||
]
|
||||
),
|
||||
help="What sampling strategy to use",
|
||||
)
|
||||
@click.option("--ddim-eta", default=0.0, type=float)
|
||||
@ -154,7 +165,7 @@ def imagine_cmd(
|
||||
logger.info(
|
||||
f"🤖🧠 imaginAIry received {len(prompt_texts)} prompt(s) and will repeat them {repeats} times to create {total_image_count} images."
|
||||
)
|
||||
if init_image and sampler_type != "DDIM":
|
||||
if init_image and sampler_type == "DDIM":
|
||||
sampler_type = "DDIM"
|
||||
|
||||
prompts = []
|
||||
|
55
imaginairy/img_log.py
Normal file
55
imaginairy/img_log.py
Normal file
@ -0,0 +1,55 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
_CURRENT_LOGGING_CONTEXT = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def log_latent(latents, description):
|
||||
if _CURRENT_LOGGING_CONTEXT is None:
|
||||
return
|
||||
if torch.isnan(latents).any() or torch.isinf(latents).any():
|
||||
logger.error(
|
||||
"Inf/NaN values showing in transformer."
|
||||
+ repr(latents)[:50]
|
||||
+ " "
|
||||
+ description[:50]
|
||||
)
|
||||
_CURRENT_LOGGING_CONTEXT.log_latents(latents, description)
|
||||
|
||||
|
||||
class LatentLoggingContext:
|
||||
def __init__(self, prompt, model, img_callback=None):
|
||||
self.prompt = prompt
|
||||
self.model = model
|
||||
self.step_count = 0
|
||||
self.img_callback = img_callback
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_LOGGING_CONTEXT
|
||||
_CURRENT_LOGGING_CONTEXT = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_LOGGING_CONTEXT
|
||||
_CURRENT_LOGGING_CONTEXT = None
|
||||
|
||||
def log_latents(self, samples, description):
|
||||
if not self.img_callback:
|
||||
return
|
||||
if samples.shape[1] != 4:
|
||||
# logger.info(f"Didn't save tensor of shape {samples.shape} for {description}")
|
||||
return
|
||||
self.step_count += 1
|
||||
description = f"{description} - {samples.shape}"
|
||||
samples = self.model.decode_first_stage(samples)
|
||||
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
for pred_x0 in samples:
|
||||
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
|
||||
img = Image.fromarray(pred_x0.astype(np.uint8))
|
||||
self.img_callback(img, description, self.step_count, self.prompt)
|
0
imaginairy/samplers/__init__.py
Normal file
0
imaginairy/samplers/__init__.py
Normal file
90
imaginairy/samplers/base.py
Normal file
90
imaginairy/samplers/base.py
Normal file
@ -0,0 +1,90 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from imaginairy.samplers.ddim import DDIMSampler
|
||||
from imaginairy.samplers.kdiff import KDiffusionSampler
|
||||
from imaginairy.samplers.plms import PLMSSampler
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
_k_sampler_type_lookup = {
|
||||
"k_dpm_2": "dpm_2",
|
||||
"k_dpm_2_a": "dpm_2_ancestral",
|
||||
"k_euler": "euler",
|
||||
"k_euler_a": "euler_ancestral",
|
||||
"k_heun": "heun",
|
||||
"k_lms": "lms",
|
||||
}
|
||||
|
||||
|
||||
def get_sampler(sampler_type, model):
|
||||
sampler_type = sampler_type.lower()
|
||||
if sampler_type == "plms":
|
||||
return PLMSSampler(model)
|
||||
elif sampler_type == "ddim":
|
||||
return DDIMSampler(model)
|
||||
elif sampler_type.startswith("k_"):
|
||||
sampler_type = _k_sampler_type_lookup[sampler_type]
|
||||
return KDiffusionSampler(model, sampler_type)
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
class DiffusionSampler:
|
||||
"""
|
||||
wip
|
||||
|
||||
hope to enforce an api upon samplers
|
||||
"""
|
||||
|
||||
def __init__(self, noise_prediction_model, sampler_func, device=get_device()):
|
||||
self.noise_prediction_model = noise_prediction_model
|
||||
self.cfg_noise_prediction_model = CFGDenoiser(noise_prediction_model)
|
||||
self.sampler_func = sampler_func
|
||||
self.device = device
|
||||
|
||||
def sample(
|
||||
self,
|
||||
num_steps,
|
||||
text_conditioning,
|
||||
batch_size,
|
||||
shape,
|
||||
unconditional_guidance_scale,
|
||||
unconditional_conditioning,
|
||||
eta,
|
||||
initial_noise_tensor=None,
|
||||
img_callback=None,
|
||||
):
|
||||
size = (batch_size, *shape)
|
||||
|
||||
initial_noise_tensor = (
|
||||
torch.randn(size, device="cpu").to(get_device())
|
||||
if initial_noise_tensor is None
|
||||
else initial_noise_tensor
|
||||
)
|
||||
sigmas = self.noise_prediction_model.get_sigmas(num_steps)
|
||||
x = initial_noise_tensor * sigmas[0]
|
||||
|
||||
samples = self.sampler_func(
|
||||
self.cfg_noise_prediction_model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args={
|
||||
"cond": text_conditioning,
|
||||
"uncond": unconditional_conditioning,
|
||||
"cond_scale": unconditional_guidance_scale,
|
||||
},
|
||||
disable=False,
|
||||
)
|
||||
|
||||
return samples, None
|
@ -5,6 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from imaginairy.img_log import log_latent
|
||||
from imaginairy.modules.diffusion.util import (
|
||||
extract_into_tensor,
|
||||
make_ddim_sampling_parameters,
|
||||
@ -89,7 +90,7 @@ class DDIMSampler:
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
num_steps,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
@ -124,7 +125,7 @@ class DDIMSampler:
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta)
|
||||
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
@ -178,6 +179,7 @@ class DDIMSampler:
|
||||
img = torch.randn(shape, device="cpu").to(device)
|
||||
else:
|
||||
img = x_T
|
||||
log_latent(img, "initial noise")
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = (
|
||||
@ -231,9 +233,9 @@ class DDIMSampler:
|
||||
)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
log_latent(img, "img")
|
||||
log_latent(pred_x0, "pred_x0")
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates["x_inter"].append(img)
|
||||
@ -378,7 +380,7 @@ class DDIMSampler:
|
||||
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
|
||||
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
|
||||
## x_dec_alt = x_dec + (original_loss * 0.1) ** 2
|
||||
if img_callback:
|
||||
img_callback(x_dec, f"x_dec {i}")
|
||||
img_callback(pred_x0, f"pred_x0 {i}")
|
||||
|
||||
log_latent(x_dec, f"x_dec {i}")
|
||||
log_latent(pred_x0, f"pred_x0 {i}")
|
||||
return x_dec
|
97
imaginairy/samplers/kdiff.py
Normal file
97
imaginairy/samplers/kdiff.py
Normal file
@ -0,0 +1,97 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from imaginairy.img_log import log_latent
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.vendored.k_diffusion import sampling as k_sampling
|
||||
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
|
||||
|
||||
|
||||
class CFGMaskedDenoiser(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
|
||||
x_in = x
|
||||
x_in = torch.cat([x_in] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
denoised = uncond + (cond - uncond) * cond_scale
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = x0
|
||||
mask_inv = 1.0 - mask
|
||||
denoised = (img_orig * mask_inv) + (mask * denoised)
|
||||
|
||||
return denoised
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, model, sampler_name):
|
||||
self.model = model
|
||||
self.cv_denoiser = CompVisDenoiser(model)
|
||||
# self.cfg_denoiser = CompVisDenoiser(self.cv_denoiser)
|
||||
self.sampler_name = sampler_name
|
||||
self.sampler_func = getattr(k_sampling, f"sample_{sampler_name}")
|
||||
|
||||
def sample(
|
||||
self,
|
||||
num_steps,
|
||||
conditioning,
|
||||
batch_size,
|
||||
shape,
|
||||
unconditional_guidance_scale,
|
||||
unconditional_conditioning,
|
||||
eta,
|
||||
initial_noise_tensor=None,
|
||||
img_callback=None,
|
||||
):
|
||||
size = (batch_size, *shape)
|
||||
|
||||
initial_noise_tensor = (
|
||||
torch.randn(size, device="cpu").to(get_device())
|
||||
if initial_noise_tensor is None
|
||||
else initial_noise_tensor
|
||||
)
|
||||
log_latent(initial_noise_tensor, "initial_noise_tensor")
|
||||
|
||||
sigmas = self.cv_denoiser.get_sigmas(num_steps)
|
||||
|
||||
x = initial_noise_tensor * sigmas[0]
|
||||
log_latent(x, "initial_sigma_noised_tensor")
|
||||
model_wrap_cfg = CFGDenoiser(self.cv_denoiser)
|
||||
|
||||
def callback(data):
|
||||
log_latent(data["x"], "x")
|
||||
log_latent(data["denoised"], "denoised")
|
||||
|
||||
samples = self.sampler_func(
|
||||
model_wrap_cfg,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args={
|
||||
"cond": conditioning,
|
||||
"uncond": unconditional_conditioning,
|
||||
"cond_scale": unconditional_guidance_scale,
|
||||
},
|
||||
disable=False,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
return samples, None
|
@ -90,7 +90,7 @@ class PLMSSampler(object):
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
num_steps,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
@ -125,7 +125,7 @@ class PLMSSampler(object):
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta)
|
||||
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
@ -248,8 +248,8 @@ class PLMSSampler(object):
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(img, i)
|
||||
img_callback(pred_x0, i)
|
||||
img_callback(img, "img")
|
||||
img_callback(pred_x0, "pred_x0")
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates["x_inter"].append(img)
|
12
imaginairy/vendored/k_diffusion/__init__.py
Normal file
12
imaginairy/vendored/k_diffusion/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from . import (
|
||||
augmentation,
|
||||
config,
|
||||
evaluation,
|
||||
external,
|
||||
gns,
|
||||
layers,
|
||||
models,
|
||||
sampling,
|
||||
utils,
|
||||
)
|
||||
from .layers import Denoiser
|
113
imaginairy/vendored/k_diffusion/augmentation.py
Normal file
113
imaginairy/vendored/k_diffusion/augmentation.py
Normal file
@ -0,0 +1,113 @@
|
||||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from skimage import transform
|
||||
from torch import nn
|
||||
|
||||
|
||||
def translate2d(tx, ty):
|
||||
mat = [[1, 0, tx], [0, 1, ty], [0, 0, 1]]
|
||||
return torch.tensor(mat, dtype=torch.float32)
|
||||
|
||||
|
||||
def scale2d(sx, sy):
|
||||
mat = [[sx, 0, 0], [0, sy, 0], [0, 0, 1]]
|
||||
return torch.tensor(mat, dtype=torch.float32)
|
||||
|
||||
|
||||
def rotate2d(theta):
|
||||
mat = [
|
||||
[torch.cos(theta), torch.sin(-theta), 0],
|
||||
[torch.sin(theta), torch.cos(theta), 0],
|
||||
[0, 0, 1],
|
||||
]
|
||||
return torch.tensor(mat, dtype=torch.float32)
|
||||
|
||||
|
||||
class KarrasAugmentationPipeline:
|
||||
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1 / 8):
|
||||
self.a_prob = a_prob
|
||||
self.a_scale = a_scale
|
||||
self.a_aniso = a_aniso
|
||||
self.a_trans = a_trans
|
||||
|
||||
def __call__(self, image):
|
||||
h, w = image.size
|
||||
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
|
||||
|
||||
# x-flip
|
||||
a0 = torch.randint(2, []).float()
|
||||
mats.append(scale2d(1 - 2 * a0, 1))
|
||||
# y-flip
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a1 = torch.randint(2, []).float() * do
|
||||
mats.append(scale2d(1, 1 - 2 * a1))
|
||||
# scaling
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a2 = torch.randn([]) * do
|
||||
mats.append(scale2d(self.a_scale**a2, self.a_scale**a2))
|
||||
# rotation
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
||||
mats.append(rotate2d(-a3))
|
||||
# anisotropy
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
||||
a5 = torch.randn([]) * do
|
||||
mats.append(rotate2d(a4))
|
||||
mats.append(scale2d(self.a_aniso**a5, self.a_aniso**-a5))
|
||||
mats.append(rotate2d(-a4))
|
||||
# translation
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a6 = torch.randn([]) * do
|
||||
a7 = torch.randn([]) * do
|
||||
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
|
||||
|
||||
# form the transformation matrix and conditioning vector
|
||||
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
|
||||
mat = reduce(operator.matmul, mats)
|
||||
cond = torch.stack(
|
||||
[a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]
|
||||
)
|
||||
|
||||
# apply the transformation
|
||||
image_orig = np.array(image, dtype=np.float32) / 255
|
||||
if image_orig.ndim == 2:
|
||||
image_orig = image_orig[..., None]
|
||||
tf = transform.AffineTransform(mat.numpy())
|
||||
image = transform.warp(
|
||||
image_orig,
|
||||
tf.inverse,
|
||||
order=3,
|
||||
mode="reflect",
|
||||
cval=0.5,
|
||||
clip=False,
|
||||
preserve_range=True,
|
||||
)
|
||||
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
|
||||
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
|
||||
return image, image_orig, cond
|
||||
|
||||
|
||||
class KarrasAugmentWrapper(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
|
||||
if aug_cond is None:
|
||||
aug_cond = input.new_zeros([input.shape[0], 9])
|
||||
if mapping_cond is None:
|
||||
mapping_cond = aug_cond
|
||||
else:
|
||||
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
|
||||
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
|
||||
|
||||
def set_skip_stages(self, skip_stages):
|
||||
return self.inner_model.set_skip_stages(skip_stages)
|
||||
|
||||
def set_patch_size(self, patch_size):
|
||||
return self.inner_model.set_patch_size(patch_size)
|
1
imaginairy/vendored/k_diffusion/clip-commit-hash.txt
Normal file
1
imaginairy/vendored/k_diffusion/clip-commit-hash.txt
Normal file
@ -0,0 +1 @@
|
||||
1a0703dfb7d24d8806267c3e7ccc4caf67fd1331
|
0
imaginairy/vendored/k_diffusion/config.py
Normal file
0
imaginairy/vendored/k_diffusion/config.py
Normal file
0
imaginairy/vendored/k_diffusion/evaluation.py
Normal file
0
imaginairy/vendored/k_diffusion/evaluation.py
Normal file
145
imaginairy/vendored/k_diffusion/external.py
Normal file
145
imaginairy/vendored/k_diffusion/external.py
Normal file
@ -0,0 +1,145 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from . import sampling, utils
|
||||
|
||||
|
||||
class VDenoiser(nn.Module):
|
||||
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
|
||||
|
||||
def __init__(self, inner_model):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
||||
c_out = -sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
return sigma.atan() / math.pi * 2
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
return (t * math.pi / 2).tan()
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [
|
||||
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
|
||||
]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.inner_model(
|
||||
noised_input * c_in, self.sigma_to_t(sigma), **kwargs
|
||||
)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [
|
||||
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
|
||||
]
|
||||
return (
|
||||
self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out
|
||||
+ input * c_skip
|
||||
)
|
||||
|
||||
|
||||
class DiscreteSchedule(nn.Module):
|
||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
||||
levels."""
|
||||
|
||||
def __init__(self, sigmas, quantize):
|
||||
super().__init__()
|
||||
self.register_buffer("sigmas", sigmas)
|
||||
self.quantize = quantize
|
||||
|
||||
def get_sigmas(self, n=None):
|
||||
if n is None:
|
||||
return sampling.append_zero(self.sigmas.flip(0))
|
||||
t_max = len(self.sigmas) - 1
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
dists = torch.abs(sigma - self.sigmas[:, None])
|
||||
if quantize:
|
||||
return torch.argmin(dists, dim=0).view(sigma.shape)
|
||||
low_idx, high_idx = torch.sort(
|
||||
torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0
|
||||
)[0]
|
||||
low, high = self.sigmas[low_idx], self.sigmas[high_idx]
|
||||
w = (low - sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
return t.view(sigma.shape)
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
||||
return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]
|
||||
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
noise)."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
return c_out, c_in
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_out, c_in = [
|
||||
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
|
||||
]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return (eps - noise).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [
|
||||
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
|
||||
]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out
|
||||
|
||||
|
||||
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for OpenAI diffusion models."""
|
||||
|
||||
def __init__(
|
||||
self, model, diffusion, quantize=False, has_learned_sigmas=True, device="cpu"
|
||||
):
|
||||
alphas_cumprod = torch.tensor(
|
||||
diffusion.alphas_cumprod, device=device, dtype=torch.float32
|
||||
)
|
||||
super().__init__(model, alphas_cumprod, quantize=quantize)
|
||||
self.has_learned_sigmas = has_learned_sigmas
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
model_output = self.inner_model(*args, **kwargs)
|
||||
if self.has_learned_sigmas:
|
||||
return model_output.chunk(2, dim=1)[0]
|
||||
return model_output
|
||||
|
||||
|
||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models."""
|
||||
|
||||
def __init__(self, model, quantize=False, device="cpu"):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
115
imaginairy/vendored/k_diffusion/gns.py
Normal file
115
imaginairy/vendored/k_diffusion/gns.py
Normal file
@ -0,0 +1,115 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DDPGradientStatsHook:
|
||||
def __init__(self, ddp_module):
|
||||
try:
|
||||
ddp_module.register_comm_hook(self, self._hook_fn)
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"DDPGradientStatsHook does not support non-DDP wrapped modules"
|
||||
)
|
||||
self._clear_state()
|
||||
|
||||
def _clear_state(self):
|
||||
self.bucket_sq_norms_small_batch = []
|
||||
self.bucket_sq_norms_large_batch = []
|
||||
|
||||
@staticmethod
|
||||
def _hook_fn(self, bucket):
|
||||
buf = bucket.buffer()
|
||||
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
|
||||
fut = torch.distributed.all_reduce(
|
||||
buf, op=torch.distributed.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
|
||||
def callback(fut):
|
||||
buf = fut.value()[0]
|
||||
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
|
||||
return buf
|
||||
|
||||
return fut.then(callback)
|
||||
|
||||
def get_stats(self):
|
||||
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
|
||||
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
|
||||
self._clear_state()
|
||||
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
|
||||
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
|
||||
return stats[0].item(), stats[1].item()
|
||||
|
||||
|
||||
class GradientNoiseScale:
|
||||
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
|
||||
from _An Empirical Model of Large-Batch Training_,
|
||||
https://arxiv.org/abs/1812.06162).
|
||||
|
||||
Args:
|
||||
beta (float): The decay factor for the exponential moving averages used to
|
||||
calculate the gradient noise scale.
|
||||
Default: 0.9998
|
||||
eps (float): Added for numerical stability.
|
||||
Default: 1e-8
|
||||
"""
|
||||
|
||||
def __init__(self, beta=0.9998, eps=1e-8):
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
self.ema_sq_norm = 0.0
|
||||
self.ema_var = 0.0
|
||||
self.beta_cumprod = 1.0
|
||||
self.gradient_noise_scale = float("nan")
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the object as a :class:`dict`."""
|
||||
return dict(self.__dict__.items())
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the object's state.
|
||||
Args:
|
||||
state_dict (dict): object state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def update(
|
||||
self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch
|
||||
):
|
||||
"""Updates the state with a new batch's gradient statistics, and returns the
|
||||
current gradient noise scale.
|
||||
|
||||
Args:
|
||||
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
|
||||
per sample gradients.
|
||||
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
|
||||
per sample gradients.
|
||||
n_small_batch (int): The batch size of the individual microbatch or per sample
|
||||
gradients (1 if per sample).
|
||||
n_large_batch (int): The total batch size of the mean of the microbatch or
|
||||
per sample gradients.
|
||||
"""
|
||||
est_sq_norm = (
|
||||
n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch
|
||||
) / (n_large_batch - n_small_batch)
|
||||
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (
|
||||
1 / n_small_batch - 1 / n_large_batch
|
||||
)
|
||||
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
|
||||
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
|
||||
self.beta_cumprod *= self.beta
|
||||
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(
|
||||
self.ema_sq_norm, self.eps
|
||||
)
|
||||
return self.gradient_noise_scale
|
||||
|
||||
def get_gns(self):
|
||||
"""Returns the current gradient noise scale."""
|
||||
return self.gradient_noise_scale
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns the current (debiased) estimates of the squared mean gradient
|
||||
and gradient variance."""
|
||||
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (
|
||||
1 - self.beta_cumprod
|
||||
)
|
296
imaginairy/vendored/k_diffusion/layers.py
Normal file
296
imaginairy/vendored/k_diffusion/layers.py
Normal file
@ -0,0 +1,296 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from . import utils
|
||||
|
||||
# Karras et al. preconditioned denoiser
|
||||
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
"""A Karras et al. preconditioner for denoising diffusion models."""
|
||||
|
||||
def __init__(self, inner_model, sigma_data=1.0):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
||||
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [
|
||||
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
|
||||
]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [
|
||||
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
|
||||
]
|
||||
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DenoiserWithVariance(Denoiser):
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [
|
||||
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
|
||||
]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output, logvar = self.inner_model(
|
||||
noised_input * c_in, sigma, return_variance=True, **kwargs
|
||||
)
|
||||
logvar = utils.append_dims(logvar, model_output.ndim)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
|
||||
return losses.flatten(1).mean(1)
|
||||
|
||||
|
||||
# Residual blocks
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, *main, skip=None):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(*main)
|
||||
self.skip = skip if skip else nn.Identity()
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input) + self.skip(input)
|
||||
|
||||
|
||||
# Noise level (and other) conditioning
|
||||
|
||||
|
||||
class ConditionedModule(nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class UnconditionedModule(ConditionedModule):
|
||||
def __init__(self, module):
|
||||
self.module = module
|
||||
|
||||
def forward(self, input, cond):
|
||||
return self.module(input)
|
||||
|
||||
|
||||
class ConditionedSequential(nn.Sequential, ConditionedModule):
|
||||
def forward(self, input, cond):
|
||||
for module in self:
|
||||
if isinstance(module, ConditionedModule):
|
||||
input = module(input, cond)
|
||||
else:
|
||||
input = module(input)
|
||||
return input
|
||||
|
||||
|
||||
class ConditionedResidualBlock(ConditionedModule):
|
||||
def __init__(self, *main, skip=None):
|
||||
super().__init__()
|
||||
self.main = ConditionedSequential(*main)
|
||||
self.skip = skip if skip else nn.Identity()
|
||||
|
||||
def forward(self, input, cond):
|
||||
skip = (
|
||||
self.skip(input, cond)
|
||||
if isinstance(self.skip, ConditionedModule)
|
||||
else self.skip(input)
|
||||
)
|
||||
return self.main(input, cond) + skip
|
||||
|
||||
|
||||
class AdaGN(ConditionedModule):
|
||||
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key="cond"):
|
||||
super().__init__()
|
||||
self.num_groups = num_groups
|
||||
self.eps = eps
|
||||
self.cond_key = cond_key
|
||||
self.mapper = nn.Linear(feats_in, c_out * 2)
|
||||
|
||||
def forward(self, input, cond):
|
||||
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
|
||||
input = F.group_norm(input, self.num_groups, eps=self.eps)
|
||||
return torch.addcmul(
|
||||
utils.append_dims(bias, input.ndim),
|
||||
input,
|
||||
utils.append_dims(weight, input.ndim) + 1,
|
||||
)
|
||||
|
||||
|
||||
# Attention
|
||||
|
||||
|
||||
class SelfAttention2d(ConditionedModule):
|
||||
def __init__(self, c_in, n_head, norm, dropout_rate=0.0):
|
||||
super().__init__()
|
||||
assert c_in % n_head == 0
|
||||
self.norm_in = norm(c_in)
|
||||
self.n_head = n_head
|
||||
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
|
||||
self.out_proj = nn.Conv2d(c_in, c_in, 1)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, input, cond):
|
||||
n, c, h, w = input.shape
|
||||
qkv = self.qkv_proj(self.norm_in(input, cond))
|
||||
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = k.shape[3] ** -0.25
|
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
||||
att = self.dropout(att)
|
||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
|
||||
return input + self.out_proj(y)
|
||||
|
||||
|
||||
class CrossAttention2d(ConditionedModule):
|
||||
def __init__(
|
||||
self,
|
||||
c_dec,
|
||||
c_enc,
|
||||
n_head,
|
||||
norm_dec,
|
||||
dropout_rate=0.0,
|
||||
cond_key="cross",
|
||||
cond_key_padding="cross_padding",
|
||||
):
|
||||
super().__init__()
|
||||
assert c_dec % n_head == 0
|
||||
self.cond_key = cond_key
|
||||
self.cond_key_padding = cond_key_padding
|
||||
self.norm_enc = nn.LayerNorm(c_enc)
|
||||
self.norm_dec = norm_dec(c_dec)
|
||||
self.n_head = n_head
|
||||
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
|
||||
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
|
||||
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, input, cond):
|
||||
n, c, h, w = input.shape
|
||||
q = self.q_proj(self.norm_dec(input, cond))
|
||||
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
|
||||
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
|
||||
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
|
||||
k, v = kv.chunk(2, dim=1)
|
||||
scale = k.shape[3] ** -0.25
|
||||
att = (q * scale) @ (k.transpose(2, 3) * scale)
|
||||
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
|
||||
att = att.softmax(3)
|
||||
att = self.dropout(att)
|
||||
y = (att @ v).transpose(2, 3)
|
||||
y = y.contiguous().view([n, c, h, w])
|
||||
return input + self.out_proj(y)
|
||||
|
||||
|
||||
# Downsampling/upsampling
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
"cubic": [
|
||||
-0.01171875,
|
||||
-0.03515625,
|
||||
0.11328125,
|
||||
0.43359375,
|
||||
0.43359375,
|
||||
0.11328125,
|
||||
-0.03515625,
|
||||
-0.01171875,
|
||||
],
|
||||
"lanczos3": [
|
||||
0.003689131001010537,
|
||||
0.015056144446134567,
|
||||
-0.03399861603975296,
|
||||
-0.066637322306633,
|
||||
0.13550527393817902,
|
||||
0.44638532400131226,
|
||||
0.44638532400131226,
|
||||
0.13550527393817902,
|
||||
-0.066637322306633,
|
||||
-0.03399861603975296,
|
||||
0.015056144446134567,
|
||||
0.003689131001010537,
|
||||
],
|
||||
}
|
||||
_kernels["bilinear"] = _kernels["linear"]
|
||||
_kernels["bicubic"] = _kernels["cubic"]
|
||||
|
||||
|
||||
class Downsample2d(nn.Module):
|
||||
def __init__(self, kernel="linear", pad_mode="reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([_kernels[kernel]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros(
|
||||
[x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]
|
||||
)
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
return F.conv2d(x, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample2d(nn.Module):
|
||||
def __init__(self, kernel="linear", pad_mode="reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros(
|
||||
[x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]
|
||||
)
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
# Embeddings
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1.0):
|
||||
super().__init__()
|
||||
assert out_features % 2 == 0
|
||||
self.register_buffer(
|
||||
"weight", torch.randn([out_features // 2, in_features]) * std
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ self.weight.T
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
|
||||
# U-Nets
|
||||
|
||||
|
||||
class UNet(ConditionedModule):
|
||||
def __init__(self, d_blocks, u_blocks, skip_stages=0):
|
||||
super().__init__()
|
||||
self.d_blocks = nn.ModuleList(d_blocks)
|
||||
self.u_blocks = nn.ModuleList(u_blocks)
|
||||
self.skip_stages = skip_stages
|
||||
|
||||
def forward(self, input, cond):
|
||||
skips = []
|
||||
for block in self.d_blocks[self.skip_stages :]:
|
||||
input = block(input, cond)
|
||||
skips.append(input)
|
||||
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
|
||||
input = block(input, cond, skip if i > 0 else None)
|
||||
return input
|
1
imaginairy/vendored/k_diffusion/models/__init__.py
Normal file
1
imaginairy/vendored/k_diffusion/models/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .image_v1 import ImageDenoiserModelV1
|
305
imaginairy/vendored/k_diffusion/models/image_v1.py
Normal file
305
imaginairy/vendored/k_diffusion/models/image_v1.py
Normal file
@ -0,0 +1,305 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .. import layers, utils
|
||||
|
||||
|
||||
def orthogonal_(module):
|
||||
nn.init.orthogonal_(module.weight)
|
||||
return module
|
||||
|
||||
|
||||
class ResConvBlock(layers.ConditionedResidualBlock):
|
||||
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.0):
|
||||
skip = (
|
||||
None
|
||||
if c_in == c_out
|
||||
else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
|
||||
)
|
||||
super().__init__(
|
||||
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(c_in, c_mid, 3, padding=1),
|
||||
nn.Dropout2d(dropout_rate, inplace=True),
|
||||
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(c_mid, c_out, 3, padding=1),
|
||||
nn.Dropout2d(dropout_rate, inplace=True),
|
||||
skip=skip,
|
||||
)
|
||||
|
||||
|
||||
class DBlock(layers.ConditionedSequential):
|
||||
def __init__(
|
||||
self,
|
||||
n_layers,
|
||||
feats_in,
|
||||
c_in,
|
||||
c_mid,
|
||||
c_out,
|
||||
group_size=32,
|
||||
head_size=64,
|
||||
dropout_rate=0.0,
|
||||
downsample=False,
|
||||
self_attn=False,
|
||||
cross_attn=False,
|
||||
c_enc=0,
|
||||
):
|
||||
modules = [nn.Identity()]
|
||||
for i in range(n_layers):
|
||||
my_c_in = c_in if i == 0 else c_mid
|
||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
||||
modules.append(
|
||||
ResConvBlock(
|
||||
feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate
|
||||
)
|
||||
)
|
||||
if self_attn:
|
||||
norm = lambda c_in: layers.AdaGN(
|
||||
feats_in, c_in, max(1, my_c_out // group_size)
|
||||
)
|
||||
modules.append(
|
||||
layers.SelfAttention2d(
|
||||
my_c_out, max(1, my_c_out // head_size), norm, dropout_rate
|
||||
)
|
||||
)
|
||||
if cross_attn:
|
||||
norm = lambda c_in: layers.AdaGN(
|
||||
feats_in, c_in, max(1, my_c_out // group_size)
|
||||
)
|
||||
modules.append(
|
||||
layers.CrossAttention2d(
|
||||
my_c_out,
|
||||
c_enc,
|
||||
max(1, my_c_out // head_size),
|
||||
norm,
|
||||
dropout_rate,
|
||||
)
|
||||
)
|
||||
super().__init__(*modules)
|
||||
self.set_downsample(downsample)
|
||||
|
||||
def set_downsample(self, downsample):
|
||||
self[0] = layers.Downsample2d() if downsample else nn.Identity()
|
||||
return self
|
||||
|
||||
|
||||
class UBlock(layers.ConditionedSequential):
|
||||
def __init__(
|
||||
self,
|
||||
n_layers,
|
||||
feats_in,
|
||||
c_in,
|
||||
c_mid,
|
||||
c_out,
|
||||
group_size=32,
|
||||
head_size=64,
|
||||
dropout_rate=0.0,
|
||||
upsample=False,
|
||||
self_attn=False,
|
||||
cross_attn=False,
|
||||
c_enc=0,
|
||||
):
|
||||
modules = []
|
||||
for i in range(n_layers):
|
||||
my_c_in = c_in if i == 0 else c_mid
|
||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
||||
modules.append(
|
||||
ResConvBlock(
|
||||
feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate
|
||||
)
|
||||
)
|
||||
if self_attn:
|
||||
norm = lambda c_in: layers.AdaGN(
|
||||
feats_in, c_in, max(1, my_c_out // group_size)
|
||||
)
|
||||
modules.append(
|
||||
layers.SelfAttention2d(
|
||||
my_c_out, max(1, my_c_out // head_size), norm, dropout_rate
|
||||
)
|
||||
)
|
||||
if cross_attn:
|
||||
norm = lambda c_in: layers.AdaGN(
|
||||
feats_in, c_in, max(1, my_c_out // group_size)
|
||||
)
|
||||
modules.append(
|
||||
layers.CrossAttention2d(
|
||||
my_c_out,
|
||||
c_enc,
|
||||
max(1, my_c_out // head_size),
|
||||
norm,
|
||||
dropout_rate,
|
||||
)
|
||||
)
|
||||
modules.append(nn.Identity())
|
||||
super().__init__(*modules)
|
||||
self.set_upsample(upsample)
|
||||
|
||||
def forward(self, input, cond, skip=None):
|
||||
if skip is not None:
|
||||
input = torch.cat([input, skip], dim=1)
|
||||
return super().forward(input, cond)
|
||||
|
||||
def set_upsample(self, upsample):
|
||||
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
|
||||
return self
|
||||
|
||||
|
||||
class MappingNet(nn.Sequential):
|
||||
def __init__(self, feats_in, feats_out, n_layers=2):
|
||||
layers = []
|
||||
for i in range(n_layers):
|
||||
layers.append(
|
||||
orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out))
|
||||
)
|
||||
layers.append(nn.GELU())
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ImageDenoiserModelV1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
c_in,
|
||||
feats_in,
|
||||
depths,
|
||||
channels,
|
||||
self_attn_depths,
|
||||
cross_attn_depths=None,
|
||||
mapping_cond_dim=0,
|
||||
unet_cond_dim=0,
|
||||
cross_cond_dim=0,
|
||||
dropout_rate=0.0,
|
||||
patch_size=1,
|
||||
skip_stages=0,
|
||||
has_variance=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.c_in = c_in
|
||||
self.channels = channels
|
||||
self.unet_cond_dim = unet_cond_dim
|
||||
self.patch_size = patch_size
|
||||
self.has_variance = has_variance
|
||||
self.timestep_embed = layers.FourierFeatures(1, feats_in)
|
||||
if mapping_cond_dim > 0:
|
||||
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
|
||||
self.mapping = MappingNet(feats_in, feats_in)
|
||||
self.proj_in = nn.Conv2d(
|
||||
(c_in + unet_cond_dim) * self.patch_size**2,
|
||||
channels[max(0, skip_stages - 1)],
|
||||
1,
|
||||
)
|
||||
self.proj_out = nn.Conv2d(
|
||||
channels[max(0, skip_stages - 1)],
|
||||
c_in * self.patch_size**2 + (1 if self.has_variance else 0),
|
||||
1,
|
||||
)
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
if cross_cond_dim == 0:
|
||||
cross_attn_depths = [False] * len(self_attn_depths)
|
||||
d_blocks, u_blocks = [], []
|
||||
for i in range(len(depths)):
|
||||
my_c_in = channels[max(0, i - 1)]
|
||||
d_blocks.append(
|
||||
DBlock(
|
||||
depths[i],
|
||||
feats_in,
|
||||
my_c_in,
|
||||
channels[i],
|
||||
channels[i],
|
||||
downsample=i > skip_stages,
|
||||
self_attn=self_attn_depths[i],
|
||||
cross_attn=cross_attn_depths[i],
|
||||
c_enc=cross_cond_dim,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
)
|
||||
for i in range(len(depths)):
|
||||
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
|
||||
my_c_out = channels[max(0, i - 1)]
|
||||
u_blocks.append(
|
||||
UBlock(
|
||||
depths[i],
|
||||
feats_in,
|
||||
my_c_in,
|
||||
channels[i],
|
||||
my_c_out,
|
||||
upsample=i > skip_stages,
|
||||
self_attn=self_attn_depths[i],
|
||||
cross_attn=cross_attn_depths[i],
|
||||
c_enc=cross_cond_dim,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
)
|
||||
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
sigma,
|
||||
mapping_cond=None,
|
||||
unet_cond=None,
|
||||
cross_cond=None,
|
||||
cross_cond_padding=None,
|
||||
return_variance=False,
|
||||
):
|
||||
c_noise = sigma.log() / 4
|
||||
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
|
||||
mapping_cond_embed = (
|
||||
torch.zeros_like(timestep_embed)
|
||||
if mapping_cond is None
|
||||
else self.mapping_cond(mapping_cond)
|
||||
)
|
||||
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
|
||||
cond = {"cond": mapping_out}
|
||||
if unet_cond is not None:
|
||||
input = torch.cat([input, unet_cond], dim=1)
|
||||
if cross_cond is not None:
|
||||
cond["cross"] = cross_cond
|
||||
cond["cross_padding"] = cross_cond_padding
|
||||
if self.patch_size > 1:
|
||||
input = F.pixel_unshuffle(input, self.patch_size)
|
||||
input = self.proj_in(input)
|
||||
input = self.u_net(input, cond)
|
||||
input = self.proj_out(input)
|
||||
if self.has_variance:
|
||||
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
|
||||
if self.patch_size > 1:
|
||||
input = F.pixel_shuffle(input, self.patch_size)
|
||||
if self.has_variance and return_variance:
|
||||
return input, logvar
|
||||
return input
|
||||
|
||||
def set_skip_stages(self, skip_stages):
|
||||
self.proj_in = nn.Conv2d(
|
||||
self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1
|
||||
)
|
||||
self.proj_out = nn.Conv2d(
|
||||
self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1
|
||||
)
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
self.u_net.skip_stages = skip_stages
|
||||
for i, block in enumerate(self.u_net.d_blocks):
|
||||
block.set_downsample(i > skip_stages)
|
||||
for i, block in enumerate(reversed(self.u_net.u_blocks)):
|
||||
block.set_upsample(i > skip_stages)
|
||||
return self
|
||||
|
||||
def set_patch_size(self, patch_size):
|
||||
self.patch_size = patch_size
|
||||
self.proj_in = nn.Conv2d(
|
||||
(self.c_in + self.unet_cond_dim) * self.patch_size**2,
|
||||
self.channels[max(0, self.u_net.skip_stages - 1)],
|
||||
1,
|
||||
)
|
||||
self.proj_out = nn.Conv2d(
|
||||
self.channels[max(0, self.u_net.skip_stages - 1)],
|
||||
self.c_in * self.patch_size**2 + (1 if self.has_variance else 0),
|
||||
1,
|
||||
)
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
1
imaginairy/vendored/k_diffusion/readme.txt
Normal file
1
imaginairy/vendored/k_diffusion/readme.txt
Normal file
@ -0,0 +1 @@
|
||||
vendored from git@github.com:crowsonkb/k-diffusion.git
|
333
imaginairy/vendored/k_diffusion/sampling.py
Normal file
333
imaginairy/vendored/k_diffusion/sampling.py
Normal file
@ -0,0 +1,333 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from scipy import integrate
|
||||
from torchdiffeq import odeint
|
||||
from tqdm.auto import tqdm, trange
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
ramp = torch.linspace(0, 1, n)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return append_zero(sigmas).to(device)
|
||||
|
||||
|
||||
def get_sigmas_exponential(n, sigma_min, sigma_max, device="cpu"):
|
||||
"""Constructs an exponential noise schedule."""
|
||||
sigmas = torch.linspace(
|
||||
math.log(sigma_max), math.log(sigma_min), n, device=device
|
||||
).exp()
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device="cpu"):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
t = torch.linspace(1, eps_s, n, device=device)
|
||||
sigmas = torch.sqrt(torch.exp(beta_d * t**2 / 2 + beta_min * t) - 1)
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / sigma
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
sigma_up = (
|
||||
sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
s_churn=0.0,
|
||||
s_tmin=0.0,
|
||||
s_tmax=float("inf"),
|
||||
s_noise=1.0,
|
||||
):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = (
|
||||
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
|
||||
if s_tmin <= sigmas[i] <= s_tmax
|
||||
else 0.0
|
||||
)
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback(
|
||||
{
|
||||
"x": x,
|
||||
"i": i,
|
||||
"sigma": sigmas[i],
|
||||
"sigma_hat": sigma_hat,
|
||||
"denoised": denoised,
|
||||
}
|
||||
)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(
|
||||
model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback(
|
||||
{
|
||||
"x": x,
|
||||
"i": i,
|
||||
"sigma": sigmas[i],
|
||||
"sigma_hat": sigmas[i],
|
||||
"denoised": denoised,
|
||||
}
|
||||
)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
s_churn=0.0,
|
||||
s_tmin=0.0,
|
||||
s_tmax=float("inf"),
|
||||
s_noise=1.0,
|
||||
):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = (
|
||||
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
|
||||
if s_tmin <= sigmas[i] <= s_tmax
|
||||
else 0.0
|
||||
)
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback(
|
||||
{
|
||||
"x": x,
|
||||
"i": i,
|
||||
"sigma": sigmas[i],
|
||||
"sigma_hat": sigma_hat,
|
||||
"denoised": denoised,
|
||||
}
|
||||
)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
s_churn=0.0,
|
||||
s_tmin=0.0,
|
||||
s_tmax=float("inf"),
|
||||
s_noise=1.0,
|
||||
):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = (
|
||||
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
|
||||
if s_tmin <= sigmas[i] <= s_tmax
|
||||
else 0.0
|
||||
)
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback(
|
||||
{
|
||||
"x": x,
|
||||
"i": i,
|
||||
"sigma": sigmas[i],
|
||||
"sigma_hat": sigma_hat,
|
||||
"denoised": denoised,
|
||||
}
|
||||
)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(
|
||||
model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
):
|
||||
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback(
|
||||
{
|
||||
"x": x,
|
||||
"i": i,
|
||||
"sigma": sigmas[i],
|
||||
"sigma_hat": sigmas[i],
|
||||
"denoised": denoised,
|
||||
}
|
||||
)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f"Order {order} too high for step {i}")
|
||||
|
||||
def fn(tau):
|
||||
prod = 1.0
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback(
|
||||
{
|
||||
"x": x,
|
||||
"i": i,
|
||||
"sigma": sigmas[i],
|
||||
"sigma_hat": sigmas[i],
|
||||
"denoised": denoised,
|
||||
}
|
||||
)
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [
|
||||
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
|
||||
for j in range(cur_order)
|
||||
]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_likelihood(
|
||||
model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4
|
||||
):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
v = torch.randint_like(x, 2) * 2 - 1
|
||||
fevals = 0
|
||||
|
||||
def ode_fn(sigma, x):
|
||||
nonlocal fevals
|
||||
with torch.enable_grad():
|
||||
x = x[0].detach().requires_grad_()
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
d = to_d(x, sigma, denoised)
|
||||
fevals += 1
|
||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
||||
d_ll = (v * grad).flatten(1).sum(1)
|
||||
return d.detach(), d_ll
|
||||
|
||||
x_min = x, x.new_zeros([x.shape[0]])
|
||||
t = x.new_tensor([sigma_min, sigma_max])
|
||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method="dopri5")
|
||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
||||
ll_prior = (
|
||||
torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
||||
)
|
||||
return ll_prior + delta_ll, {"fevals": fevals}
|
221
imaginairy/vendored/k_diffusion/sampling.py-e
Normal file
221
imaginairy/vendored/k_diffusion/sampling.py-e
Normal file
@ -0,0 +1,221 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from scipy import integrate
|
||||
from torchdiffeq import odeint
|
||||
from tqdm.auto import tqdm, trange
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
ramp = torch.linspace(0, 1, n)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return append_zero(sigmas).to(device)
|
||||
|
||||
|
||||
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
|
||||
"""Constructs an exponential noise schedule."""
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
t = torch.linspace(1, eps_s, n, device=device)
|
||||
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
|
||||
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f'Order {order} too high for step {i}')
|
||||
def fn(tau):
|
||||
prod = 1.
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
v = torch.randint_like(x, 2) * 2 - 1
|
||||
fevals = 0
|
||||
def ode_fn(sigma, x):
|
||||
nonlocal fevals
|
||||
with torch.enable_grad():
|
||||
x = x[0].detach().requires_grad_()
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
d = to_d(x, sigma, denoised)
|
||||
fevals += 1
|
||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
||||
d_ll = (v * grad).flatten(1).sum(1)
|
||||
return d.detach(), d_ll
|
||||
x_min = x, x.new_zeros([x.shape[0]])
|
||||
t = x.new_tensor([sigma_min, sigma_max])
|
||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
||||
return ll_prior + delta_ll, {'fevals': fevals}
|
385
imaginairy/vendored/k_diffusion/utils.py
Normal file
385
imaginairy/vendored/k_diffusion/utils.py
Normal file
@ -0,0 +1,385 @@
|
||||
import hashlib
|
||||
import math
|
||||
import shutil
|
||||
import urllib
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import nn, optim
|
||||
from torch.utils import data
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
|
||||
def from_pil_image(x):
|
||||
"""Converts from a PIL image to a tensor."""
|
||||
x = TF.to_tensor(x)
|
||||
if x.ndim == 2:
|
||||
x = x[..., None]
|
||||
return x * 2 - 1
|
||||
|
||||
|
||||
def to_pil_image(x):
|
||||
"""Converts from a tensor to a PIL image."""
|
||||
if x.ndim == 4:
|
||||
assert x.shape[0] == 1
|
||||
x = x[0]
|
||||
if x.shape[0] == 1:
|
||||
x = x[0]
|
||||
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
|
||||
|
||||
|
||||
def hf_datasets_augs_helper(examples, transform, image_key, mode="RGB"):
|
||||
"""Apply passed in transforms for HuggingFace Datasets."""
|
||||
images = [transform(image.convert(mode)) for image in examples[image_key]]
|
||||
return {image_key: images}
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def n_params(module):
|
||||
"""Returns the number of trainable parameters in a module."""
|
||||
return sum(p.numel() for p in module.parameters())
|
||||
|
||||
|
||||
def download_file(path, url, digest=None):
|
||||
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not path.exists():
|
||||
with urllib.request.urlopen(url) as response, open(path, "wb") as f:
|
||||
shutil.copyfileobj(response, f)
|
||||
if digest is not None:
|
||||
file_digest = hashlib.sha256(open(path, "rb").read()).hexdigest()
|
||||
if digest != file_digest:
|
||||
raise OSError(f"hash of {path} (url: {url}) failed to validate")
|
||||
return path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def train_mode(model, mode=True):
|
||||
"""A context manager that places a model into training mode and restores
|
||||
the previous mode on exit."""
|
||||
modes = [module.training for module in model.modules()]
|
||||
try:
|
||||
yield model.train(mode)
|
||||
finally:
|
||||
for i, module in enumerate(model.modules()):
|
||||
module.training = modes[i]
|
||||
|
||||
|
||||
def eval_mode(model):
|
||||
"""A context manager that places a model into evaluation mode and restores
|
||||
the previous mode on exit."""
|
||||
return train_mode(model, False)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(model, averaged_model, decay):
|
||||
"""Incorporates updated model parameters into an exponential moving averaged
|
||||
version of a model. It should be called after each optimizer step."""
|
||||
model_params = dict(model.named_parameters())
|
||||
averaged_params = dict(averaged_model.named_parameters())
|
||||
assert model_params.keys() == averaged_params.keys()
|
||||
|
||||
for name, param in model_params.items():
|
||||
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
|
||||
|
||||
model_buffers = dict(model.named_buffers())
|
||||
averaged_buffers = dict(averaged_model.named_buffers())
|
||||
assert model_buffers.keys() == averaged_buffers.keys()
|
||||
|
||||
for name, buf in model_buffers.items():
|
||||
averaged_buffers[name].copy_(buf)
|
||||
|
||||
|
||||
class EMAWarmup:
|
||||
"""Implements an EMA warmup using an inverse decay schedule.
|
||||
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
max_value (float): The maximum EMA decay rate. Default: 1.
|
||||
start_at (int): The epoch to start averaging at. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inv_gamma=1.0,
|
||||
power=1.0,
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
start_at=0,
|
||||
last_epoch=0,
|
||||
):
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.start_at = start_at
|
||||
self.last_epoch = last_epoch
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the class as a :class:`dict`."""
|
||||
return dict(self.__dict__.items())
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the class's state.
|
||||
Args:
|
||||
state_dict (dict): scaler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_value(self):
|
||||
"""Gets the current EMA decay rate."""
|
||||
epoch = max(0, self.last_epoch - self.start_at)
|
||||
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
|
||||
return 0.0 if epoch < 0 else min(self.max_value, max(self.min_value, value))
|
||||
|
||||
def step(self):
|
||||
"""Updates the step count."""
|
||||
self.last_epoch += 1
|
||||
|
||||
|
||||
class InverseLR(optim.lr_scheduler._LRScheduler):
|
||||
"""Implements an inverse decay learning rate schedule with an optional exponential
|
||||
warmup. When last_epoch=-1, sets initial lr as lr.
|
||||
inv_gamma is the number of steps/epochs required for the learning rate to decay to
|
||||
(1 / 2)**power of its original value.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
|
||||
power (float): Exponential factor of learning rate decay. Default: 1.
|
||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||
Default: 0.
|
||||
min_lr (float): The minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
inv_gamma=1.0,
|
||||
power=1.0,
|
||||
warmup=0.0,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
if not 0.0 <= warmup < 1:
|
||||
raise ValueError("Invalid value for warmup")
|
||||
self.warmup = warmup
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`."
|
||||
)
|
||||
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
|
||||
return [
|
||||
warmup * max(self.min_lr, base_lr * lr_mult) for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
||||
class ExponentialLR(optim.lr_scheduler._LRScheduler):
|
||||
"""Implements an exponential learning rate schedule with an optional exponential
|
||||
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
|
||||
continuously by decay (default 0.5) every num_steps steps.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
num_steps (float): The number of steps to decay the learning rate by decay in.
|
||||
decay (float): The factor by which to decay the learning rate every num_steps
|
||||
steps. Default: 0.5.
|
||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||
Default: 0.
|
||||
min_lr (float): The minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
num_steps,
|
||||
decay=0.5,
|
||||
warmup=0.0,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
self.num_steps = num_steps
|
||||
self.decay = decay
|
||||
if not 0.0 <= warmup < 1:
|
||||
raise ValueError("Invalid value for warmup")
|
||||
self.warmup = warmup
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`."
|
||||
)
|
||||
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
|
||||
return [
|
||||
warmup * max(self.min_lr, base_lr * lr_mult) for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
||||
def rand_log_normal(shape, loc=0.0, scale=1.0, device="cpu", dtype=torch.float32):
|
||||
"""Draws samples from an lognormal distribution."""
|
||||
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
|
||||
|
||||
|
||||
def rand_log_logistic(
|
||||
shape,
|
||||
loc=0.0,
|
||||
scale=1.0,
|
||||
min_value=0.0,
|
||||
max_value=float("inf"),
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
):
|
||||
"""Draws samples from an optionally truncated log-logistic distribution."""
|
||||
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
|
||||
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
|
||||
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
|
||||
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
|
||||
u = (
|
||||
torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf)
|
||||
+ min_cdf
|
||||
)
|
||||
return u.logit().mul(scale).add(loc).exp().to(dtype)
|
||||
|
||||
|
||||
def rand_log_uniform(shape, min_value, max_value, device="cpu", dtype=torch.float32):
|
||||
"""Draws samples from an log-uniform distribution."""
|
||||
min_value = math.log(min_value)
|
||||
max_value = math.log(max_value)
|
||||
return (
|
||||
torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value)
|
||||
+ min_value
|
||||
).exp()
|
||||
|
||||
|
||||
def rand_v_diffusion(
|
||||
shape,
|
||||
sigma_data=1.0,
|
||||
min_value=0.0,
|
||||
max_value=float("inf"),
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
):
|
||||
"""Draws samples from a truncated v-diffusion training timestep distribution."""
|
||||
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
|
||||
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
|
||||
u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
|
||||
return torch.tan(u * math.pi / 2) * sigma_data
|
||||
|
||||
|
||||
class FolderOfImages(data.Dataset):
|
||||
"""Recursively finds all images in a directory. It does not support
|
||||
classes/targets."""
|
||||
|
||||
IMG_EXTENSIONS = {
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".ppm",
|
||||
".bmp",
|
||||
".pgm",
|
||||
".tif",
|
||||
".tiff",
|
||||
".webp",
|
||||
}
|
||||
|
||||
def __init__(self, root, transform=None):
|
||||
super().__init__()
|
||||
self.root = Path(root)
|
||||
self.transform = nn.Identity() if transform is None else transform
|
||||
self.paths = sorted(
|
||||
path
|
||||
for path in self.root.rglob("*")
|
||||
if path.suffix.lower() in self.IMG_EXTENSIONS
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, key):
|
||||
path = self.paths[key]
|
||||
with open(path, "rb") as f:
|
||||
image = Image.open(f).convert("RGB")
|
||||
image = self.transform(image)
|
||||
return (image,)
|
||||
|
||||
|
||||
class CSVLogger:
|
||||
def __init__(self, filename, columns):
|
||||
self.filename = Path(filename)
|
||||
self.columns = columns
|
||||
if self.filename.exists():
|
||||
self.file = open(self.filename, "a")
|
||||
else:
|
||||
self.file = open(self.filename, "w")
|
||||
self.write(*self.columns)
|
||||
|
||||
def write(self, *args):
|
||||
print(*args, sep=",", file=self.file, flush=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tf32_mode(cudnn=None, matmul=None):
|
||||
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
|
||||
cudnn_old = torch.backends.cudnn.allow_tf32
|
||||
matmul_old = torch.backends.cuda.matmul.allow_tf32
|
||||
try:
|
||||
if cudnn is not None:
|
||||
torch.backends.cudnn.allow_tf32 = cudnn
|
||||
if matmul is not None:
|
||||
torch.backends.cuda.matmul.allow_tf32 = matmul
|
||||
yield
|
||||
finally:
|
||||
if cudnn is not None:
|
||||
torch.backends.cudnn.allow_tf32 = cudnn_old
|
||||
if matmul is not None:
|
||||
torch.backends.cuda.matmul.allow_tf32 = matmul_old
|
0
imaginairy/vendored/k_diffusion/version.py
Normal file
0
imaginairy/vendored/k_diffusion/version.py
Normal file
@ -32,7 +32,7 @@ black==22.8.0
|
||||
# via -r requirements-dev.in
|
||||
cachetools==5.2.0
|
||||
# via google-auth
|
||||
certifi==2022.6.15.1
|
||||
certifi==2022.6.15.2
|
||||
# via requests
|
||||
charset-normalizer==2.1.1
|
||||
# via
|
||||
@ -98,7 +98,7 @@ huggingface-hub==0.9.1
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
idna==3.3
|
||||
idna==3.4
|
||||
# via
|
||||
# requests
|
||||
# yarl
|
||||
@ -209,7 +209,7 @@ platformdirs==2.5.2
|
||||
# pylint
|
||||
pluggy==1.0.0
|
||||
# via pytest
|
||||
protobuf==3.19.4
|
||||
protobuf==3.19.5
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
@ -257,7 +257,7 @@ pyyaml==6.0
|
||||
# transformers
|
||||
realesrgan==0.2.5.0
|
||||
# via imaginAIry (setup.py)
|
||||
regex==2022.9.11
|
||||
regex==2022.9.13
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
@ -285,6 +285,7 @@ scipy==1.9.1
|
||||
# filterpy
|
||||
# gfpgan
|
||||
# scikit-image
|
||||
# torchdiffeq
|
||||
six==1.16.0
|
||||
# via
|
||||
# google-auth
|
||||
@ -292,7 +293,7 @@ six==1.16.0
|
||||
# python-dateutil
|
||||
snowballstemmer==2.2.0
|
||||
# via pydocstyle
|
||||
tb-nightly==2.11.0a20220912
|
||||
tb-nightly==2.11.0a20220913
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
@ -327,8 +328,11 @@ torch==1.12.1
|
||||
# kornia
|
||||
# pytorch-lightning
|
||||
# realesrgan
|
||||
# torchdiffeq
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
torchdiffeq==0.2.3
|
||||
# via imaginAIry (setup.py)
|
||||
torchmetrics==0.6.0
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
|
Loading…
Reference in New Issue
Block a user