feature: generate large images with coherent composition

pull/259/head
Bryce 2 years ago committed by Bryce Drennan
parent c3a88c44cd
commit 882cc7e0f1

@ -300,7 +300,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
- perf: tiled encoding of images (removes memory bottleneck)
- feature: 🎉🎉 Make large images while retaining composition. Try `imagine "a flower" -w 1920 -h 1080 --upscale`
- perf: sliced encoding of images to latents (removes memory bottleneck)
- perf: use Silu for performance improvement over nonlinearity
- perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost.
- perf: sliced attention now runs on MacOS. A typo prevented that from happening previously.
@ -555,6 +556,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- https://github.com/huggingface/diffusers/pull/532/files
- https://github.com/HazyResearch/flash-attention
- https://github.com/chavinlo/sda-node
- https://github.com/AminRezaei0x443/memory-efficient-attention/issues/7
- Development Environment
- ✅ add tests

@ -3,6 +3,7 @@ import math
import os
import re
from imaginairy.enhancers.upscale_riverwing import upscale_latent
from imaginairy.schema import SafetyMode
logger = logging.getLogger(__name__)
@ -269,6 +270,7 @@ def _generate_single_image(
with lc.timing("conditioning"):
# need to expand if doing batches
neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model)
_prompts_to_embeddings("", model)
log_conditioning(neutral_conditioning, "neutral conditioning")
if prompt.conditioning is not None:
positive_conditioning = prompt.conditioning
@ -425,40 +427,43 @@ def _generate_single_image(
}
log_latent(init_latent_noised, "init_latent_noised")
comp_samples = _generate_composition_latent(
sampler=sampler,
sampler_kwargs={
"num_steps": prompt.steps,
"initial_latent": init_latent_noised,
"positive_conditioning": positive_conditioning,
"neutral_conditioning": neutral_conditioning,
"guidance_scale": prompt.prompt_strength,
"t_start": t_enc,
"mask": mask_latent,
"orig_latent": init_latent,
"shape": shape,
"batch_size": 1,
"denoiser_cls": denoiser_cls,
},
)
if comp_samples is not None:
noise = noise[:, :, : comp_samples.shape[2], : comp_samples.shape[3]]
schedule = NoiseSchedule(
model_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
model_alphas_cumprod=model.alphas_cumprod,
ddim_discretize="uniform",
)
t_enc = int(prompt.steps * 0.8)
init_latent_noised = noise_an_image(
comp_samples,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
if prompt.allow_compose_phase:
comp_samples = _generate_composition_latent(
sampler=sampler,
sampler_kwargs={
"num_steps": prompt.steps,
"initial_latent": init_latent_noised,
"positive_conditioning": positive_conditioning,
"neutral_conditioning": neutral_conditioning,
"guidance_scale": prompt.prompt_strength,
"t_start": t_enc,
"mask": mask_latent,
"orig_latent": init_latent,
"shape": shape,
"batch_size": 1,
"denoiser_cls": denoiser_cls,
},
)
if comp_samples is not None:
result_images["composition"] = comp_samples
noise = noise[:, :, : comp_samples.shape[2], : comp_samples.shape[3]]
schedule = NoiseSchedule(
model_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
model_alphas_cumprod=model.alphas_cumprod,
ddim_discretize="uniform",
)
t_enc = int(prompt.steps * 0.75)
init_latent_noised = noise_an_image(
comp_samples,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
)
log_latent(comp_samples, "comp_samples")
log_latent(comp_samples, "comp_samples")
with lc.timing("sampling"):
samples = sampler.sample(
num_steps=prompt.steps,
@ -575,8 +580,7 @@ def _generate_composition_latent(
from torch.nn import functional as F
new_kwargs = deepcopy(sampler_kwargs)
b, c, h, w = orig_shape = new_kwargs["shape"]
b, c, h, w = orig_shape = sampler_kwargs["shape"]
max_compose_gen_size = 768
shrink_scale = calc_scale_to_fit_within(
height=h,
@ -586,6 +590,8 @@ def _generate_composition_latent(
if shrink_scale >= 1:
return None
new_kwargs = deepcopy(sampler_kwargs)
# shrink everything
new_shape = b, c, int(round(h * shrink_scale)), int(round(w * shrink_scale))
initial_latent = new_kwargs["initial_latent"]
@ -622,7 +628,7 @@ def _generate_composition_latent(
}
)
samples = sampler.sample(**new_kwargs)
# samples = upscale_latent(samples)
samples = upscale_latent(samples)
samples = F.interpolate(samples, size=orig_shape[2:], mode="bilinear")
return samples

@ -136,6 +136,11 @@ common_options = [
is_flag=True,
help="Any images rendered will be tileable in the Y direction.",
),
click.option(
"--allow-compose-phase/--no-compose-phase",
default=True,
help="Allow the image to be composed at a lower resolution.",
),
click.option(
"--mask-image",
metavar="PATH|URL",
@ -342,6 +347,7 @@ def imagine_cmd(
tile,
tile_x,
tile_y,
allow_compose_phase,
mask_image,
mask_prompt,
mask_mode,
@ -387,6 +393,7 @@ def imagine_cmd(
tile,
tile_x,
tile_y,
allow_compose_phase,
mask_image,
mask_prompt,
mask_mode,
@ -591,6 +598,7 @@ def _imagine_cmd(
tile,
tile_x,
tile_y,
allow_compose_phase,
mask_image,
mask_prompt,
mask_mode,
@ -705,6 +713,7 @@ def _imagine_cmd(
fix_faces=fix_faces,
fix_faces_fidelity=fix_faces_fidelity,
tile_mode=_tile_mode,
allow_compose_phase=allow_compose_phase,
model=model_weights_path,
model_config_path=model_config_path,
)

@ -23,7 +23,13 @@ class NoiseLevelAndTextConditionedUpscaler(nn.Module):
def forward(self, inp, sigma, low_res, low_res_sigma, c, **kwargs):
cross_cond, cross_cond_padding, pooler = c
c_in = 1 / (low_res_sigma**2 + self.sigma_data**2) ** 0.5
sigma_data = self.sigma_data
# 'MPS does not support power op with int64 input'
if isinstance(low_res_sigma, torch.Tensor):
low_res_sigma = low_res_sigma.to(torch.float32)
if isinstance(sigma_data, torch.Tensor):
sigma_data = sigma_data.to(torch.float32)
c_in = 1 / (low_res_sigma**2 + sigma_data**2) ** 0.5
c_noise = low_res_sigma.log1p()[:, None]
c_in = append_dims(c_in, low_res.ndim)
low_res_noise_embed = self.low_res_noise_embed(c_noise)
@ -200,7 +206,7 @@ def upscale_latent(
low_res_latent,
upscale_prompt="",
seed=0,
steps=30,
steps=15,
guidance_scale=1.0,
batch_size=1,
num_samples=1,

@ -110,6 +110,7 @@ class ImaginePrompt:
sampler_type=config.DEFAULT_SAMPLER,
conditioning=None,
tile_mode="",
allow_compose_phase=True,
model=config.DEFAULT_MODEL,
model_config_path=None,
is_intermediate=False,
@ -136,8 +137,10 @@ class ImaginePrompt:
self.mask_modify_original = mask_modify_original
self.outpaint = outpaint
self.tile_mode = tile_mode
self.allow_compose_phase = allow_compose_phase
self.model = model
self.model_config_path = model_config_path
# we don't want to save intermediate images
self.is_intermediate = is_intermediate
self.collect_progress_latents = collect_progress_latents
@ -284,7 +287,10 @@ class ImagineResult:
):
import torch
from imaginairy.img_utils import torch_img_to_pillow_img
from imaginairy.img_utils import (
model_latent_to_pillow_img,
torch_img_to_pillow_img,
)
from imaginairy.utils import get_device, get_hardware_description
self.prompt = prompt
@ -305,7 +311,10 @@ class ImagineResult:
for img_type, r_img in result_images.items():
if isinstance(r_img, torch.Tensor):
r_img = torch_img_to_pillow_img(r_img)
if r_img.shape[1] == 4:
r_img = model_latent_to_pillow_img(r_img)
else:
r_img = torch_img_to_pillow_img(r_img)
self.images[img_type] = r_img
self.timings = timings

@ -142,6 +142,10 @@ class UBlock(layers.ConditionedSequential):
def forward(self, input, cond, skip=None):
if skip is not None:
if input.shape[-2:] != skip.shape[-2:]:
input = nn.functional.interpolate(
input, size=skip.shape[-2:], mode="bilinear"
)
input = torch.cat([input, skip], dim=1)
return super().forward(input, cond)

Loading…
Cancel
Save