refactor: fix lint issues

pull/108/head
Bryce 1 year ago committed by Bryce Drennan
parent 40ab571fc1
commit 58c2897dd1

@ -6,15 +6,15 @@ import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn
from imaginairy.modules.diffusion.util import checkpoint
from imaginairy.modules.diffusion.util import checkpoint as checkpoint_eval
from imaginairy.utils import get_device
try:
import xformers
import xformers.ops
import xformers # noqa
import xformers.ops # noqa
XFORMERS_IS_AVAILBLE = True
except:
except ImportError:
XFORMERS_IS_AVAILBLE = False
@ -350,7 +350,7 @@ class BasicTransformerBlock(nn.Module):
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(
return checkpoint_eval(
self._forward, (x, context), self.parameters(), self.checkpoint
)
@ -428,7 +428,7 @@ class SpatialTransformer(nn.Module):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
b, c, h, w = x.shape # noqa
x_in = x
x = self.norm(x)
if not self.use_linear:

@ -1,3 +1,4 @@
# pylama:ignore=W0613
import logging
from contextlib import contextmanager
@ -19,7 +20,7 @@ class AutoencoderKL(pl.LightningModule):
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
@ -37,7 +38,7 @@ class AutoencoderKL(pl.LightningModule):
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
assert isinstance(colorize_nlabels, int)
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
@ -52,13 +53,14 @@ class AutoencoderKL(pl.LightningModule):
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
ignore_keys = [] if ignore_keys is None else ignore_keys
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print(f"Deleting key {k} from state_dict.")
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@ -93,7 +95,7 @@ class AutoencoderKL(pl.LightningModule):
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
def forward(self, input, sample_posterior=True): # noqa
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
@ -161,6 +163,7 @@ class AutoencoderKL(pl.LightningModule):
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return discloss
return None
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
@ -218,7 +221,7 @@ class AutoencoderKL(pl.LightningModule):
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
log = {}
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:

@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
ignore_keys=tuple(),
load_only_unet=False,
monitor="val/loss",
use_ema=True,
@ -123,7 +123,7 @@ class DDPM(pl.LightningModule):
if reset_ema:
assert self.use_ema
print(
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
)
self.model_ema = LitEma(self.model)
if reset_num_ema_updates:
@ -149,7 +149,7 @@ class DDPM(pl.LightningModule):
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.ucg_training = ucg_training or dict()
self.ucg_training = ucg_training or {}
if self.ucg_training:
self.ucg_prng = np.random.RandomState()
@ -272,7 +272,7 @@ class DDPM(pl.LightningModule):
print(f"{context}: Restored training weights")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
def init_from_ckpt(self, path, ignore_keys=tuple(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
@ -280,7 +280,7 @@ class DDPM(pl.LightningModule):
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print(f"Deleting key {k} from state_dict.")
del sd[k]
if self.make_it_fit:
n_params = len(
@ -296,7 +296,7 @@ class DDPM(pl.LightningModule):
desc="Fitting old weights to new weights",
total=n_params,
):
if not name in sd:
if name not in sd:
continue
old_shape = sd[name].shape
new_shape = param.shape
@ -592,7 +592,7 @@ class DDPM(pl.LightningModule):
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
log = {}
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
@ -600,7 +600,7 @@ class DDPM(pl.LightningModule):
log["inputs"] = x
# get diffusion row
diffusion_row = list()
diffusion_row = []
x_start = x[:n_row]
for t in range(self.num_timesteps):
@ -626,8 +626,7 @@ class DDPM(pl.LightningModule):
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
@ -1239,7 +1238,7 @@ class LatentDiffusion(DDPM):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(
def p_sample( # noqa
self,
x,
c,

@ -4,26 +4,28 @@ from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch import nn
from imaginairy.modules.attention import MemoryEfficientCrossAttention
try:
import xformers
import xformers.ops
import xformers # noqa
import xformers.ops # noqa
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
XFORMERS_IS_AVAILABLE = True
except ImportError:
XFORMERS_IS_AVAILABLE = False
# print("No module 'xformers'. Proceeding without it.")
def get_timestep_embedding(timesteps, embedding_dim):
"""
Build sinusoidal embeddings.
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
@ -271,22 +273,22 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
"linear",
"none",
], f"attn_type {attn_type} unknown"
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
if XFORMERS_IS_AVAILABLE and attn_type == "vanilla":
attn_type = "vanilla-xformers"
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
if attn_type == "vanilla-xformers":
# print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
if type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
if attn_type == "none":
return nn.Identity(in_channels)
else:
raise NotImplementedError()
raise NotImplementedError()
class Model(nn.Module):
@ -599,7 +601,7 @@ class Decoder(nn.Module):
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs,
**ignore_kwargs,
):
super().__init__()
if use_linear_attn:
@ -677,6 +679,8 @@ class Decoder(nn.Module):
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
self.last_z_shape = None
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
@ -1003,17 +1007,16 @@ class Resize(nn.Module):
# f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
# )
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=4, stride=2, padding=1
)
# assert in_channels is not None
# # no asymmetric padding in torch conv, must do it ourselves
# self.conv = torch.nn.Conv2d(
# in_channels, in_channels, kernel_size=4, stride=2, padding=1
# )
def forward(self, x, scale_factor=1.0):
if scale_factor == 1.0:
return x
else:
x = torch.nn.functional.interpolate(
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
)
x = torch.nn.functional.interpolate(
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
)
return x

@ -4,6 +4,7 @@ from abc import abstractmethod
import numpy as np
import torch as th
import torch.nn.functional as F
from omegaconf.listconfig import ListConfig
from torch import nn
from imaginairy.modules.attention import SpatialTransformer
@ -493,9 +494,8 @@ class UNetModel(nn.Module):
assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
if isinstance(context_dim, ListConfig):
context_dim = list(context_dim)
if num_heads_upsample == -1:
@ -842,5 +842,4 @@ class UNetModel(nn.Module):
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
return self.out(h)

@ -53,7 +53,7 @@ class LitEma(nn.Module):
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
@ -62,11 +62,12 @@ class LitEma(nn.Module):
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
@ -76,10 +77,12 @@ class LitEma(nn.Module):
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.

@ -1,6 +1,6 @@
import open_clip
import torch
import torch.nn as nn
from torch import nn
from torch.utils.checkpoint import checkpoint
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
@ -10,18 +10,10 @@ from imaginairy.utils import get_device
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
super().__init__()
@ -51,9 +43,13 @@ class ClassEmbedder(nn.Module):
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
def disabled_train(self, mode=True): # noqa
"""
For disabling train/eval mode.
Overwrite `model.train` with this function to make sure train/eval mode
does not change anymore.
"""
return self

@ -16,7 +16,7 @@ aiosignal==1.3.1
# via aiohttp
antlr4-python3-runtime==4.8
# via omegaconf
astroid==2.12.12
astroid==2.12.13
# via pylint
async-timeout==4.0.2
# via aiohttp
@ -52,13 +52,13 @@ coverage==6.5.0
# via -r requirements-dev.in
cycler==0.11.0
# via matplotlib
diffusers==0.7.2
diffusers==0.8.1
# via imaginAIry (setup.py)
dill==0.3.6
# via pylint
einops==0.3.0
# via imaginAIry (setup.py)
exceptiongroup==1.0.1
exceptiongroup==1.0.4
# via pytest
facexlib==0.2.5
# via
@ -82,7 +82,9 @@ frozenlist==1.3.3
fsspec[http]==2022.11.0
# via pytorch-lightning
ftfy==6.1.1
# via imaginAIry (setup.py)
# via
# imaginAIry (setup.py)
# open-clip-torch
future==0.18.2
# via
# basicsr
@ -104,9 +106,10 @@ grpcio==1.50.0
# via
# tb-nightly
# tensorboard
huggingface-hub==0.10.1
huggingface-hub==0.11.0
# via
# diffusers
# open-clip-torch
# timm
# transformers
idna==3.4
@ -117,7 +120,7 @@ imageio==2.9.0
# via
# imaginAIry (setup.py)
# scikit-image
importlib-metadata==5.0.0
importlib-metadata==5.1.0
# via diffusers
iniconfig==1.1.1
# via pytest
@ -163,7 +166,7 @@ networkx==2.8.8
# via scikit-image
numba==0.56.4
# via facexlib
numpy==1.23.4
numpy==1.23.5
# via
# basicsr
# contourpy
@ -192,6 +195,8 @@ oauthlib==3.2.2
# via requests-oauthlib
omegaconf==2.1.1
# via imaginAIry (setup.py)
open-clip-torch==2.7.0
# via imaginAIry (setup.py)
opencv-python==4.6.0.66
# via
# basicsr
@ -245,7 +250,7 @@ pyasn1-modules==0.2.8
# via google-auth
pycln==2.1.2
# via -r requirements-dev.in
pycodestyle==2.9.1
pycodestyle==2.10.0
# via pylama
pydeprecate==0.3.1
# via pytorch-lightning
@ -253,11 +258,11 @@ pydocstyle==6.1.1
# via
# -r requirements-dev.in
# pylama
pyflakes==2.5.0
pyflakes==3.0.1
# via pylama
pylama==8.4.1
# via -r requirements-dev.in
pylint==2.15.5
pylint==2.15.6
# via -r requirements-dev.in
pyparsing==3.0.9
# via
@ -294,6 +299,7 @@ realesrgan==0.3.0
regex==2022.10.31
# via
# diffusers
# open-clip-torch
# transformers
requests==2.28.1
# via
@ -330,7 +336,7 @@ six==1.16.0
# python-dateutil
snowballstemmer==2.2.0
# via pydocstyle
tb-nightly==2.12.0a20221113
tb-nightly==2.12.0a20221125
# via
# basicsr
# gfpgan
@ -344,11 +350,11 @@ tensorboard-plugin-wit==1.8.1
# via
# tb-nightly
# tensorboard
termcolor==2.1.0
termcolor==2.1.1
# via pytest-sugar
tifffile==2022.10.10
# via scikit-image
timm==0.6.11
timm==0.6.12
# via imaginAIry (setup.py)
tokenizers==0.12.1
# via transformers
@ -371,6 +377,7 @@ torch==1.13.0
# gfpgan
# imaginAIry (setup.py)
# kornia
# open-clip-torch
# pytorch-lightning
# realesrgan
# timm
@ -389,6 +396,7 @@ torchvision==0.14.0
# facexlib
# gfpgan
# imaginAIry (setup.py)
# open-clip-torch
# realesrgan
# timm
tqdm==4.64.1
@ -398,6 +406,7 @@ tqdm==4.64.1
# gfpgan
# huggingface-hub
# imaginAIry (setup.py)
# open-clip-torch
# pytorch-lightning
# realesrgan
# transformers
@ -417,7 +426,7 @@ typing-extensions==4.4.0
# typing-inspect
typing-inspect==0.8.0
# via libcst
urllib3==1.26.12
urllib3==1.26.13
# via
# requests
# responses

@ -11,11 +11,11 @@ format = pylint
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*,testing_support/vastai_cli_official.py
linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy
ignore =
Z999,C0103,C0301,C0114,C0115,C0116,
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D415,
Z999,C0103,C0301,C0302,C0114,C0115,C0116,
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D413,D415,
Z999,E501,E1101,
Z999,R0901,R0902,R0903,R0193,R0912,R0913,R0914,R0915,
Z999,W0221,W0511,W1203
Z999,R0901,R0902,R0903,R0904,R0193,R0912,R0913,R0914,R0915,R1702,
Z999,W0221,W0511,W0612,W0613,W1203
[pylama:tests/*]
ignore = C0104,C0114,C0116,D103,W0143,W0613

Loading…
Cancel
Save