From 58c2897dd19642c5de8c554cd989b021e0073d50 Mon Sep 17 00:00:00 2001 From: Bryce Date: Fri, 25 Nov 2022 13:46:22 -0800 Subject: [PATCH] refactor: fix lint issues --- imaginairy/modules/attention.py | 12 ++--- imaginairy/modules/autoencoder.py | 15 ++++--- imaginairy/modules/diffusion/ddpm.py | 21 +++++---- imaginairy/modules/diffusion/model.py | 49 +++++++++++---------- imaginairy/modules/diffusion/openaimodel.py | 7 ++- imaginairy/modules/ema.py | 7 ++- imaginairy/modules/encoders.py | 20 ++++----- requirements-dev.txt | 37 ++++++++++------ tox.ini | 8 ++-- 9 files changed, 94 insertions(+), 82 deletions(-) diff --git a/imaginairy/modules/attention.py b/imaginairy/modules/attention.py index d67ff72..e065299 100644 --- a/imaginairy/modules/attention.py +++ b/imaginairy/modules/attention.py @@ -6,15 +6,15 @@ import torch.nn.functional as F from einops import rearrange from torch import einsum, nn -from imaginairy.modules.diffusion.util import checkpoint +from imaginairy.modules.diffusion.util import checkpoint as checkpoint_eval from imaginairy.utils import get_device try: - import xformers - import xformers.ops + import xformers # noqa + import xformers.ops # noqa XFORMERS_IS_AVAILBLE = True -except: +except ImportError: XFORMERS_IS_AVAILBLE = False @@ -350,7 +350,7 @@ class BasicTransformerBlock(nn.Module): self.checkpoint = checkpoint def forward(self, x, context=None): - return checkpoint( + return checkpoint_eval( self._forward, (x, context), self.parameters(), self.checkpoint ) @@ -428,7 +428,7 @@ class SpatialTransformer(nn.Module): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] - b, c, h, w = x.shape + b, c, h, w = x.shape # noqa x_in = x x = self.norm(x) if not self.use_linear: diff --git a/imaginairy/modules/autoencoder.py b/imaginairy/modules/autoencoder.py index 96d6d9a..506952c 100644 --- a/imaginairy/modules/autoencoder.py +++ b/imaginairy/modules/autoencoder.py @@ -1,3 +1,4 @@ +# pylama:ignore=W0613 import logging from contextlib import contextmanager @@ -19,7 +20,7 @@ class AutoencoderKL(pl.LightningModule): lossconfig, embed_dim, ckpt_path=None, - ignore_keys=[], + ignore_keys=None, image_key="image", colorize_nlabels=None, monitor=None, @@ -37,7 +38,7 @@ class AutoencoderKL(pl.LightningModule): self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim if colorize_nlabels is not None: - assert type(colorize_nlabels) == int + assert isinstance(colorize_nlabels, int) self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor @@ -52,13 +53,14 @@ class AutoencoderKL(pl.LightningModule): if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - def init_from_ckpt(self, path, ignore_keys=list()): + def init_from_ckpt(self, path, ignore_keys=None): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) + ignore_keys = [] if ignore_keys is None else ignore_keys for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print(f"Deleting key {k} from state_dict.") del sd[k] self.load_state_dict(sd, strict=False) print(f"Restored from {path}") @@ -93,7 +95,7 @@ class AutoencoderKL(pl.LightningModule): dec = self.decoder(z) return dec - def forward(self, input, sample_posterior=True): + def forward(self, input, sample_posterior=True): # noqa posterior = self.encode(input) if sample_posterior: z = posterior.sample() @@ -161,6 +163,7 @@ class AutoencoderKL(pl.LightningModule): log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False ) return discloss + return None def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) @@ -218,7 +221,7 @@ class AutoencoderKL(pl.LightningModule): @torch.no_grad() def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if not only_inputs: diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index 94d4afe..bd48312 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -52,7 +52,7 @@ class DDPM(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=[], + ignore_keys=tuple(), load_only_unet=False, monitor="val/loss", use_ema=True, @@ -123,7 +123,7 @@ class DDPM(pl.LightningModule): if reset_ema: assert self.use_ema print( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + "Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." ) self.model_ema = LitEma(self.model) if reset_num_ema_updates: @@ -149,7 +149,7 @@ class DDPM(pl.LightningModule): if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) - self.ucg_training = ucg_training or dict() + self.ucg_training = ucg_training or {} if self.ucg_training: self.ucg_prng = np.random.RandomState() @@ -272,7 +272,7 @@ class DDPM(pl.LightningModule): print(f"{context}: Restored training weights") @torch.no_grad() - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + def init_from_ckpt(self, path, ignore_keys=tuple(), only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] @@ -280,7 +280,7 @@ class DDPM(pl.LightningModule): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print(f"Deleting key {k} from state_dict.") del sd[k] if self.make_it_fit: n_params = len( @@ -296,7 +296,7 @@ class DDPM(pl.LightningModule): desc="Fitting old weights to new weights", total=n_params, ): - if not name in sd: + if name not in sd: continue old_shape = sd[name].shape new_shape = param.shape @@ -592,7 +592,7 @@ class DDPM(pl.LightningModule): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -600,7 +600,7 @@ class DDPM(pl.LightningModule): log["inputs"] = x # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -626,8 +626,7 @@ class DDPM(pl.LightningModule): if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: return log - else: - return {key: log[key] for key in return_keys} + return {key: log[key] for key in return_keys} return log def configure_optimizers(self): @@ -1239,7 +1238,7 @@ class LatentDiffusion(DDPM): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample( + def p_sample( # noqa self, x, c, diff --git a/imaginairy/modules/diffusion/model.py b/imaginairy/modules/diffusion/model.py index 76a0a9a..fa50995 100644 --- a/imaginairy/modules/diffusion/model.py +++ b/imaginairy/modules/diffusion/model.py @@ -4,26 +4,28 @@ from typing import Any, Optional import numpy as np import torch -import torch.nn as nn from einops import rearrange +from torch import nn from imaginairy.modules.attention import MemoryEfficientCrossAttention try: - import xformers - import xformers.ops + import xformers # noqa + import xformers.ops # noqa - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False + XFORMERS_IS_AVAILABLE = True +except ImportError: + XFORMERS_IS_AVAILABLE = False # print("No module 'xformers'. Proceeding without it.") def get_timestep_embedding(timesteps, embedding_dim): """ + Build sinusoidal embeddings. + This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. - Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ @@ -271,22 +273,22 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): "linear", "none", ], f"attn_type {attn_type} unknown" - if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + if XFORMERS_IS_AVAILABLE and attn_type == "vanilla": attn_type = "vanilla-xformers" # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": assert attn_kwargs is None return AttnBlock(in_channels) - elif attn_type == "vanilla-xformers": + if attn_type == "vanilla-xformers": # print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") return MemoryEfficientAttnBlock(in_channels) - elif type == "memory-efficient-cross-attn": + if type == "memory-efficient-cross-attn": attn_kwargs["query_dim"] = in_channels return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) - elif attn_type == "none": + if attn_type == "none": return nn.Identity(in_channels) - else: - raise NotImplementedError() + + raise NotImplementedError() class Model(nn.Module): @@ -599,7 +601,7 @@ class Decoder(nn.Module): tanh_out=False, use_linear_attn=False, attn_type="vanilla", - **ignorekwargs, + **ignore_kwargs, ): super().__init__() if use_linear_attn: @@ -677,6 +679,8 @@ class Decoder(nn.Module): block_in, out_ch, kernel_size=3, stride=1, padding=1 ) + self.last_z_shape = None + def forward(self, z): # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape @@ -1003,17 +1007,16 @@ class Resize(nn.Module): # f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" # ) raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=4, stride=2, padding=1 - ) + # assert in_channels is not None + # # no asymmetric padding in torch conv, must do it ourselves + # self.conv = torch.nn.Conv2d( + # in_channels, in_channels, kernel_size=4, stride=2, padding=1 + # ) def forward(self, x, scale_factor=1.0): if scale_factor == 1.0: return x - else: - x = torch.nn.functional.interpolate( - x, mode=self.mode, align_corners=False, scale_factor=scale_factor - ) + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) return x diff --git a/imaginairy/modules/diffusion/openaimodel.py b/imaginairy/modules/diffusion/openaimodel.py index 84f6bb2..67bf588 100644 --- a/imaginairy/modules/diffusion/openaimodel.py +++ b/imaginairy/modules/diffusion/openaimodel.py @@ -4,6 +4,7 @@ from abc import abstractmethod import numpy as np import torch as th import torch.nn.functional as F +from omegaconf.listconfig import ListConfig from torch import nn from imaginairy.modules.attention import SpatialTransformer @@ -493,9 +494,8 @@ class UNetModel(nn.Module): assert ( use_spatial_transformer ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." - from omegaconf.listconfig import ListConfig - if type(context_dim) == ListConfig: + if isinstance(context_dim, ListConfig): context_dim = list(context_dim) if num_heads_upsample == -1: @@ -842,5 +842,4 @@ class UNetModel(nn.Module): h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) - else: - return self.out(h) + return self.out(h) diff --git a/imaginairy/modules/ema.py b/imaginairy/modules/ema.py index 9db7015..68a164e 100644 --- a/imaginairy/modules/ema.py +++ b/imaginairy/modules/ema.py @@ -53,7 +53,7 @@ class LitEma(nn.Module): one_minus_decay * (shadow_params[sname] - m_param[key]) ) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -62,11 +62,12 @@ class LitEma(nn.Module): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ Save the current parameters for restoring later. + Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. @@ -76,10 +77,12 @@ class LitEma(nn.Module): def restore(self, parameters): """ Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. + Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. diff --git a/imaginairy/modules/encoders.py b/imaginairy/modules/encoders.py index 2cd2716..8bf59f6 100644 --- a/imaginairy/modules/encoders.py +++ b/imaginairy/modules/encoders.py @@ -1,6 +1,6 @@ import open_clip import torch -import torch.nn as nn +from torch import nn from torch.utils.checkpoint import checkpoint from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer @@ -10,18 +10,10 @@ from imaginairy.utils import get_device class AbstractEncoder(nn.Module): - def __init__(self): - super().__init__() - def encode(self, *args, **kwargs): raise NotImplementedError -class IdentityEncoder(AbstractEncoder): - def encode(self, x): - return x - - class ClassEmbedder(nn.Module): def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): super().__init__() @@ -51,9 +43,13 @@ class ClassEmbedder(nn.Module): return uc -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" +def disabled_train(self, mode=True): # noqa + """ + For disabling train/eval mode. + + Overwrite `model.train` with this function to make sure train/eval mode + does not change anymore. + """ return self diff --git a/requirements-dev.txt b/requirements-dev.txt index 573c11c..60d69e5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,7 +16,7 @@ aiosignal==1.3.1 # via aiohttp antlr4-python3-runtime==4.8 # via omegaconf -astroid==2.12.12 +astroid==2.12.13 # via pylint async-timeout==4.0.2 # via aiohttp @@ -52,13 +52,13 @@ coverage==6.5.0 # via -r requirements-dev.in cycler==0.11.0 # via matplotlib -diffusers==0.7.2 +diffusers==0.8.1 # via imaginAIry (setup.py) dill==0.3.6 # via pylint einops==0.3.0 # via imaginAIry (setup.py) -exceptiongroup==1.0.1 +exceptiongroup==1.0.4 # via pytest facexlib==0.2.5 # via @@ -82,7 +82,9 @@ frozenlist==1.3.3 fsspec[http]==2022.11.0 # via pytorch-lightning ftfy==6.1.1 - # via imaginAIry (setup.py) + # via + # imaginAIry (setup.py) + # open-clip-torch future==0.18.2 # via # basicsr @@ -104,9 +106,10 @@ grpcio==1.50.0 # via # tb-nightly # tensorboard -huggingface-hub==0.10.1 +huggingface-hub==0.11.0 # via # diffusers + # open-clip-torch # timm # transformers idna==3.4 @@ -117,7 +120,7 @@ imageio==2.9.0 # via # imaginAIry (setup.py) # scikit-image -importlib-metadata==5.0.0 +importlib-metadata==5.1.0 # via diffusers iniconfig==1.1.1 # via pytest @@ -163,7 +166,7 @@ networkx==2.8.8 # via scikit-image numba==0.56.4 # via facexlib -numpy==1.23.4 +numpy==1.23.5 # via # basicsr # contourpy @@ -192,6 +195,8 @@ oauthlib==3.2.2 # via requests-oauthlib omegaconf==2.1.1 # via imaginAIry (setup.py) +open-clip-torch==2.7.0 + # via imaginAIry (setup.py) opencv-python==4.6.0.66 # via # basicsr @@ -245,7 +250,7 @@ pyasn1-modules==0.2.8 # via google-auth pycln==2.1.2 # via -r requirements-dev.in -pycodestyle==2.9.1 +pycodestyle==2.10.0 # via pylama pydeprecate==0.3.1 # via pytorch-lightning @@ -253,11 +258,11 @@ pydocstyle==6.1.1 # via # -r requirements-dev.in # pylama -pyflakes==2.5.0 +pyflakes==3.0.1 # via pylama pylama==8.4.1 # via -r requirements-dev.in -pylint==2.15.5 +pylint==2.15.6 # via -r requirements-dev.in pyparsing==3.0.9 # via @@ -294,6 +299,7 @@ realesrgan==0.3.0 regex==2022.10.31 # via # diffusers + # open-clip-torch # transformers requests==2.28.1 # via @@ -330,7 +336,7 @@ six==1.16.0 # python-dateutil snowballstemmer==2.2.0 # via pydocstyle -tb-nightly==2.12.0a20221113 +tb-nightly==2.12.0a20221125 # via # basicsr # gfpgan @@ -344,11 +350,11 @@ tensorboard-plugin-wit==1.8.1 # via # tb-nightly # tensorboard -termcolor==2.1.0 +termcolor==2.1.1 # via pytest-sugar tifffile==2022.10.10 # via scikit-image -timm==0.6.11 +timm==0.6.12 # via imaginAIry (setup.py) tokenizers==0.12.1 # via transformers @@ -371,6 +377,7 @@ torch==1.13.0 # gfpgan # imaginAIry (setup.py) # kornia + # open-clip-torch # pytorch-lightning # realesrgan # timm @@ -389,6 +396,7 @@ torchvision==0.14.0 # facexlib # gfpgan # imaginAIry (setup.py) + # open-clip-torch # realesrgan # timm tqdm==4.64.1 @@ -398,6 +406,7 @@ tqdm==4.64.1 # gfpgan # huggingface-hub # imaginAIry (setup.py) + # open-clip-torch # pytorch-lightning # realesrgan # transformers @@ -417,7 +426,7 @@ typing-extensions==4.4.0 # typing-inspect typing-inspect==0.8.0 # via libcst -urllib3==1.26.12 +urllib3==1.26.13 # via # requests # responses diff --git a/tox.ini b/tox.ini index f788fb0..f95dba5 100644 --- a/tox.ini +++ b/tox.ini @@ -11,11 +11,11 @@ format = pylint skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*,testing_support/vastai_cli_official.py linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy ignore = - Z999,C0103,C0301,C0114,C0115,C0116, - Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D415, + Z999,C0103,C0301,C0302,C0114,C0115,C0116, + Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D413,D415, Z999,E501,E1101, - Z999,R0901,R0902,R0903,R0193,R0912,R0913,R0914,R0915, - Z999,W0221,W0511,W1203 + Z999,R0901,R0902,R0903,R0904,R0193,R0912,R0913,R0914,R0915,R1702, + Z999,W0221,W0511,W0612,W0613,W1203 [pylama:tests/*] ignore = C0104,C0114,C0116,D103,W0143,W0613