feature: Mac M1 Support out of box

- auto-downloads checkpoints
- fixes bug in torch
- autoset environment variable
pull/1/head
Bryce 2 years ago
parent 66c640ce7b
commit 6d1d0622eb

4
.gitignore vendored

@ -10,3 +10,7 @@ downloads
.coveragerc
/imaginairy/data/stable-diffusion-v1.yaml
/imaginairy/data/stable-diffusion-v1-4.ckpt
build
dist
**/*.ckpt
**/*.egg-info

@ -1,15 +1,16 @@
#!/usr/bin/env python
import os
os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1")
import click
from imaginairy.imagine import imagine as imagine_f
from imaginairy.imagine import imagine_image_files
from imaginairy.schema import ImaginePrompt
@click.command()
@click.argument(
"prompt_texts", default=None, help="text to render to an image", nargs=-1
)
@click.argument("prompt_texts", default=None, nargs=-1)
@click.option("--outdir", default="./outputs", help="where to write results to")
@click.option(
"-r", "--repeats", default=1, type=int, help="How many times to repeat the renders"
@ -55,6 +56,7 @@ def imagine_cmd(
sampler_type,
ddim_eta,
):
"""Render an image"""
prompts = []
for _ in range(repeats):
for prompt_text in prompt_texts:
@ -66,12 +68,12 @@ def imagine_cmd(
height=height,
width=width,
prompt_strength=prompt_strength,
upscale=True,
fix_faces=True,
upscale=False,
fix_faces=False,
)
prompts.append(prompt)
imagine_f(
imagine_image_files(
prompts,
outdir=outdir,
ddim_eta=ddim_eta,

@ -0,0 +1,69 @@
model:
base_learning_rate: 1.0e-04
target: imaginairy.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
scheduler_config: # 10000 warmup steps
target: imaginairy.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: imaginairy.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: imaginairy.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: imaginairy.modules.clip_embedders.FrozenCLIPEmbedder

@ -13,19 +13,35 @@ from einops import rearrange
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from torch import autocast
from transformers import cached_path
from imaginairy.models.diffusion.ddim import DDIMSampler
from imaginairy.models.diffusion.plms import PLMSSampler
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import get_device, instantiate_from_config
from imaginairy.utils import (
get_device,
instantiate_from_config,
fix_torch_nn_layer_norm,
)
# from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# from transformers import AutoFeatureExtractor
# load safety model
# safety_model_id = "CompVis/stable-diffusion-safety-checker"
# safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
# safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
LIB_PATH = os.path.dirname(__file__)
logger = logging.getLogger(__name__)
def load_model_from_config(config, ckpt):
logger.info(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
def load_model_from_config(config):
ckpt_path = cached_path(
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
)
logger.info(f"Loading model from {ckpt_path}")
pl_sd = torch.load(ckpt_path, map_location="cpu")
if "global_step" in pl_sd:
logger.info(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
@ -36,7 +52,7 @@ def load_model_from_config(config, ckpt):
if len(u) > 0:
logger.info(f"unexpected keys: {u}")
model.cuda()
model.to(get_device())
model.eval()
return model
@ -57,10 +73,9 @@ def load_img(path, max_height=512, max_width=512):
@lru_cache()
def load_model():
config = "data/stable-diffusion-v1.yaml"
ckpt = "data/stable-diffusion-v1-4.ckpt"
config = "configs/stable-diffusion-v1.yaml"
config = OmegaConf.load(f"{LIB_PATH}/{config}")
model = load_model_from_config(config, f"{LIB_PATH}/{ckpt}")
model = load_model_from_config(config)
model = model.to(get_device())
return model
@ -73,7 +88,7 @@ def imagine_image_files(
downsampling_factor=8,
precision="autocast",
ddim_eta=0.0,
record_steps=False
record_steps=False,
):
big_path = os.path.join(outdir, "upscaled")
os.makedirs(outdir, exist_ok=True)
@ -94,6 +109,7 @@ def imagine_image_files(
Image.fromarray(pred_x0.astype(np.uint8)).save(
os.path.join(steps_path, filename)
)
img_callback = _record_steps if record_steps else None
for result in imagine_images(
prompts,
@ -131,7 +147,7 @@ def imagine_images(
_img_callback = None
precision_scope = autocast if precision == "autocast" else nullcontext
with (torch.no_grad(), precision_scope("cuda")):
with (torch.no_grad(), precision_scope("cuda"), fix_torch_nn_layer_norm()):
for prompt in prompts:
seed_everything(prompt.seed)
uc = None
@ -145,6 +161,7 @@ def imagine_images(
]
)
if img_callback:
def _img_callback(samples, i):
img_callback(samples, i, model, prompt)
@ -159,9 +176,7 @@ def imagine_images(
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(
ddim_num_steps=ddim_steps, ddim_eta=ddim_eta
)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
t_enc = int(generation_strength * ddim_steps)
init_image, w, h = load_img(prompt.init_image)
@ -173,7 +188,8 @@ def imagine_images(
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
noised_init_latent, torch.tensor([t_enc]).to(get_device()),
noised_init_latent,
torch.tensor([t_enc]).to(get_device()),
)
_img_callback(noised_init_latent, 0)

@ -1,8 +1,10 @@
import logging
import numpy as np
import pytorch_lightning as pl
import torch
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
import torch.nn as nn
from einops import rearrange
from imaginairy.modules.diffusionmodules.model import Encoder, Decoder
from imaginairy.modules.distributions import DiagonalGaussianDistribution
@ -11,6 +13,150 @@ from imaginairy.utils import instantiate_from_config
logger = logging.getLogger(__name__)
class VectorQuantizer(nn.Module):
"""
Improved version over original VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
https://github.com/CompVis/taming-transformers/blob/141eb746f567a731f71cd703796d4d53a323f45f/taming/modules/vqvae/quantize.py#L213
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(
self,
n_e,
e_dim,
beta,
remap=None,
unknown_index="random",
sane_index_shape=False,
legacy=True,
):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits == False, "Only for interface compatible with Gumbel"
assert return_logits == False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, "b c h w -> b h w c").contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2
* torch.einsum(
"bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
)
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
(z_q - z.detach()) ** 2
)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(
z.shape[0], -1
) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class VQModel(pl.LightningModule):
def __init__(
self,

@ -30,9 +30,7 @@ class DDIMSampler:
attr = attr.to(torch.float32).to(torch.device(self.device_available))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0
):
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
@ -256,7 +254,7 @@ class DDIMSampler:
noise_dropout=0.0,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
loss_function=None
loss_function=None,
):
b, *_, device = *x.shape, x.device
@ -268,8 +266,12 @@ class DDIMSampler:
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
# with torch.no_grad():
noise_pred_uncond, noise_pred = self.model.apply_model(x_in, t_in, c_in).chunk(2)
noise_pred = noise_pred_uncond + unconditional_guidance_scale * (noise_pred - noise_pred_uncond)
noise_pred_uncond, noise_pred = self.model.apply_model(
x_in, t_in, c_in
).chunk(2)
noise_pred = noise_pred_uncond + unconditional_guidance_scale * (
noise_pred - noise_pred_uncond
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
@ -337,7 +339,7 @@ class DDIMSampler:
use_original_steps=False,
img_callback=None,
score_corrector=None,
temperature=1.0
temperature=1.0,
):
timesteps = (
@ -367,7 +369,7 @@ class DDIMSampler:
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
temperature=temperature
temperature=temperature,
)
# original_loss = ((x_dec - x_latent).abs().mean()*70)
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())

@ -1,8 +1,12 @@
import os
import importlib
import logging
from contextlib import contextmanager
from functools import lru_cache
from typing import List, Optional
import torch
from torch import Tensor
logger = logging.getLogger(__name__)
@ -38,3 +42,49 @@ def get_obj_from_str(string, reload=False):
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
from torch.overrides import has_torch_function_variadic, handle_torch_function
def _fixed_layer_norm(
input: Tensor,
normalized_shape: List[int],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
) -> Tensor:
r"""Applies Layer Normalization for last certain number of dimensions.
See :class:`~torch.nn.LayerNorm` for details.
"""
if has_torch_function_variadic(input, weight, bias):
return handle_torch_function(
_fixed_layer_norm,
(input, weight, bias),
input,
normalized_shape,
weight=weight,
bias=bias,
eps=eps,
)
return torch.layer_norm(
input.contiguous(),
normalized_shape,
weight,
bias,
eps,
torch.backends.cudnn.enabled,
)
@contextmanager
def fix_torch_nn_layer_norm():
"""https://github.com/CompVis/stable-diffusion/issues/25#issuecomment-1221416526"""
from torch.nn import functional
orig_function = functional.layer_norm
functional.layer_norm = _fixed_layer_norm
try:
yield
finally:
functional.layer_norm = orig_function

@ -1,36 +1,38 @@
from setuptools import setup, find_packages
setup(
name='imaginairy',
version='0.0.1',
description='AI imagined images.',
packages=find_packages("imaginairy"),
name="imaginairy",
version="0.0.1",
description="AI imagined images.",
packages=find_packages(include=("imaginairy", "imaginairy.*")),
entry_points={
'console_scripts': ['imagine=imaginairy.cmds:imagine_cmd'],
"console_scripts": ["imagine=imaginairy.cmds:imagine_cmd"],
},
package_data={"imaginairy": ["configs/*.yaml"]},
install_requires=[
'click',
'torch',
'numpy',
'tqdm',
"click",
"torch",
"numpy",
"tqdm",
# "albumentations==0.4.3",
"diffusers",
# "diffusers",
# opencv-python==4.1.2.30
"pudb==2019.2",
# "pudb==2019.2",
# "invisible-watermark",
"imageio==2.9.0",
"imageio-ffmpeg==0.4.2",
# "imageio-ffmpeg==0.4.2",
"pytorch-lightning==1.4.2",
"omegaconf==2.1.1",
"test-tube>=0.7.5",
"streamlit>=0.73.1",
# "test-tube>=0.7.5",
# "streamlit>=0.73.1",
"einops==0.3.0",
"torch-fidelity==0.3.0",
# "torch-fidelity==0.3.0",
"transformers==4.19.2",
"torchmetrics==0.6.0",
"torchvision>=0.13.1",
"kornia==0.6",
"realesrgan",
"-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers",
"-e git+https://github.com/openai/CLIP.git@main#egg=clip",
# "realesrgan",
# "-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers",
"clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1#egg=clip",
],
)

@ -1,3 +1,3 @@
import os.path
TESTS_FOLDER = os.path.dirname(__file__)
TESTS_FOLDER = os.path.dirname(__file__)

@ -4,9 +4,11 @@ from . import TESTS_FOLDER
def test_imagine():
prompt = ImaginePrompt("a scenic landscape", width=512, height=256, steps=20, seed=1)
prompt = ImaginePrompt(
"a scenic landscape", width=512, height=256, steps=20, seed=1
)
result = next(imagine_images(prompt))
assert result.md5() == '4c5957c498881d365cfcf13014812af0'
assert result.md5() == "4c5957c498881d365cfcf13014812af0"
result.img.save(f"{TESTS_FOLDER}/test_output/scenic_landscape.png")
@ -22,24 +24,28 @@ def test_img_to_img():
sampler_type="DDIM",
)
out_folder = f"{TESTS_FOLDER}/test_output"
out_folder = '/home/bryce/Mounts/drennanfiles/art/tests'
out_folder = "/home/bryce/Mounts/drennanfiles/art/tests"
imagine_image_files(prompt, outdir=out_folder)
def test_img_to_file():
prompt = ImaginePrompt(
[WeightedPrompt("an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo")],
[
WeightedPrompt(
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo"
)
],
# init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
init_image_strength=0.5,
width=512+64,
height=512-64,
width=512 + 64,
height=512 - 64,
steps=50,
# seed=2,
sampler_type="PLMS",
upscale=True
upscale=True,
)
out_folder = f"{TESTS_FOLDER}/test_output"
out_folder = '/home/bryce/Mounts/drennanfiles/art/tests'
out_folder = "/home/bryce/Mounts/drennanfiles/art/tests"
imagine_image_files(prompt, outdir=out_folder)
@ -48,13 +54,13 @@ def test_img_conditioning():
"photo",
init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
init_image_strength=0.5,
width=512+64,
height=512-64,
width=512 + 64,
height=512 - 64,
steps=50,
# seed=2,
sampler_type="PLMS",
upscale=True
upscale=True,
)
out_folder = f"{TESTS_FOLDER}/test_output"
out_folder = '/home/bryce/Mounts/drennanfiles/art/tests'
imagine_image_files(prompt, outdir=out_folder, record_steps=True)
out_folder = "/home/bryce/Mounts/drennanfiles/art/tests"
imagine_image_files(prompt, outdir=out_folder, record_steps=True)

Loading…
Cancel
Save