feature: generate large images with coherent composition

pull/259/head
Bryce 2 years ago committed by Bryce Drennan
parent c3a88c44cd
commit 882cc7e0f1

@ -300,7 +300,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog ## ChangeLog
- perf: tiled encoding of images (removes memory bottleneck) - feature: 🎉🎉 Make large images while retaining composition. Try `imagine "a flower" -w 1920 -h 1080 --upscale`
- perf: sliced encoding of images to latents (removes memory bottleneck)
- perf: use Silu for performance improvement over nonlinearity - perf: use Silu for performance improvement over nonlinearity
- perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost. - perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost.
- perf: sliced attention now runs on MacOS. A typo prevented that from happening previously. - perf: sliced attention now runs on MacOS. A typo prevented that from happening previously.
@ -555,6 +556,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- https://github.com/huggingface/diffusers/pull/532/files - https://github.com/huggingface/diffusers/pull/532/files
- https://github.com/HazyResearch/flash-attention - https://github.com/HazyResearch/flash-attention
- https://github.com/chavinlo/sda-node - https://github.com/chavinlo/sda-node
- https://github.com/AminRezaei0x443/memory-efficient-attention/issues/7
- Development Environment - Development Environment
- ✅ add tests - ✅ add tests

@ -3,6 +3,7 @@ import math
import os import os
import re import re
from imaginairy.enhancers.upscale_riverwing import upscale_latent
from imaginairy.schema import SafetyMode from imaginairy.schema import SafetyMode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -269,6 +270,7 @@ def _generate_single_image(
with lc.timing("conditioning"): with lc.timing("conditioning"):
# need to expand if doing batches # need to expand if doing batches
neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model) neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model)
_prompts_to_embeddings("", model)
log_conditioning(neutral_conditioning, "neutral conditioning") log_conditioning(neutral_conditioning, "neutral conditioning")
if prompt.conditioning is not None: if prompt.conditioning is not None:
positive_conditioning = prompt.conditioning positive_conditioning = prompt.conditioning
@ -425,40 +427,43 @@ def _generate_single_image(
} }
log_latent(init_latent_noised, "init_latent_noised") log_latent(init_latent_noised, "init_latent_noised")
comp_samples = _generate_composition_latent( if prompt.allow_compose_phase:
sampler=sampler, comp_samples = _generate_composition_latent(
sampler_kwargs={ sampler=sampler,
"num_steps": prompt.steps, sampler_kwargs={
"initial_latent": init_latent_noised, "num_steps": prompt.steps,
"positive_conditioning": positive_conditioning, "initial_latent": init_latent_noised,
"neutral_conditioning": neutral_conditioning, "positive_conditioning": positive_conditioning,
"guidance_scale": prompt.prompt_strength, "neutral_conditioning": neutral_conditioning,
"t_start": t_enc, "guidance_scale": prompt.prompt_strength,
"mask": mask_latent, "t_start": t_enc,
"orig_latent": init_latent, "mask": mask_latent,
"shape": shape, "orig_latent": init_latent,
"batch_size": 1, "shape": shape,
"denoiser_cls": denoiser_cls, "batch_size": 1,
}, "denoiser_cls": denoiser_cls,
) },
if comp_samples is not None:
noise = noise[:, :, : comp_samples.shape[2], : comp_samples.shape[3]]
schedule = NoiseSchedule(
model_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
model_alphas_cumprod=model.alphas_cumprod,
ddim_discretize="uniform",
)
t_enc = int(prompt.steps * 0.8)
init_latent_noised = noise_an_image(
comp_samples,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
) )
if comp_samples is not None:
result_images["composition"] = comp_samples
noise = noise[:, :, : comp_samples.shape[2], : comp_samples.shape[3]]
schedule = NoiseSchedule(
model_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
model_alphas_cumprod=model.alphas_cumprod,
ddim_discretize="uniform",
)
t_enc = int(prompt.steps * 0.75)
init_latent_noised = noise_an_image(
comp_samples,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
)
log_latent(comp_samples, "comp_samples")
log_latent(comp_samples, "comp_samples")
with lc.timing("sampling"): with lc.timing("sampling"):
samples = sampler.sample( samples = sampler.sample(
num_steps=prompt.steps, num_steps=prompt.steps,
@ -575,8 +580,7 @@ def _generate_composition_latent(
from torch.nn import functional as F from torch.nn import functional as F
new_kwargs = deepcopy(sampler_kwargs) b, c, h, w = orig_shape = sampler_kwargs["shape"]
b, c, h, w = orig_shape = new_kwargs["shape"]
max_compose_gen_size = 768 max_compose_gen_size = 768
shrink_scale = calc_scale_to_fit_within( shrink_scale = calc_scale_to_fit_within(
height=h, height=h,
@ -586,6 +590,8 @@ def _generate_composition_latent(
if shrink_scale >= 1: if shrink_scale >= 1:
return None return None
new_kwargs = deepcopy(sampler_kwargs)
# shrink everything # shrink everything
new_shape = b, c, int(round(h * shrink_scale)), int(round(w * shrink_scale)) new_shape = b, c, int(round(h * shrink_scale)), int(round(w * shrink_scale))
initial_latent = new_kwargs["initial_latent"] initial_latent = new_kwargs["initial_latent"]
@ -622,7 +628,7 @@ def _generate_composition_latent(
} }
) )
samples = sampler.sample(**new_kwargs) samples = sampler.sample(**new_kwargs)
# samples = upscale_latent(samples) samples = upscale_latent(samples)
samples = F.interpolate(samples, size=orig_shape[2:], mode="bilinear") samples = F.interpolate(samples, size=orig_shape[2:], mode="bilinear")
return samples return samples

@ -136,6 +136,11 @@ common_options = [
is_flag=True, is_flag=True,
help="Any images rendered will be tileable in the Y direction.", help="Any images rendered will be tileable in the Y direction.",
), ),
click.option(
"--allow-compose-phase/--no-compose-phase",
default=True,
help="Allow the image to be composed at a lower resolution.",
),
click.option( click.option(
"--mask-image", "--mask-image",
metavar="PATH|URL", metavar="PATH|URL",
@ -342,6 +347,7 @@ def imagine_cmd(
tile, tile,
tile_x, tile_x,
tile_y, tile_y,
allow_compose_phase,
mask_image, mask_image,
mask_prompt, mask_prompt,
mask_mode, mask_mode,
@ -387,6 +393,7 @@ def imagine_cmd(
tile, tile,
tile_x, tile_x,
tile_y, tile_y,
allow_compose_phase,
mask_image, mask_image,
mask_prompt, mask_prompt,
mask_mode, mask_mode,
@ -591,6 +598,7 @@ def _imagine_cmd(
tile, tile,
tile_x, tile_x,
tile_y, tile_y,
allow_compose_phase,
mask_image, mask_image,
mask_prompt, mask_prompt,
mask_mode, mask_mode,
@ -705,6 +713,7 @@ def _imagine_cmd(
fix_faces=fix_faces, fix_faces=fix_faces,
fix_faces_fidelity=fix_faces_fidelity, fix_faces_fidelity=fix_faces_fidelity,
tile_mode=_tile_mode, tile_mode=_tile_mode,
allow_compose_phase=allow_compose_phase,
model=model_weights_path, model=model_weights_path,
model_config_path=model_config_path, model_config_path=model_config_path,
) )

@ -23,7 +23,13 @@ class NoiseLevelAndTextConditionedUpscaler(nn.Module):
def forward(self, inp, sigma, low_res, low_res_sigma, c, **kwargs): def forward(self, inp, sigma, low_res, low_res_sigma, c, **kwargs):
cross_cond, cross_cond_padding, pooler = c cross_cond, cross_cond_padding, pooler = c
c_in = 1 / (low_res_sigma**2 + self.sigma_data**2) ** 0.5 sigma_data = self.sigma_data
# 'MPS does not support power op with int64 input'
if isinstance(low_res_sigma, torch.Tensor):
low_res_sigma = low_res_sigma.to(torch.float32)
if isinstance(sigma_data, torch.Tensor):
sigma_data = sigma_data.to(torch.float32)
c_in = 1 / (low_res_sigma**2 + sigma_data**2) ** 0.5
c_noise = low_res_sigma.log1p()[:, None] c_noise = low_res_sigma.log1p()[:, None]
c_in = append_dims(c_in, low_res.ndim) c_in = append_dims(c_in, low_res.ndim)
low_res_noise_embed = self.low_res_noise_embed(c_noise) low_res_noise_embed = self.low_res_noise_embed(c_noise)
@ -200,7 +206,7 @@ def upscale_latent(
low_res_latent, low_res_latent,
upscale_prompt="", upscale_prompt="",
seed=0, seed=0,
steps=30, steps=15,
guidance_scale=1.0, guidance_scale=1.0,
batch_size=1, batch_size=1,
num_samples=1, num_samples=1,

@ -110,6 +110,7 @@ class ImaginePrompt:
sampler_type=config.DEFAULT_SAMPLER, sampler_type=config.DEFAULT_SAMPLER,
conditioning=None, conditioning=None,
tile_mode="", tile_mode="",
allow_compose_phase=True,
model=config.DEFAULT_MODEL, model=config.DEFAULT_MODEL,
model_config_path=None, model_config_path=None,
is_intermediate=False, is_intermediate=False,
@ -136,8 +137,10 @@ class ImaginePrompt:
self.mask_modify_original = mask_modify_original self.mask_modify_original = mask_modify_original
self.outpaint = outpaint self.outpaint = outpaint
self.tile_mode = tile_mode self.tile_mode = tile_mode
self.allow_compose_phase = allow_compose_phase
self.model = model self.model = model
self.model_config_path = model_config_path self.model_config_path = model_config_path
# we don't want to save intermediate images # we don't want to save intermediate images
self.is_intermediate = is_intermediate self.is_intermediate = is_intermediate
self.collect_progress_latents = collect_progress_latents self.collect_progress_latents = collect_progress_latents
@ -284,7 +287,10 @@ class ImagineResult:
): ):
import torch import torch
from imaginairy.img_utils import torch_img_to_pillow_img from imaginairy.img_utils import (
model_latent_to_pillow_img,
torch_img_to_pillow_img,
)
from imaginairy.utils import get_device, get_hardware_description from imaginairy.utils import get_device, get_hardware_description
self.prompt = prompt self.prompt = prompt
@ -305,7 +311,10 @@ class ImagineResult:
for img_type, r_img in result_images.items(): for img_type, r_img in result_images.items():
if isinstance(r_img, torch.Tensor): if isinstance(r_img, torch.Tensor):
r_img = torch_img_to_pillow_img(r_img) if r_img.shape[1] == 4:
r_img = model_latent_to_pillow_img(r_img)
else:
r_img = torch_img_to_pillow_img(r_img)
self.images[img_type] = r_img self.images[img_type] = r_img
self.timings = timings self.timings = timings

@ -142,6 +142,10 @@ class UBlock(layers.ConditionedSequential):
def forward(self, input, cond, skip=None): def forward(self, input, cond, skip=None):
if skip is not None: if skip is not None:
if input.shape[-2:] != skip.shape[-2:]:
input = nn.functional.interpolate(
input, size=skip.shape[-2:], mode="bilinear"
)
input = torch.cat([input, skip], dim=1) input = torch.cat([input, skip], dim=1)
return super().forward(input, cond) return super().forward(input, cond)

Loading…
Cancel
Save