mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
style: fix lint issues
This commit is contained in:
parent
38c7f88950
commit
69af07ab67
@ -213,7 +213,7 @@ def imagine(
|
||||
ddim_steps = int(prompt.steps / generation_strength)
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
|
||||
try:
|
||||
init_image, _, h = pillow_fit_image_within(
|
||||
init_image = pillow_fit_image_within(
|
||||
prompt.init_image,
|
||||
max_height=prompt.height,
|
||||
max_width=prompt.width,
|
||||
|
@ -68,7 +68,7 @@ def enhance_faces(img, fidelity=0):
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = net(cropped_face_t, w=fidelity, adain=True)[0]
|
||||
output = net(cropped_face_t, w=fidelity, adain=True)[0] # noqa
|
||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -16,7 +16,7 @@ def pillow_fit_image_within(image: PIL.Image.Image, max_height=512, max_width=51
|
||||
w, h = int(w * resize_ratio), int(h * resize_ratio)
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((w, h), resample=Image.Resampling.NEAREST)
|
||||
return image, w, h
|
||||
return image
|
||||
|
||||
|
||||
def pillow_img_to_torch_image(img: PIL.Image.Image):
|
||||
|
@ -85,7 +85,7 @@ class LinearAttention(nn.Module):
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
b, c, h, w = x.shape # noqa
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
@ -126,7 +126,7 @@ class SpatialSelfAttention(nn.Module):
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
b, c, h, w = q.shape # noqa
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
@ -178,9 +178,9 @@ class CrossAttention(nn.Module):
|
||||
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, "b ... -> b (...)")
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
_max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
sim.masked_fill_(~mask, _max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
@ -189,7 +189,7 @@ class CrossAttention(nn.Module):
|
||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
def forward_cuda(self, x, context=None, mask=None):
|
||||
def forward_cuda(self, x, context=None, mask=None): # noqa
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
@ -258,7 +258,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
checkpoint=True, # noqa
|
||||
):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
@ -326,7 +326,7 @@ class SpatialTransformer(nn.Module):
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
b, c, h, w = x.shape # noqa
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
|
@ -17,12 +17,13 @@ class AutoencoderKL(pl.LightningModule):
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
ignore_keys=None,
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
super().__init__()
|
||||
ignore_keys = [] if ignore_keys is None else ignore_keys
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
@ -32,20 +33,21 @@ class AutoencoderKL(pl.LightningModule):
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
assert isinstance(colorize_nlabels, int)
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
def init_from_ckpt(self, path, ignore_keys=None):
|
||||
ignore_keys = [] if ignore_keys is None else ignore_keys
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
logger.info("Deleting key {} from state_dict.".format(k))
|
||||
logger.info(f"Deleting key {k} from state_dict.")
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
logger.info(f"Restored from {path}")
|
||||
@ -61,7 +63,7 @@ class AutoencoderKL(pl.LightningModule):
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
def forward(self, input, sample_posterior=True): # noqa
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
|
@ -1,7 +1,7 @@
|
||||
import kornia
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
from torch import nn
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
@ -102,7 +102,9 @@ class FrozenClipImageEmbedder(nn.Module):
|
||||
antialias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, preprocess = clip.load(name=model_name, device=device, jit=jit)
|
||||
self.model, preprocess = clip.load( # noqa
|
||||
name=model_name, device=device, jit=jit
|
||||
)
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
|
@ -54,7 +54,7 @@ class DDPM(pl.LightningModule):
|
||||
beta_schedule="linear",
|
||||
loss_type="l2",
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
ignore_keys=None,
|
||||
load_only_unet=False,
|
||||
monitor="val/loss",
|
||||
first_stage_key="image",
|
||||
@ -77,6 +77,8 @@ class DDPM(pl.LightningModule):
|
||||
logvar_init=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
ignore_keys = [] if ignore_keys is None else ignore_keys
|
||||
|
||||
assert parameterization in [
|
||||
"eps",
|
||||
"x0",
|
||||
@ -236,7 +238,6 @@ class LatentDiffusion(DDPM):
|
||||
conditioning_key=None,
|
||||
scale_factor=1.0,
|
||||
scale_by_std=False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_timesteps_cond = (
|
||||
@ -251,7 +252,7 @@ class LatentDiffusion(DDPM):
|
||||
conditioning_key = None
|
||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
||||
super().__init__(conditioning_key=conditioning_key, **kwargs)
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
@ -286,7 +287,9 @@ class LatentDiffusion(DDPM):
|
||||
"""For creating seamless tiles"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
m.padding_mode = "circular" if enabled else m._initial_padding_mode
|
||||
m.padding_mode = (
|
||||
"circular" if enabled else m._initial_padding_mode # noqa
|
||||
)
|
||||
|
||||
def make_cond_schedule(
|
||||
self,
|
||||
@ -436,7 +439,7 @@ class LatentDiffusion(DDPM):
|
||||
:param x: img of size (bs, c, h, w)
|
||||
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
||||
"""
|
||||
bs, nc, h, w = x.shape
|
||||
bs, nc, h, w = x.shape # noqa
|
||||
|
||||
# number of crops in image
|
||||
Ly = (h - kernel_size[0]) // stride[0] + 1
|
||||
@ -595,7 +598,7 @@ class LatentDiffusion(DDPM):
|
||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||
df = self.split_input_params["vqf"]
|
||||
self.split_input_params["original_image_size"] = x.shape[-2:]
|
||||
bs, nc, h, w = x.shape
|
||||
bs, nc, h, w = x.shape # noqa
|
||||
if ks[0] > h or ks[1] > w:
|
||||
ks = (min(ks[0], h), min(ks[1], w))
|
||||
logger.info("reducing Kernel")
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pylama:ignore=W0613,W0612
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import gc
|
||||
import logging
|
||||
@ -5,8 +6,8 @@ import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from imaginairy.modules.attention import LinearAttention
|
||||
from imaginairy.modules.distributions import DiagonalGaussianDistribution
|
||||
@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
Matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
@ -286,10 +287,10 @@ def make_attn(in_channels, attn_type="vanilla"):
|
||||
)
|
||||
if attn_type == "vanilla":
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "none":
|
||||
if attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
return LinAttnBlock(in_channels)
|
||||
|
||||
return LinAttnBlock(in_channels)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
@ -502,6 +503,7 @@ class Decoder(nn.Module):
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.last_z_shape = None
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
@ -656,22 +658,22 @@ class Resize(nn.Module):
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
logger.info(
|
||||
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
|
||||
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" # noqa
|
||||
)
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=4, stride=2, padding=1
|
||||
)
|
||||
# assert in_channels is not None
|
||||
# # no asymmetric padding in torch conv, must do it ourselves
|
||||
# self.conv = torch.nn.Conv2d(
|
||||
# in_channels, in_channels, kernel_size=4, stride=2, padding=1
|
||||
# )
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor == 1.0:
|
||||
return x
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
|
||||
)
|
||||
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ from abc import abstractmethod
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from omegaconf.listconfig import ListConfig
|
||||
from torch import nn
|
||||
|
||||
from imaginairy.modules.attention import SpatialTransformer
|
||||
@ -488,7 +489,6 @@ class UNetModel(nn.Module):
|
||||
assert (
|
||||
use_spatial_transformer
|
||||
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
from omegaconf.listconfig import ListConfig
|
||||
|
||||
if isinstance(context_dim, ListConfig):
|
||||
context_dim = list(context_dim)
|
||||
@ -753,7 +753,7 @@ class UNetModel(nn.Module):
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): # noqa
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
|
@ -13,8 +13,8 @@ import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat as e_repeat
|
||||
from torch import nn
|
||||
|
||||
from imaginairy.utils import instantiate_from_config
|
||||
|
||||
@ -57,7 +57,7 @@ def make_beta_schedule(
|
||||
|
||||
|
||||
def frange(start, stop, step):
|
||||
"""range but handles floats"""
|
||||
"""Range but handles floats"""
|
||||
x = start
|
||||
while True:
|
||||
if x >= stop:
|
||||
@ -148,11 +148,11 @@ def checkpoint(func, inputs, params, flag):
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
class CheckpointFunction(torch.autograd.Function): # noqa
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
@ -252,7 +252,7 @@ class SiLU(nn.Module):
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
def forward(self, x): # noqa
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
|
@ -17,16 +17,13 @@ from imaginairy.vendored import k_diffusion as K
|
||||
|
||||
|
||||
def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=True):
|
||||
img_latent = pillow_img_to_model_latent(
|
||||
model, pil_img, batch_size=1, device="cuda", half=half
|
||||
)
|
||||
img_latent = pillow_img_to_model_latent(model, pil_img, batch_size=1, half=half)
|
||||
return find_noise_for_latent(
|
||||
model,
|
||||
img_latent,
|
||||
prompt,
|
||||
steps=steps,
|
||||
cond_scale=cond_scale,
|
||||
half=half,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pylama:ignore=W0613
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -25,9 +26,9 @@ _k_sampler_type_lookup = {
|
||||
|
||||
|
||||
def get_sampler(sampler_type, model):
|
||||
from imaginairy.samplers.ddim import DDIMSampler
|
||||
from imaginairy.samplers.kdiff import KDiffusionSampler
|
||||
from imaginairy.samplers.plms import PLMSSampler
|
||||
from imaginairy.samplers.ddim import DDIMSampler # noqa
|
||||
from imaginairy.samplers.kdiff import KDiffusionSampler # noqa
|
||||
from imaginairy.samplers.plms import PLMSSampler # noqa
|
||||
|
||||
sampler_type = sampler_type.lower()
|
||||
if sampler_type == "plms":
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pylama:ignore=W0613
|
||||
import torch
|
||||
|
||||
from imaginairy.img_log import log_latent
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pylama:ignore=W0613
|
||||
"""SAMPLING ONLY."""
|
||||
import logging
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user