From a3a0de08e97180a277858dee3ddfeb69d5de063e Mon Sep 17 00:00:00 2001 From: Bryce Date: Fri, 23 Sep 2022 14:41:15 -0700 Subject: [PATCH] autoformat --- imaginairy/modules/diffusion/ddpm.py | 12 +++--------- imaginairy/modules/find_noise.py | 1 + imaginairy/samplers/base.py | 1 - 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index ab65c18..7736632 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -11,8 +11,8 @@ from functools import partial import numpy as np import pytorch_lightning as pl import torch -from torch import nn from einops import rearrange +from torch import nn from torchvision.utils import make_grid from tqdm import tqdm @@ -348,16 +348,10 @@ class LatentDiffusion(DDPM): model = instantiate_from_config(config) self.cond_stage_model = model - def _get_denoise_row_from_list( - self, samples, desc="" - ): + def _get_denoise_row_from_list(self, samples, desc=""): denoise_row = [] for zd in tqdm(samples, desc=desc): - denoise_row.append( - self.decode_first_stage( - zd.to(self.device) - ) - ) + denoise_row.append(self.decode_first_stage(zd.to(self.device))) n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") diff --git a/imaginairy/modules/find_noise.py b/imaginairy/modules/find_noise.py index bb7857a..7edf30a 100644 --- a/imaginairy/modules/find_noise.py +++ b/imaginairy/modules/find_noise.py @@ -15,6 +15,7 @@ from torch import autocast from imaginairy.utils import get_device, pillow_img_to_torch_image from imaginairy.vendored import k_diffusion as K + def pil_img_to_latent(model, img, batch_size=1, half=True): # init_image = pil_img_to_torch(img, half=half).to(device) init_image = pillow_img_to_torch_image(img).to(get_device()) diff --git a/imaginairy/samplers/base.py b/imaginairy/samplers/base.py index 8a69c2f..2a73bae 100644 --- a/imaginairy/samplers/base.py +++ b/imaginairy/samplers/base.py @@ -3,7 +3,6 @@ from torch import nn from imaginairy.utils import get_device - SAMPLER_TYPE_OPTIONS = [ "plms", "ddim",