autoformat

This commit is contained in:
Bryce 2022-09-23 14:41:15 -07:00
parent df28bf8805
commit a3a0de08e9
3 changed files with 4 additions and 10 deletions

View File

@ -11,8 +11,8 @@ from functools import partial
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch import nn
from einops import rearrange from einops import rearrange
from torch import nn
from torchvision.utils import make_grid from torchvision.utils import make_grid
from tqdm import tqdm from tqdm import tqdm
@ -348,16 +348,10 @@ class LatentDiffusion(DDPM):
model = instantiate_from_config(config) model = instantiate_from_config(config)
self.cond_stage_model = model self.cond_stage_model = model
def _get_denoise_row_from_list( def _get_denoise_row_from_list(self, samples, desc=""):
self, samples, desc=""
):
denoise_row = [] denoise_row = []
for zd in tqdm(samples, desc=desc): for zd in tqdm(samples, desc=desc):
denoise_row.append( denoise_row.append(self.decode_first_stage(zd.to(self.device)))
self.decode_first_stage(
zd.to(self.device)
)
)
n_imgs_per_row = len(denoise_row) n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")

View File

@ -15,6 +15,7 @@ from torch import autocast
from imaginairy.utils import get_device, pillow_img_to_torch_image from imaginairy.utils import get_device, pillow_img_to_torch_image
from imaginairy.vendored import k_diffusion as K from imaginairy.vendored import k_diffusion as K
def pil_img_to_latent(model, img, batch_size=1, half=True): def pil_img_to_latent(model, img, batch_size=1, half=True):
# init_image = pil_img_to_torch(img, half=half).to(device) # init_image = pil_img_to_torch(img, half=half).to(device)
init_image = pillow_img_to_torch_image(img).to(get_device()) init_image = pillow_img_to_torch_image(img).to(get_device())

View File

@ -3,7 +3,6 @@ from torch import nn
from imaginairy.utils import get_device from imaginairy.utils import get_device
SAMPLER_TYPE_OPTIONS = [ SAMPLER_TYPE_OPTIONS = [
"plms", "plms",
"ddim", "ddim",