mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
feature: generate large images with coherent composition
This commit is contained in:
parent
c3a88c44cd
commit
882cc7e0f1
@ -300,7 +300,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
|||||||
|
|
||||||
## ChangeLog
|
## ChangeLog
|
||||||
|
|
||||||
- perf: tiled encoding of images (removes memory bottleneck)
|
- feature: 🎉🎉 Make large images while retaining composition. Try `imagine "a flower" -w 1920 -h 1080 --upscale`
|
||||||
|
- perf: sliced encoding of images to latents (removes memory bottleneck)
|
||||||
- perf: use Silu for performance improvement over nonlinearity
|
- perf: use Silu for performance improvement over nonlinearity
|
||||||
- perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost.
|
- perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost.
|
||||||
- perf: sliced attention now runs on MacOS. A typo prevented that from happening previously.
|
- perf: sliced attention now runs on MacOS. A typo prevented that from happening previously.
|
||||||
@ -555,6 +556,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
|
|||||||
- https://github.com/huggingface/diffusers/pull/532/files
|
- https://github.com/huggingface/diffusers/pull/532/files
|
||||||
- https://github.com/HazyResearch/flash-attention
|
- https://github.com/HazyResearch/flash-attention
|
||||||
- https://github.com/chavinlo/sda-node
|
- https://github.com/chavinlo/sda-node
|
||||||
|
- https://github.com/AminRezaei0x443/memory-efficient-attention/issues/7
|
||||||
|
|
||||||
- Development Environment
|
- Development Environment
|
||||||
- ✅ add tests
|
- ✅ add tests
|
||||||
|
@ -3,6 +3,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from imaginairy.enhancers.upscale_riverwing import upscale_latent
|
||||||
from imaginairy.schema import SafetyMode
|
from imaginairy.schema import SafetyMode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -269,6 +270,7 @@ def _generate_single_image(
|
|||||||
with lc.timing("conditioning"):
|
with lc.timing("conditioning"):
|
||||||
# need to expand if doing batches
|
# need to expand if doing batches
|
||||||
neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model)
|
neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model)
|
||||||
|
_prompts_to_embeddings("", model)
|
||||||
log_conditioning(neutral_conditioning, "neutral conditioning")
|
log_conditioning(neutral_conditioning, "neutral conditioning")
|
||||||
if prompt.conditioning is not None:
|
if prompt.conditioning is not None:
|
||||||
positive_conditioning = prompt.conditioning
|
positive_conditioning = prompt.conditioning
|
||||||
@ -425,6 +427,7 @@ def _generate_single_image(
|
|||||||
}
|
}
|
||||||
log_latent(init_latent_noised, "init_latent_noised")
|
log_latent(init_latent_noised, "init_latent_noised")
|
||||||
|
|
||||||
|
if prompt.allow_compose_phase:
|
||||||
comp_samples = _generate_composition_latent(
|
comp_samples = _generate_composition_latent(
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
sampler_kwargs={
|
sampler_kwargs={
|
||||||
@ -442,6 +445,7 @@ def _generate_single_image(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
if comp_samples is not None:
|
if comp_samples is not None:
|
||||||
|
result_images["composition"] = comp_samples
|
||||||
noise = noise[:, :, : comp_samples.shape[2], : comp_samples.shape[3]]
|
noise = noise[:, :, : comp_samples.shape[2], : comp_samples.shape[3]]
|
||||||
|
|
||||||
schedule = NoiseSchedule(
|
schedule = NoiseSchedule(
|
||||||
@ -450,7 +454,7 @@ def _generate_single_image(
|
|||||||
model_alphas_cumprod=model.alphas_cumprod,
|
model_alphas_cumprod=model.alphas_cumprod,
|
||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
)
|
)
|
||||||
t_enc = int(prompt.steps * 0.8)
|
t_enc = int(prompt.steps * 0.75)
|
||||||
init_latent_noised = noise_an_image(
|
init_latent_noised = noise_an_image(
|
||||||
comp_samples,
|
comp_samples,
|
||||||
torch.tensor([t_enc - 1]).to(get_device()),
|
torch.tensor([t_enc - 1]).to(get_device()),
|
||||||
@ -459,6 +463,7 @@ def _generate_single_image(
|
|||||||
)
|
)
|
||||||
|
|
||||||
log_latent(comp_samples, "comp_samples")
|
log_latent(comp_samples, "comp_samples")
|
||||||
|
|
||||||
with lc.timing("sampling"):
|
with lc.timing("sampling"):
|
||||||
samples = sampler.sample(
|
samples = sampler.sample(
|
||||||
num_steps=prompt.steps,
|
num_steps=prompt.steps,
|
||||||
@ -575,8 +580,7 @@ def _generate_composition_latent(
|
|||||||
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
new_kwargs = deepcopy(sampler_kwargs)
|
b, c, h, w = orig_shape = sampler_kwargs["shape"]
|
||||||
b, c, h, w = orig_shape = new_kwargs["shape"]
|
|
||||||
max_compose_gen_size = 768
|
max_compose_gen_size = 768
|
||||||
shrink_scale = calc_scale_to_fit_within(
|
shrink_scale = calc_scale_to_fit_within(
|
||||||
height=h,
|
height=h,
|
||||||
@ -586,6 +590,8 @@ def _generate_composition_latent(
|
|||||||
if shrink_scale >= 1:
|
if shrink_scale >= 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
new_kwargs = deepcopy(sampler_kwargs)
|
||||||
|
|
||||||
# shrink everything
|
# shrink everything
|
||||||
new_shape = b, c, int(round(h * shrink_scale)), int(round(w * shrink_scale))
|
new_shape = b, c, int(round(h * shrink_scale)), int(round(w * shrink_scale))
|
||||||
initial_latent = new_kwargs["initial_latent"]
|
initial_latent = new_kwargs["initial_latent"]
|
||||||
@ -622,7 +628,7 @@ def _generate_composition_latent(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
samples = sampler.sample(**new_kwargs)
|
samples = sampler.sample(**new_kwargs)
|
||||||
# samples = upscale_latent(samples)
|
samples = upscale_latent(samples)
|
||||||
samples = F.interpolate(samples, size=orig_shape[2:], mode="bilinear")
|
samples = F.interpolate(samples, size=orig_shape[2:], mode="bilinear")
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@ -136,6 +136,11 @@ common_options = [
|
|||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="Any images rendered will be tileable in the Y direction.",
|
help="Any images rendered will be tileable in the Y direction.",
|
||||||
),
|
),
|
||||||
|
click.option(
|
||||||
|
"--allow-compose-phase/--no-compose-phase",
|
||||||
|
default=True,
|
||||||
|
help="Allow the image to be composed at a lower resolution.",
|
||||||
|
),
|
||||||
click.option(
|
click.option(
|
||||||
"--mask-image",
|
"--mask-image",
|
||||||
metavar="PATH|URL",
|
metavar="PATH|URL",
|
||||||
@ -342,6 +347,7 @@ def imagine_cmd(
|
|||||||
tile,
|
tile,
|
||||||
tile_x,
|
tile_x,
|
||||||
tile_y,
|
tile_y,
|
||||||
|
allow_compose_phase,
|
||||||
mask_image,
|
mask_image,
|
||||||
mask_prompt,
|
mask_prompt,
|
||||||
mask_mode,
|
mask_mode,
|
||||||
@ -387,6 +393,7 @@ def imagine_cmd(
|
|||||||
tile,
|
tile,
|
||||||
tile_x,
|
tile_x,
|
||||||
tile_y,
|
tile_y,
|
||||||
|
allow_compose_phase,
|
||||||
mask_image,
|
mask_image,
|
||||||
mask_prompt,
|
mask_prompt,
|
||||||
mask_mode,
|
mask_mode,
|
||||||
@ -591,6 +598,7 @@ def _imagine_cmd(
|
|||||||
tile,
|
tile,
|
||||||
tile_x,
|
tile_x,
|
||||||
tile_y,
|
tile_y,
|
||||||
|
allow_compose_phase,
|
||||||
mask_image,
|
mask_image,
|
||||||
mask_prompt,
|
mask_prompt,
|
||||||
mask_mode,
|
mask_mode,
|
||||||
@ -705,6 +713,7 @@ def _imagine_cmd(
|
|||||||
fix_faces=fix_faces,
|
fix_faces=fix_faces,
|
||||||
fix_faces_fidelity=fix_faces_fidelity,
|
fix_faces_fidelity=fix_faces_fidelity,
|
||||||
tile_mode=_tile_mode,
|
tile_mode=_tile_mode,
|
||||||
|
allow_compose_phase=allow_compose_phase,
|
||||||
model=model_weights_path,
|
model=model_weights_path,
|
||||||
model_config_path=model_config_path,
|
model_config_path=model_config_path,
|
||||||
)
|
)
|
||||||
|
@ -23,7 +23,13 @@ class NoiseLevelAndTextConditionedUpscaler(nn.Module):
|
|||||||
|
|
||||||
def forward(self, inp, sigma, low_res, low_res_sigma, c, **kwargs):
|
def forward(self, inp, sigma, low_res, low_res_sigma, c, **kwargs):
|
||||||
cross_cond, cross_cond_padding, pooler = c
|
cross_cond, cross_cond_padding, pooler = c
|
||||||
c_in = 1 / (low_res_sigma**2 + self.sigma_data**2) ** 0.5
|
sigma_data = self.sigma_data
|
||||||
|
# 'MPS does not support power op with int64 input'
|
||||||
|
if isinstance(low_res_sigma, torch.Tensor):
|
||||||
|
low_res_sigma = low_res_sigma.to(torch.float32)
|
||||||
|
if isinstance(sigma_data, torch.Tensor):
|
||||||
|
sigma_data = sigma_data.to(torch.float32)
|
||||||
|
c_in = 1 / (low_res_sigma**2 + sigma_data**2) ** 0.5
|
||||||
c_noise = low_res_sigma.log1p()[:, None]
|
c_noise = low_res_sigma.log1p()[:, None]
|
||||||
c_in = append_dims(c_in, low_res.ndim)
|
c_in = append_dims(c_in, low_res.ndim)
|
||||||
low_res_noise_embed = self.low_res_noise_embed(c_noise)
|
low_res_noise_embed = self.low_res_noise_embed(c_noise)
|
||||||
@ -200,7 +206,7 @@ def upscale_latent(
|
|||||||
low_res_latent,
|
low_res_latent,
|
||||||
upscale_prompt="",
|
upscale_prompt="",
|
||||||
seed=0,
|
seed=0,
|
||||||
steps=30,
|
steps=15,
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_samples=1,
|
num_samples=1,
|
||||||
|
@ -110,6 +110,7 @@ class ImaginePrompt:
|
|||||||
sampler_type=config.DEFAULT_SAMPLER,
|
sampler_type=config.DEFAULT_SAMPLER,
|
||||||
conditioning=None,
|
conditioning=None,
|
||||||
tile_mode="",
|
tile_mode="",
|
||||||
|
allow_compose_phase=True,
|
||||||
model=config.DEFAULT_MODEL,
|
model=config.DEFAULT_MODEL,
|
||||||
model_config_path=None,
|
model_config_path=None,
|
||||||
is_intermediate=False,
|
is_intermediate=False,
|
||||||
@ -136,8 +137,10 @@ class ImaginePrompt:
|
|||||||
self.mask_modify_original = mask_modify_original
|
self.mask_modify_original = mask_modify_original
|
||||||
self.outpaint = outpaint
|
self.outpaint = outpaint
|
||||||
self.tile_mode = tile_mode
|
self.tile_mode = tile_mode
|
||||||
|
self.allow_compose_phase = allow_compose_phase
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_config_path = model_config_path
|
self.model_config_path = model_config_path
|
||||||
|
|
||||||
# we don't want to save intermediate images
|
# we don't want to save intermediate images
|
||||||
self.is_intermediate = is_intermediate
|
self.is_intermediate = is_intermediate
|
||||||
self.collect_progress_latents = collect_progress_latents
|
self.collect_progress_latents = collect_progress_latents
|
||||||
@ -284,7 +287,10 @@ class ImagineResult:
|
|||||||
):
|
):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from imaginairy.img_utils import torch_img_to_pillow_img
|
from imaginairy.img_utils import (
|
||||||
|
model_latent_to_pillow_img,
|
||||||
|
torch_img_to_pillow_img,
|
||||||
|
)
|
||||||
from imaginairy.utils import get_device, get_hardware_description
|
from imaginairy.utils import get_device, get_hardware_description
|
||||||
|
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
@ -305,6 +311,9 @@ class ImagineResult:
|
|||||||
|
|
||||||
for img_type, r_img in result_images.items():
|
for img_type, r_img in result_images.items():
|
||||||
if isinstance(r_img, torch.Tensor):
|
if isinstance(r_img, torch.Tensor):
|
||||||
|
if r_img.shape[1] == 4:
|
||||||
|
r_img = model_latent_to_pillow_img(r_img)
|
||||||
|
else:
|
||||||
r_img = torch_img_to_pillow_img(r_img)
|
r_img = torch_img_to_pillow_img(r_img)
|
||||||
self.images[img_type] = r_img
|
self.images[img_type] = r_img
|
||||||
|
|
||||||
|
@ -142,6 +142,10 @@ class UBlock(layers.ConditionedSequential):
|
|||||||
|
|
||||||
def forward(self, input, cond, skip=None):
|
def forward(self, input, cond, skip=None):
|
||||||
if skip is not None:
|
if skip is not None:
|
||||||
|
if input.shape[-2:] != skip.shape[-2:]:
|
||||||
|
input = nn.functional.interpolate(
|
||||||
|
input, size=skip.shape[-2:], mode="bilinear"
|
||||||
|
)
|
||||||
input = torch.cat([input, skip], dim=1)
|
input = torch.cat([input, skip], dim=1)
|
||||||
return super().forward(input, cond)
|
return super().forward(input, cond)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user