style: fix lint issues

This commit is contained in:
Bryce 2022-09-24 00:29:45 -07:00 committed by Bryce Drennan
parent 38c7f88950
commit 69af07ab67
14 changed files with 62 additions and 53 deletions

View File

@ -213,7 +213,7 @@ def imagine(
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
try:
init_image, _, h = pillow_fit_image_within(
init_image = pillow_fit_image_within(
prompt.init_image,
max_height=prompt.height,
max_width=prompt.width,

View File

@ -68,7 +68,7 @@ def enhance_faces(img, fidelity=0):
try:
with torch.no_grad():
output = net(cropped_face_t, w=fidelity, adain=True)[0]
output = net(cropped_face_t, w=fidelity, adain=True)[0] # noqa
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()

View File

@ -16,7 +16,7 @@ def pillow_fit_image_within(image: PIL.Image.Image, max_height=512, max_width=51
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.NEAREST)
return image, w, h
return image
def pillow_img_to_torch_image(img: PIL.Image.Image):

View File

@ -85,7 +85,7 @@ class LinearAttention(nn.Module):
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
b, c, h, w = x.shape # noqa
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
@ -126,7 +126,7 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
b, c, h, w = q.shape # noqa
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
@ -178,9 +178,9 @@ class CrossAttention(nn.Module):
if mask is not None:
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
_max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
sim.masked_fill_(~mask, _max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
@ -189,7 +189,7 @@ class CrossAttention(nn.Module):
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
def forward_cuda(self, x, context=None, mask=None):
def forward_cuda(self, x, context=None, mask=None): # noqa
h = self.heads
q_in = self.to_q(x)
@ -258,7 +258,7 @@ class BasicTransformerBlock(nn.Module):
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
checkpoint=True, # noqa
):
super().__init__()
self.attn1 = CrossAttention(
@ -326,7 +326,7 @@ class SpatialTransformer(nn.Module):
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
b, c, h, w = x.shape # noqa
x_in = x
x = self.norm(x)
x = self.proj_in(x)

View File

@ -17,12 +17,13 @@ class AutoencoderKL(pl.LightningModule):
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
):
super().__init__()
ignore_keys = [] if ignore_keys is None else ignore_keys
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
@ -32,20 +33,21 @@ class AutoencoderKL(pl.LightningModule):
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
assert isinstance(colorize_nlabels, int)
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
def init_from_ckpt(self, path, ignore_keys=None):
ignore_keys = [] if ignore_keys is None else ignore_keys
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
logger.info("Deleting key {} from state_dict.".format(k))
logger.info(f"Deleting key {k} from state_dict.")
del sd[k]
self.load_state_dict(sd, strict=False)
logger.info(f"Restored from {path}")
@ -61,7 +63,7 @@ class AutoencoderKL(pl.LightningModule):
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
def forward(self, input, sample_posterior=True): # noqa
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()

View File

@ -1,7 +1,7 @@
import kornia
import torch
import torch.nn as nn
from einops import repeat
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer
from imaginairy.utils import get_device
@ -102,7 +102,9 @@ class FrozenClipImageEmbedder(nn.Module):
antialias=False,
):
super().__init__()
self.model, preprocess = clip.load(name=model_name, device=device, jit=jit)
self.model, preprocess = clip.load( # noqa
name=model_name, device=device, jit=jit
)
self.antialias = antialias

View File

@ -54,7 +54,7 @@ class DDPM(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
first_stage_key="image",
@ -77,6 +77,8 @@ class DDPM(pl.LightningModule):
logvar_init=0.0,
):
super().__init__()
ignore_keys = [] if ignore_keys is None else ignore_keys
assert parameterization in [
"eps",
"x0",
@ -236,7 +238,6 @@ class LatentDiffusion(DDPM):
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
*args,
**kwargs,
):
self.num_timesteps_cond = (
@ -251,7 +252,7 @@ class LatentDiffusion(DDPM):
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
super().__init__(conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
@ -286,7 +287,9 @@ class LatentDiffusion(DDPM):
"""For creating seamless tiles"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.padding_mode = "circular" if enabled else m._initial_padding_mode
m.padding_mode = (
"circular" if enabled else m._initial_padding_mode # noqa
)
def make_cond_schedule(
self,
@ -436,7 +439,7 @@ class LatentDiffusion(DDPM):
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape
bs, nc, h, w = x.shape # noqa
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
@ -595,7 +598,7 @@ class LatentDiffusion(DDPM):
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params["original_image_size"] = x.shape[-2:]
bs, nc, h, w = x.shape
bs, nc, h, w = x.shape # noqa
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
logger.info("reducing Kernel")

View File

@ -1,3 +1,4 @@
# pylama:ignore=W0613,W0612
# pytorch_diffusion + derived encoder decoder
import gc
import logging
@ -5,8 +6,8 @@ import math
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch import nn
from imaginairy.modules.attention import LinearAttention
from imaginairy.modules.distributions import DiagonalGaussianDistribution
@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
Matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
@ -286,10 +287,10 @@ def make_attn(in_channels, attn_type="vanilla"):
)
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
if attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
return LinAttnBlock(in_channels)
class Encoder(nn.Module):
@ -502,6 +503,7 @@ class Decoder(nn.Module):
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
self.last_z_shape = None
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
@ -656,22 +658,22 @@ class Resize(nn.Module):
self.mode = mode
if self.with_conv:
logger.info(
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" # noqa
)
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=4, stride=2, padding=1
)
# assert in_channels is not None
# # no asymmetric padding in torch conv, must do it ourselves
# self.conv = torch.nn.Conv2d(
# in_channels, in_channels, kernel_size=4, stride=2, padding=1
# )
def forward(self, x, scale_factor=1.0):
if scale_factor == 1.0:
return x
else:
x = torch.nn.functional.interpolate(
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
)
x = torch.nn.functional.interpolate(
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
)
return x

View File

@ -4,6 +4,7 @@ from abc import abstractmethod
import numpy as np
import torch as th
import torch.nn.functional as F
from omegaconf.listconfig import ListConfig
from torch import nn
from imaginairy.modules.attention import SpatialTransformer
@ -488,7 +489,6 @@ class UNetModel(nn.Module):
assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
from omegaconf.listconfig import ListConfig
if isinstance(context_dim, ListConfig):
context_dim = list(context_dim)
@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): # noqa
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.

View File

@ -13,8 +13,8 @@ import math
import numpy as np
import torch
import torch.nn as nn
from einops import repeat as e_repeat
from torch import nn
from imaginairy.utils import instantiate_from_config
@ -57,7 +57,7 @@ def make_beta_schedule(
def frange(start, stop, step):
"""range but handles floats"""
"""Range but handles floats"""
x = start
while True:
if x >= stop:
@ -148,11 +148,11 @@ def checkpoint(func, inputs, params, flag):
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
class CheckpointFunction(torch.autograd.Function): # noqa
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
@ -252,7 +252,7 @@ class SiLU(nn.Module):
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
def forward(self, x): # noqa
return super().forward(x.float()).type(x.dtype)

View File

@ -17,16 +17,13 @@ from imaginairy.vendored import k_diffusion as K
def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=True):
img_latent = pillow_img_to_model_latent(
model, pil_img, batch_size=1, device="cuda", half=half
)
img_latent = pillow_img_to_model_latent(model, pil_img, batch_size=1, half=half)
return find_noise_for_latent(
model,
img_latent,
prompt,
steps=steps,
cond_scale=cond_scale,
half=half,
)

View File

@ -1,3 +1,4 @@
# pylama:ignore=W0613
import torch
from torch import nn
@ -25,9 +26,9 @@ _k_sampler_type_lookup = {
def get_sampler(sampler_type, model):
from imaginairy.samplers.ddim import DDIMSampler
from imaginairy.samplers.kdiff import KDiffusionSampler
from imaginairy.samplers.plms import PLMSSampler
from imaginairy.samplers.ddim import DDIMSampler # noqa
from imaginairy.samplers.kdiff import KDiffusionSampler # noqa
from imaginairy.samplers.plms import PLMSSampler # noqa
sampler_type = sampler_type.lower()
if sampler_type == "plms":

View File

@ -1,3 +1,4 @@
# pylama:ignore=W0613
import torch
from imaginairy.img_log import log_latent

View File

@ -1,3 +1,4 @@
# pylama:ignore=W0613
"""SAMPLING ONLY."""
import logging