autoformat

pull/22/head
Bryce 2 years ago
parent df28bf8805
commit a3a0de08e9

@ -11,8 +11,8 @@ from functools import partial
import numpy as np
import pytorch_lightning as pl
import torch
from torch import nn
from einops import rearrange
from torch import nn
from torchvision.utils import make_grid
from tqdm import tqdm
@ -348,16 +348,10 @@ class LatentDiffusion(DDPM):
model = instantiate_from_config(config)
self.cond_stage_model = model
def _get_denoise_row_from_list(
self, samples, desc=""
):
def _get_denoise_row_from_list(self, samples, desc=""):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(
self.decode_first_stage(
zd.to(self.device)
)
)
denoise_row.append(self.decode_first_stage(zd.to(self.device)))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")

@ -15,6 +15,7 @@ from torch import autocast
from imaginairy.utils import get_device, pillow_img_to_torch_image
from imaginairy.vendored import k_diffusion as K
def pil_img_to_latent(model, img, batch_size=1, half=True):
# init_image = pil_img_to_torch(img, half=half).to(device)
init_image = pillow_img_to_torch_image(img).to(get_device())

@ -3,7 +3,6 @@ from torch import nn
from imaginairy.utils import get_device
SAMPLER_TYPE_OPTIONS = [
"plms",
"ddim",

Loading…
Cancel
Save