From 66c640ce7bf7db90ca33d1dde025e97c688d5692 Mon Sep 17 00:00:00 2001 From: Bryce Date: Fri, 9 Sep 2022 22:14:04 -0700 Subject: [PATCH] feature: add ImageResult. step output option remove verbose args --- imaginairy/cmds.py | 7 +- imaginairy/imagine.py | 199 ++++++++++++--------------- imaginairy/models/diffusion/ddim.py | 45 +++--- imaginairy/models/diffusion/plms.py | 10 +- imaginairy/modules/clip_embedders.py | 2 +- imaginairy/schema.py | 70 ++++++++++ requirements-dev.in | 1 + tests/__init__.py | 3 + tests/test_imagine.py | 60 ++++++++ 9 files changed, 253 insertions(+), 144 deletions(-) create mode 100644 imaginairy/schema.py create mode 100644 requirements-dev.in create mode 100644 tests/test_imagine.py diff --git a/imaginairy/cmds.py b/imaginairy/cmds.py index 007e929..0d2ac69 100644 --- a/imaginairy/cmds.py +++ b/imaginairy/cmds.py @@ -2,7 +2,8 @@ import click -from imaginairy.imagine import ImaginePrompt, imagine as imagine_f +from imaginairy.imagine import imagine as imagine_f +from imaginairy.schema import ImaginePrompt @click.command() @@ -10,7 +11,9 @@ from imaginairy.imagine import ImaginePrompt, imagine as imagine_f "prompt_texts", default=None, help="text to render to an image", nargs=-1 ) @click.option("--outdir", default="./outputs", help="where to write results to") -@click.option("-r", "--repeats", default=1, type=int, help="How many times to repeat the renders") +@click.option( + "-r", "--repeats", default=1, type=int, help="How many times to repeat the renders" +) @click.option( "-h", "--height", diff --git a/imaginairy/imagine.py b/imaginairy/imagine.py index 242e323..157e646 100755 --- a/imaginairy/imagine.py +++ b/imaginairy/imagine.py @@ -1,7 +1,5 @@ -import argparse import logging import os -import random import re import subprocess from contextlib import nullcontext @@ -18,13 +16,14 @@ from torch import autocast from imaginairy.models.diffusion.ddim import DDIMSampler from imaginairy.models.diffusion.plms import PLMSSampler +from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.utils import get_device, instantiate_from_config LIB_PATH = os.path.dirname(__file__) logger = logging.getLogger(__name__) -def load_model_from_config(config, ckpt, verbose=False): +def load_model_from_config(config, ckpt): logger.info(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: @@ -32,9 +31,9 @@ def load_model_from_config(config, ckpt, verbose=False): sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: + if len(m) > 0: logger.info(f"missing keys: {m}") - if len(u) > 0 and verbose: + if len(u) > 0: logger.info(f"unexpected keys: {u}") model.cuda() @@ -42,56 +41,6 @@ def load_model_from_config(config, ckpt, verbose=False): return model -class WeightedPrompt: - def __init__(self, text, weight=1): - self.text = text - self.weight = weight - - def __str__(self): - return f"{self.weight}*({self.text})" - - -class ImaginePrompt: - def __init__( - self, - prompt=None, - seed=None, - prompt_strength=7.5, - sampler_type="PLMS", - init_image=None, - init_image_strength=0.3, - steps=50, - height=512, - width=512, - upscale=True, - fix_faces=True, - parts=None, - ): - prompt = prompt if prompt is not None else "a scenic landscape" - if isinstance(prompt, str): - self.prompts = [WeightedPrompt(prompt, 1)] - else: - self.prompts = prompt - self.init_image = init_image - self.init_image_strength = init_image_strength - self.prompts.sort(key=lambda p: p.weight, reverse=True) - self.seed = random.randint(1, 1_000_000_000) if seed is None else seed - self.prompt_strength = prompt_strength - self.sampler_type = sampler_type - self.steps = steps - self.height = height - self.width = width - self.upscale = upscale - self.fix_faces = fix_faces - self.parts = parts or {} - - @property - def prompt_text(self): - if len(self.prompts) == 1: - return self.prompts[0].text - return "|".join(str(p) for p in self.prompts) - - def load_img(path, max_height=512, max_width=512): image = Image.open(path).convert("RGB") w, h = image.size @@ -108,8 +57,8 @@ def load_img(path, max_height=512, max_width=512): @lru_cache() def load_model(): - config = ("data/stable-diffusion-v1.yaml",) - ckpt = ("data/stable-diffusion-v1-4.ckpt",) + config = "data/stable-diffusion-v1.yaml" + ckpt = "data/stable-diffusion-v1-4.ckpt" config = OmegaConf.load(f"{LIB_PATH}/{config}") model = load_model_from_config(config, f"{LIB_PATH}/{ckpt}") @@ -117,24 +66,69 @@ def load_model(): return model -def imagine( +def imagine_image_files( prompts, - outdir="outputs/txt2img-samples", + outdir, latent_channels=4, downsampling_factor=8, precision="autocast", - skip_save=False, ddim_eta=0.0, + record_steps=False +): + big_path = os.path.join(outdir, "upscaled") + os.makedirs(outdir, exist_ok=True) + os.makedirs(big_path, exist_ok=True) + base_count = len(os.listdir(outdir)) + step_count = 0 + + def _record_steps(samples, i, model, prompt): + nonlocal step_count + step_count += 1 + samples = model.decode_first_stage(samples) + samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0) + steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}") + os.makedirs(steps_path, exist_ok=True) + for pred_x0 in samples: + pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c") + filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}.jpg" + Image.fromarray(pred_x0.astype(np.uint8)).save( + os.path.join(steps_path, filename) + ) + img_callback = _record_steps if record_steps else None + for result in imagine_images( + prompts, + latent_channels=latent_channels, + downsampling_factor=downsampling_factor, + precision=precision, + ddim_eta=ddim_eta, + img_callback=img_callback, + ): + prompt = result.prompt + img = result.img + basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}" + filepath = os.path.join(outdir, f"{basefilename}.jpg") + + img.save(filepath) + if prompt.upscale: + enlarge_realesrgan2x( + filepath, + os.path.join(big_path, basefilename) + ".jpg", + ) + base_count += 1 + + +def imagine_images( + prompts, + latent_channels=4, + downsampling_factor=8, + precision="autocast", + ddim_eta=0.0, + img_callback=None, ): model = load_model() - os.makedirs(outdir, exist_ok=True) - outpath = outdir - - sample_path = os.path.join(outpath) - big_path = os.path.join(sample_path, "esrgan") - os.makedirs(sample_path, exist_ok=True) - os.makedirs(big_path, exist_ok=True) - base_count = len(os.listdir(sample_path)) + prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts + prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts + _img_callback = None precision_scope = autocast if precision == "autocast" else nullcontext with (torch.no_grad(), precision_scope("cuda")): @@ -150,6 +144,9 @@ def imagine( for wp in prompt.prompts ] ) + if img_callback: + def _img_callback(samples, i): + img_callback(samples, i, model, prompt) shape = [ latent_channels, @@ -157,41 +154,29 @@ def imagine( prompt.width // downsampling_factor, ] - def img_callback(samples, i): - pass - samples = model.decode_first_stage(samples) - samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0) - steps_path = os.path.join( - sample_path, "steps", f"{base_count:08}_S{prompt.seed}" - ) - os.makedirs(steps_path, exist_ok=True) - for pred_x0 in samples: - pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c") - filename = f"{base_count:08}_S{prompt.seed}_step{i:04}.jpg" - Image.fromarray(pred_x0.astype(np.uint8)).save( - os.path.join(steps_path, filename) - ) - start_code = None sampler = get_sampler(prompt.sampler_type, model) if prompt.init_image: generation_strength = 1 - prompt.init_image_strength ddim_steps = int(prompt.steps / generation_strength) sampler.make_schedule( - ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False + ddim_num_steps=ddim_steps, ddim_eta=ddim_eta ) t_enc = int(generation_strength * ddim_steps) init_image, w, h = load_img(prompt.init_image) init_image = init_image.to(get_device()) - init_latent = model.get_first_stage_encoding( - model.encode_first_stage(init_image) - ) + init_latent = model.encode_first_stage(init_image) + noised_init_latent = model.get_first_stage_encoding(init_latent) + _img_callback(init_latent.mean, 0) + _img_callback(noised_init_latent, 0) # encode (scaled latent) z_enc = sampler.stochastic_encode( - init_latent, torch.tensor([t_enc]).to(get_device()) + noised_init_latent, torch.tensor([t_enc]).to(get_device()), ) + _img_callback(noised_init_latent, 0) + # decode it samples = sampler.decode( z_enc, @@ -199,7 +184,7 @@ def imagine( t_enc, unconditional_guidance_scale=prompt.prompt_strength, unconditional_conditioning=uc, - img_callback=img_callback, + img_callback=_img_callback, ) else: @@ -208,35 +193,27 @@ def imagine( conditioning=c, batch_size=1, shape=shape, - verbose=False, unconditional_guidance_scale=prompt.prompt_strength, unconditional_conditioning=uc, eta=ddim_eta, x_T=start_code, - img_callback=img_callback, + img_callback=_img_callback, ) x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - if not skip_save: - for x_sample in x_samples: - x_sample = 255.0 * rearrange( - x_sample.cpu().numpy(), "c h w -> h w c" - ) - basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}" - filepath = os.path.join(sample_path, f"{basefilename}.jpg") - img = Image.fromarray(x_sample.astype(np.uint8)) - if prompt.fix_faces: - img = fix_faces_GFPGAN(img) - img.save(filepath) - if prompt.upscale: - enlarge_realesrgan2x( - filepath, - os.path.join(big_path, basefilename) + ".jpg", - ) - base_count += 1 - return f"{basefilename}.jpg" + for x_sample in x_samples: + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") + img = Image.fromarray(x_sample.astype(np.uint8)) + if prompt.fix_faces: + img = fix_faces_GFPGAN(img) + # if prompt.upscale: + # enlarge_realesrgan2x( + # filepath, + # os.path.join(big_path, basefilename) + ".jpg", + # ) + yield ImagineResult(img=img, prompt=prompt) def prompt_normalized(prompt): @@ -287,7 +264,3 @@ def fix_faces_GFPGAN(image): res = Image.fromarray(restored_img) return res - - -if __name__ == "__main__": - main() diff --git a/imaginairy/models/diffusion/ddim.py b/imaginairy/models/diffusion/ddim.py index 3fe8684..7aa6ec7 100644 --- a/imaginairy/models/diffusion/ddim.py +++ b/imaginairy/models/diffusion/ddim.py @@ -31,13 +31,12 @@ class DDIMSampler: setattr(self, name, attr) def make_schedule( - self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0 ): self.ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, ) alphas_cumprod = self.model.alphas_cumprod assert ( @@ -75,7 +74,6 @@ class DDIMSampler: alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, - verbose=verbose, ) self.register_buffer("ddim_sigmas", ddim_sigmas) self.register_buffer("ddim_alphas", ddim_alphas) @@ -108,7 +106,6 @@ class DDIMSampler: noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, - verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.0, @@ -129,7 +126,7 @@ class DDIMSampler: f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" ) - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + self.make_schedule(ddim_num_steps=S, ddim_eta=eta) # sampling C, H, W = shape size = (batch_size, C, H, W) @@ -230,8 +227,6 @@ class DDIMSampler: quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, ) @@ -239,6 +234,7 @@ class DDIMSampler: callback(i) if img_callback: img_callback(pred_x0, i) + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates["x_inter"].append(img) @@ -246,7 +242,7 @@ class DDIMSampler: return img, intermediates - @torch.no_grad() + # @torch.no_grad() def p_sample_ddim( self, x, @@ -258,27 +254,22 @@ class DDIMSampler: quantize_denoised=False, temperature=1.0, noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + loss_function=None ): b, *_, device = *x.shape, x.device if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: - e_t = self.model.apply_model(x, t, c) + with torch.no_grad(): + noise_pred = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs - ) + # with torch.no_grad(): + noise_pred_uncond, noise_pred = self.model.apply_model(x_in, t_in, c_in).chunk(2) + noise_pred = noise_pred_uncond + unconditional_guidance_scale * (noise_pred - noise_pred_uncond) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = ( @@ -305,11 +296,12 @@ class DDIMSampler: ) # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + pred_x0 = (x - sqrt_one_minus_at * noise_pred) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * noise_pred + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) @@ -344,6 +336,8 @@ class DDIMSampler: unconditional_conditioning=None, use_original_steps=False, img_callback=None, + score_corrector=None, + temperature=1.0 ): timesteps = ( @@ -359,6 +353,7 @@ class DDIMSampler: iterator = tqdm(time_range, desc="Decoding image", total=total_steps) x_dec = x_latent + for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full( @@ -372,7 +367,15 @@ class DDIMSampler: use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, + temperature=temperature ) + # original_loss = ((x_dec - x_latent).abs().mean()*70) + # sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device()) + # x_dec = x_dec.detach() + (original_loss * 0.1) ** 2 + # cond_grad = -torch.autograd.grad(original_loss, x_dec)[0] + # x_dec = x_dec.detach() + cond_grad * sigma_t ** 2 + ## x_dec_alt = x_dec + (original_loss * 0.1) ** 2 if img_callback: + img_callback(x_dec, i) img_callback(pred_x0, i) return x_dec diff --git a/imaginairy/models/diffusion/plms.py b/imaginairy/models/diffusion/plms.py index 9283c6c..6506c53 100644 --- a/imaginairy/models/diffusion/plms.py +++ b/imaginairy/models/diffusion/plms.py @@ -29,16 +29,13 @@ class PLMSSampler(object): attr = attr.to(torch.float32).to(torch.device(self.device_available)) setattr(self, name, attr) - def make_schedule( - self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True - ): + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0): if ddim_eta != 0: raise ValueError("ddim_eta must be 0 for PLMS") self.ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, ) alphas_cumprod = self.model.alphas_cumprod assert ( @@ -76,7 +73,6 @@ class PLMSSampler(object): alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, - verbose=verbose, ) self.register_buffer("ddim_sigmas", ddim_sigmas) self.register_buffer("ddim_alphas", ddim_alphas) @@ -109,7 +105,6 @@ class PLMSSampler(object): noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, - verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.0, @@ -130,7 +125,7 @@ class PLMSSampler(object): f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" ) - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + self.make_schedule(ddim_num_steps=S, ddim_eta=eta) # sampling C, H, W = shape size = (batch_size, C, H, W) @@ -252,6 +247,7 @@ class PLMSSampler(object): if callback: callback(i) if img_callback: + img_callback(img, i) img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: diff --git a/imaginairy/modules/clip_embedders.py b/imaginairy/modules/clip_embedders.py index 9c2ea92..ff12938 100644 --- a/imaginairy/modules/clip_embedders.py +++ b/imaginairy/modules/clip_embedders.py @@ -8,7 +8,7 @@ from transformers import CLIPTokenizer, CLIPTextModel from imaginairy.utils import get_device -class FrozenCLIPEmbedder: +class FrozenCLIPEmbedder(nn.Module): """Uses the CLIP transformer encoder for text (from Hugging Face)""" def __init__( diff --git a/imaginairy/schema.py b/imaginairy/schema.py new file mode 100644 index 0000000..fb1e767 --- /dev/null +++ b/imaginairy/schema.py @@ -0,0 +1,70 @@ +import hashlib +import random + +import numpy + + +class WeightedPrompt: + def __init__(self, text, weight=1): + self.text = text + self.weight = weight + + def __str__(self): + return f"{self.weight}*({self.text})" + + +class ImaginePrompt: + def __init__( + self, + prompt=None, + seed=None, + prompt_strength=7.5, + sampler_type="PLMS", + init_image=None, + init_image_strength=0.3, + steps=50, + height=512, + width=512, + upscale=False, + fix_faces=False, + parts=None, + ): + prompt = prompt if prompt is not None else "a scenic landscape" + if isinstance(prompt, str): + self.prompts = [WeightedPrompt(prompt, 1)] + else: + self.prompts = prompt + self.init_image = init_image + self.init_image_strength = init_image_strength + self.prompts.sort(key=lambda p: p.weight, reverse=True) + self.seed = random.randint(1, 1_000_000_000) if seed is None else seed + self.prompt_strength = prompt_strength + self.sampler_type = sampler_type + self.steps = steps + self.height = height + self.width = width + self.upscale = upscale + self.fix_faces = fix_faces + self.parts = parts or {} + + @property + def prompt_text(self): + if len(self.prompts) == 1: + return self.prompts[0].text + return "|".join(str(p) for p in self.prompts) + + +class ImagineResult: + def __init__(self, img, prompt): + self.img = img + self.prompt = prompt + + def cv2_img(self): + open_cv_image = numpy.array(self.img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + return open_cv_image + # return cv2.cvtColor(numpy.array(self.img), cv2.COLOR_RGB2BGR) + + def md5(self): + return hashlib.md5(self.img.tobytes()).hexdigest() diff --git a/requirements-dev.in b/requirements-dev.in new file mode 100644 index 0000000..55b033e --- /dev/null +++ b/requirements-dev.in @@ -0,0 +1 @@ +pytest \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..4d27b58 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +import os.path + +TESTS_FOLDER = os.path.dirname(__file__) \ No newline at end of file diff --git a/tests/test_imagine.py b/tests/test_imagine.py new file mode 100644 index 0000000..283627f --- /dev/null +++ b/tests/test_imagine.py @@ -0,0 +1,60 @@ +from imaginairy.imagine import imagine_images, imagine_image_files +from imaginairy.schema import ImaginePrompt, WeightedPrompt +from . import TESTS_FOLDER + + +def test_imagine(): + prompt = ImaginePrompt("a scenic landscape", width=512, height=256, steps=20, seed=1) + result = next(imagine_images(prompt)) + assert result.md5() == '4c5957c498881d365cfcf13014812af0' + result.img.save(f"{TESTS_FOLDER}/test_output/scenic_landscape.png") + + +def test_img_to_img(): + prompt = ImaginePrompt( + "a photo of a beach", + init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg", + init_image_strength=0.5, + width=512, + height=512, + steps=50, + seed=1, + sampler_type="DDIM", + ) + out_folder = f"{TESTS_FOLDER}/test_output" + out_folder = '/home/bryce/Mounts/drennanfiles/art/tests' + imagine_image_files(prompt, outdir=out_folder) + + +def test_img_to_file(): + prompt = ImaginePrompt( + [WeightedPrompt("an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo")], + # init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg", + init_image_strength=0.5, + width=512+64, + height=512-64, + steps=50, + # seed=2, + sampler_type="PLMS", + upscale=True + ) + out_folder = f"{TESTS_FOLDER}/test_output" + out_folder = '/home/bryce/Mounts/drennanfiles/art/tests' + imagine_image_files(prompt, outdir=out_folder) + + +def test_img_conditioning(): + prompt = ImaginePrompt( + "photo", + init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg", + init_image_strength=0.5, + width=512+64, + height=512-64, + steps=50, + # seed=2, + sampler_type="PLMS", + upscale=True + ) + out_folder = f"{TESTS_FOLDER}/test_output" + out_folder = '/home/bryce/Mounts/drennanfiles/art/tests' + imagine_image_files(prompt, outdir=out_folder, record_steps=True) \ No newline at end of file