mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
refactor: cleanup image generation code
This commit is contained in:
parent
8a97213622
commit
ea1d4baafe
@ -195,18 +195,20 @@ def _generate_single_image(
|
||||
half_mode=None,
|
||||
add_caption=False,
|
||||
):
|
||||
import numpy as np
|
||||
import torch.nn
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image, ImageOps
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch.cuda import OutOfMemoryError
|
||||
|
||||
from imaginairy.enhancers.clip_masking import get_img_mask
|
||||
from imaginairy.enhancers.describe_image_blip import generate_caption
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
||||
from imaginairy.img_utils import (
|
||||
pillow_fit_image_within,
|
||||
pillow_img_to_torch_image,
|
||||
pillow_mask_to_latent_mask,
|
||||
torch_img_to_pillow_img,
|
||||
)
|
||||
from imaginairy.log_utils import (
|
||||
ImageLoggingContext,
|
||||
log_conditioning,
|
||||
@ -214,7 +216,7 @@ def _generate_single_image(
|
||||
log_latent,
|
||||
)
|
||||
from imaginairy.model_manager import get_diffusion_model
|
||||
from imaginairy.modules.midas.utils import AddMiDaS
|
||||
from imaginairy.modules.midas.api import torch_image_to_depth_map
|
||||
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
|
||||
from imaginairy.safety import create_safety_score
|
||||
from imaginairy.samplers import SAMPLER_LOOKUP
|
||||
@ -241,7 +243,6 @@ def _generate_single_image(
|
||||
half_mode=half_mode,
|
||||
for_inpainting=prompt.mask_image or prompt.mask_prompt or prompt.outpaint,
|
||||
)
|
||||
has_depth_channel = hasattr(model, "depth_stage_key")
|
||||
progress_latents = []
|
||||
|
||||
def latent_logger(latents):
|
||||
@ -279,9 +280,14 @@ def _generate_single_image(
|
||||
]
|
||||
SamplerCls = SAMPLER_LOOKUP[prompt.sampler_type.lower()]
|
||||
sampler = SamplerCls(model)
|
||||
mask = mask_image = mask_image_orig = mask_grayscale = None
|
||||
mask_latent = mask_image = mask_image_orig = mask_grayscale = None
|
||||
t_enc = init_latent = init_latent_noised = None
|
||||
starting_image = None
|
||||
denoiser_cls = None
|
||||
|
||||
c_cat = []
|
||||
c_cat_neutral = None
|
||||
result_images = {}
|
||||
if prompt.init_image:
|
||||
starting_image = prompt.init_image
|
||||
generation_strength = 1 - prompt.init_image_strength
|
||||
@ -321,25 +327,12 @@ def _generate_single_image(
|
||||
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
|
||||
mask_image = ImageOps.invert(mask_image)
|
||||
|
||||
log_img(
|
||||
Image.composite(init_image, mask_image, mask_image),
|
||||
"mask overlay",
|
||||
)
|
||||
mask_image_orig = mask_image
|
||||
mask_image = mask_image.resize(
|
||||
(
|
||||
mask_image.width // downsampling_factor,
|
||||
mask_image.height // downsampling_factor,
|
||||
),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
log_img(mask_image, "latent_mask")
|
||||
mask_latent = pillow_mask_to_latent_mask(
|
||||
mask_image, downsampling_factor=downsampling_factor
|
||||
).to(get_device())
|
||||
|
||||
mask = np.array(mask_image)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.to(get_device())
|
||||
init_image_t = pillow_img_to_torch_image(init_image)
|
||||
init_image_t = init_image_t.to(get_device())
|
||||
init_latent = model.get_first_stage_encoding(
|
||||
@ -348,7 +341,7 @@ def _generate_single_image(
|
||||
shape = init_latent.shape
|
||||
|
||||
log_latent(init_latent, "init_latent")
|
||||
# encode (scaled latent)
|
||||
|
||||
seed_everything(prompt.seed)
|
||||
noise = randn_seeded(seed=prompt.seed, size=init_latent.size())
|
||||
noise = noise.to(get_device())
|
||||
@ -371,90 +364,45 @@ def _generate_single_image(
|
||||
schedule=schedule,
|
||||
noise=noise,
|
||||
)
|
||||
batch_size = 1
|
||||
log_latent(init_latent_noised, "init_latent_noised")
|
||||
batch = {
|
||||
"txt": batch_size * [prompt.prompt_text],
|
||||
}
|
||||
c_cat = []
|
||||
c_cat_neutral = None
|
||||
depth_image_display = None
|
||||
if has_depth_channel and starting_image:
|
||||
midas_model = AddMiDaS()
|
||||
_init_image_d = np.array(starting_image.convert("RGB"))
|
||||
_init_image_d = (
|
||||
torch.from_numpy(_init_image_d).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
)
|
||||
depth_image = midas_model(_init_image_d)
|
||||
depth_image = torch.from_numpy(depth_image[None, ...])
|
||||
batch[model.depth_stage_key] = depth_image.to(device=get_device())
|
||||
_init_image_d = rearrange(_init_image_d, "h w c -> 1 c h w")
|
||||
batch["jpg"] = _init_image_d
|
||||
for ck in model.concat_keys:
|
||||
cc = batch[ck]
|
||||
cc = model.depth_model(cc)
|
||||
depth_min, depth_max = torch.amin(
|
||||
cc, dim=[1, 2, 3], keepdim=True
|
||||
), torch.amax(cc, dim=[1, 2, 3], keepdim=True)
|
||||
display_depth = (cc - depth_min) / (depth_max - depth_min)
|
||||
depth_image_display = Image.fromarray(
|
||||
(display_depth[0, 0, ...].cpu().numpy() * 255.0).astype(np.uint8)
|
||||
)
|
||||
cc = torch.nn.functional.interpolate(
|
||||
cc,
|
||||
|
||||
if hasattr(model, "depth_stage_key"):
|
||||
# depth model
|
||||
depth_t = torch_image_to_depth_map(init_image_t)
|
||||
depth_latent = torch.nn.functional.interpolate(
|
||||
depth_t,
|
||||
size=shape[2:],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
depth_min, depth_max = torch.amin(
|
||||
cc, dim=[1, 2, 3], keepdim=True
|
||||
), torch.amax(cc, dim=[1, 2, 3], keepdim=True)
|
||||
cc = 2.0 * (cc - depth_min) / (depth_max - depth_min) - 1.0
|
||||
c_cat.append(cc)
|
||||
c_cat = [torch.cat(c_cat, dim=1)]
|
||||
result_images["depth_image"] = depth_t
|
||||
c_cat.append(depth_latent)
|
||||
|
||||
if mask_image_orig and not has_depth_channel:
|
||||
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
|
||||
get_device()
|
||||
)
|
||||
inverted_mask = 1 - mask
|
||||
masked_image_t = init_image_t * (mask_t < 0.5)
|
||||
batch.update(
|
||||
{
|
||||
"image": repeat(
|
||||
init_image_t.to(device=get_device()),
|
||||
"1 ... -> n ...",
|
||||
n=batch_size,
|
||||
),
|
||||
"txt": batch_size * [prompt.prompt_text],
|
||||
"mask": repeat(
|
||||
inverted_mask.to(device=get_device()),
|
||||
"1 ... -> n ...",
|
||||
n=batch_size,
|
||||
),
|
||||
"masked_image": repeat(
|
||||
masked_image_t.to(device=get_device()),
|
||||
"1 ... -> n ...",
|
||||
n=batch_size,
|
||||
),
|
||||
}
|
||||
)
|
||||
elif hasattr(model, "masked_image_key"):
|
||||
# inpainting model
|
||||
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
|
||||
get_device()
|
||||
)
|
||||
inverted_mask = 1 - mask_latent
|
||||
masked_image_t = init_image_t * (mask_t < 0.5)
|
||||
log_img(masked_image_t, "masked_image")
|
||||
|
||||
for concat_key in getattr(model, "concat_keys", []):
|
||||
cc = batch[concat_key].float()
|
||||
if concat_key != model.masked_image_key:
|
||||
bchw = [batch_size, 4, shape[2], shape[3]]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
inverted_mask_latent = torch.nn.functional.interpolate(
|
||||
inverted_mask, size=shape[-2:]
|
||||
)
|
||||
c_cat.append(inverted_mask_latent)
|
||||
|
||||
masked_image_latent = model.get_first_stage_encoding(
|
||||
model.encode_first_stage(masked_image_t)
|
||||
)
|
||||
c_cat.append(masked_image_latent)
|
||||
|
||||
elif model.cond_stage_key == "edit":
|
||||
# pix2pix model
|
||||
c_cat = [model.encode_first_stage(init_image_t).mode()]
|
||||
c_cat_neutral = [torch.zeros_like(init_latent)]
|
||||
denoiser_cls = CFGEditingDenoiser
|
||||
if c_cat:
|
||||
c_cat = [torch.cat(c_cat, dim=1)]
|
||||
denoiser_cls = None
|
||||
if model.cond_stage_key == "edit":
|
||||
c_cat = [model.encode_first_stage(init_image_t).mode()]
|
||||
c_cat_neutral = [torch.zeros_like(init_latent)]
|
||||
denoiser_cls = CFGEditingDenoiser
|
||||
|
||||
if c_cat_neutral is None:
|
||||
c_cat_neutral = c_cat
|
||||
@ -467,6 +415,8 @@ def _generate_single_image(
|
||||
"c_concat": c_cat_neutral,
|
||||
"c_crossattn": [neutral_conditioning],
|
||||
}
|
||||
log_latent(init_latent_noised, "init_latent_noised")
|
||||
|
||||
with lc.timing("sampling"):
|
||||
samples = sampler.sample(
|
||||
num_steps=prompt.steps,
|
||||
@ -475,105 +425,74 @@ def _generate_single_image(
|
||||
neutral_conditioning=neutral_conditioning,
|
||||
guidance_scale=prompt.prompt_strength,
|
||||
t_start=t_enc,
|
||||
mask=mask,
|
||||
mask=mask_latent,
|
||||
orig_latent=init_latent,
|
||||
shape=shape,
|
||||
batch_size=1,
|
||||
denoiser_cls=denoiser_cls,
|
||||
)
|
||||
# from torch.nn.functional import interpolate
|
||||
# samples = interpolate(samples, scale_factor=2, mode='nearest')
|
||||
|
||||
with lc.timing("decoding"):
|
||||
try:
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
except OutOfMemoryError:
|
||||
model.cond_stage_model.to("cpu")
|
||||
model.model.to("cpu")
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
model.cond_stage_model.to(get_device())
|
||||
model.model.to(get_device())
|
||||
gen_imgs_t = model.decode_first_stage(samples)
|
||||
gen_img = torch_img_to_pillow_img(gen_imgs_t)
|
||||
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
if mask_image_orig and init_image:
|
||||
mask_final = mask_image_orig.copy()
|
||||
log_img(mask_final, "reconstituting mask")
|
||||
mask_final = ImageOps.invert(mask_final)
|
||||
gen_img = Image.composite(gen_img, init_image, mask_final)
|
||||
log_img(gen_img, "reconstituted image")
|
||||
|
||||
for x_sample in x_samples:
|
||||
x_sample = x_sample.to(torch.float32)
|
||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
||||
x_sample_8_orig = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample_8_orig)
|
||||
if mask_image_orig and init_image:
|
||||
# mask_final = mask_image_orig.filter(
|
||||
# ImageFilter.GaussianBlur(radius=3)
|
||||
# )
|
||||
mask_final = mask_image_orig.copy()
|
||||
log_img(mask_final, "reconstituting mask")
|
||||
mask_final = ImageOps.invert(mask_final)
|
||||
img = Image.composite(img, init_image, mask_final)
|
||||
log_img(img, "reconstituted image")
|
||||
upscaled_img = None
|
||||
rebuilt_orig_img = None
|
||||
|
||||
upscaled_img = None
|
||||
rebuilt_orig_img = None
|
||||
if add_caption:
|
||||
caption = generate_caption(gen_img)
|
||||
logger.info(f"Generated caption: {caption}")
|
||||
|
||||
if add_caption:
|
||||
caption = generate_caption(img)
|
||||
logger.info(f"Generated caption: {caption}")
|
||||
|
||||
with lc.timing("safety-filter"):
|
||||
safety_score = create_safety_score(
|
||||
img,
|
||||
safety_mode=IMAGINAIRY_SAFETY_MODE,
|
||||
)
|
||||
if safety_score.is_filtered:
|
||||
progress_latents.clear()
|
||||
if not safety_score.is_filtered:
|
||||
if prompt.fix_faces:
|
||||
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
|
||||
with lc.timing("face enhancement"):
|
||||
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
|
||||
if prompt.upscale:
|
||||
logger.info("Upscaling 🖼 using real-ESRGAN...")
|
||||
with lc.timing("upscaling"):
|
||||
upscaled_img = upscale_image(img)
|
||||
|
||||
# put the newly generated patch back into the original, full size image
|
||||
if prompt.mask_modify_original and mask_image_orig and starting_image:
|
||||
img_to_add_back_to_original = upscaled_img if upscaled_img else img
|
||||
img_to_add_back_to_original = img_to_add_back_to_original.resize(
|
||||
starting_image.size,
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
|
||||
mask_for_orig_size = mask_image_orig.resize(
|
||||
starting_image.size,
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
# mask_for_orig_size = mask_for_orig_size.filter(
|
||||
# ImageFilter.GaussianBlur(radius=5)
|
||||
# )
|
||||
log_img(mask_for_orig_size, "mask for original image size")
|
||||
|
||||
rebuilt_orig_img = Image.composite(
|
||||
starting_image,
|
||||
img_to_add_back_to_original,
|
||||
mask_for_orig_size,
|
||||
)
|
||||
log_img(rebuilt_orig_img, "reconstituted original")
|
||||
|
||||
result = ImagineResult(
|
||||
img=img,
|
||||
prompt=prompt,
|
||||
upscaled_img=upscaled_img,
|
||||
is_nsfw=safety_score.is_nsfw,
|
||||
safety_score=safety_score,
|
||||
modified_original=rebuilt_orig_img,
|
||||
mask_binary=mask_image_orig,
|
||||
mask_grayscale=mask_grayscale,
|
||||
depth_image=depth_image_display,
|
||||
timings=lc.get_timings(),
|
||||
progress_latents=progress_latents.copy(),
|
||||
with lc.timing("safety-filter"):
|
||||
safety_score = create_safety_score(
|
||||
gen_img,
|
||||
safety_mode=IMAGINAIRY_SAFETY_MODE,
|
||||
)
|
||||
_most_recent_result = result
|
||||
logger.info(f"Image Generated. Timings: {result.timings_str()}")
|
||||
return result
|
||||
if safety_score.is_filtered:
|
||||
progress_latents.clear()
|
||||
if not safety_score.is_filtered:
|
||||
if prompt.fix_faces:
|
||||
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
|
||||
with lc.timing("face enhancement"):
|
||||
gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity)
|
||||
if prompt.upscale:
|
||||
logger.info("Upscaling 🖼 using real-ESRGAN...")
|
||||
with lc.timing("upscaling"):
|
||||
upscaled_img = upscale_image(gen_img)
|
||||
|
||||
# put the newly generated patch back into the original, full-size image
|
||||
if prompt.mask_modify_original and mask_image_orig and starting_image:
|
||||
img_to_add_back_to_original = upscaled_img if upscaled_img else gen_img
|
||||
rebuilt_orig_img = combine_image(
|
||||
original_img=starting_image,
|
||||
generated_img=img_to_add_back_to_original,
|
||||
mask_img=mask_image_orig,
|
||||
)
|
||||
|
||||
result = ImagineResult(
|
||||
img=gen_img,
|
||||
prompt=prompt,
|
||||
upscaled_img=upscaled_img,
|
||||
is_nsfw=safety_score.is_nsfw,
|
||||
safety_score=safety_score,
|
||||
modified_original=rebuilt_orig_img,
|
||||
mask_binary=mask_image_orig,
|
||||
mask_grayscale=mask_grayscale,
|
||||
result_images=result_images,
|
||||
timings=lc.get_timings(),
|
||||
progress_latents=progress_latents.copy(),
|
||||
)
|
||||
|
||||
_most_recent_result = result
|
||||
logger.info(f"Image Generated. Timings: {result.timings_str()}")
|
||||
return result
|
||||
|
||||
|
||||
def _prompts_to_embeddings(prompts, model):
|
||||
@ -587,3 +506,29 @@ def _prompts_to_embeddings(prompts, model):
|
||||
|
||||
def prompt_normalized(prompt):
|
||||
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]
|
||||
|
||||
|
||||
def combine_image(original_img, generated_img, mask_img):
|
||||
"""Combine the generated image with the original image using the mask image."""
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy.log_utils import log_img
|
||||
|
||||
generated_img = generated_img.resize(
|
||||
original_img.size,
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
|
||||
mask_for_orig_size = mask_img.resize(
|
||||
original_img.size,
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
log_img(mask_for_orig_size, "mask for original image size")
|
||||
|
||||
rebuilt_orig_img = Image.composite(
|
||||
original_img,
|
||||
generated_img,
|
||||
mask_for_orig_size,
|
||||
)
|
||||
log_img(rebuilt_orig_img, "reconstituted original")
|
||||
return rebuilt_orig_img
|
||||
|
@ -34,20 +34,28 @@ def pillow_fit_image_within(
|
||||
return image
|
||||
|
||||
|
||||
def pillow_img_to_torch_image(img: PIL.Image.Image):
|
||||
img = img.convert("RGB")
|
||||
def pillow_img_to_torch_image(img: PIL.Image.Image, convert="RGB"):
|
||||
if convert:
|
||||
img = img.convert(convert)
|
||||
img = np.array(img).astype(np.float32) / 255.0
|
||||
img = img[None].transpose(0, 3, 1, 2)
|
||||
img = torch.from_numpy(img)
|
||||
return 2.0 * img - 1.0
|
||||
|
||||
|
||||
def torch_img_to_pillow_img(img: torch.Tensor):
|
||||
img = rearrange(img, "b c h w -> b h w c")
|
||||
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
img = (255.0 * img).cpu().numpy().astype(np.uint8)
|
||||
img = Image.fromarray(img[0])
|
||||
return img
|
||||
def pillow_mask_to_latent_mask(mask_img: PIL.Image.Image, downsampling_factor):
|
||||
mask_img = mask_img.resize(
|
||||
(
|
||||
mask_img.width // downsampling_factor,
|
||||
mask_img.height // downsampling_factor,
|
||||
),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
|
||||
mask = np.array(mask_img).astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask
|
||||
|
||||
|
||||
def pillow_img_to_opencv_img(img: PIL.Image.Image):
|
||||
@ -57,22 +65,42 @@ def pillow_img_to_opencv_img(img: PIL.Image.Image):
|
||||
return open_cv_image
|
||||
|
||||
|
||||
def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Image]:
|
||||
def torch_img_to_pillow_img(img_t: torch.Tensor):
|
||||
if len(img_t.shape) == 3:
|
||||
img_t = img_t.unsqueeze(0)
|
||||
if img_t.shape[0] != 1:
|
||||
raise ValueError("Only batch size 1 supported")
|
||||
if img_t.shape[1] == 1:
|
||||
colorspace = "L"
|
||||
elif img_t.shape[1] == 3:
|
||||
colorspace = "RGB"
|
||||
else:
|
||||
raise ValueError("Unsupported colorspace")
|
||||
img_t = rearrange(img_t, "b c h w -> b h w c")
|
||||
img_t = torch.clamp((img_t + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
img_np = (255.0 * img_t).cpu().numpy().astype(np.uint8)[0]
|
||||
if colorspace == "L":
|
||||
img_np = img_np[:, :, 0]
|
||||
return Image.fromarray(img_np, colorspace)
|
||||
|
||||
|
||||
def model_latent_to_pillow_img(latent: torch.Tensor) -> PIL.Image.Image:
|
||||
from imaginairy.model_manager import get_current_diffusion_model # noqa
|
||||
|
||||
if len(latent.shape) == 3:
|
||||
latent = latent.unsqueeze(0)
|
||||
if latent.shape[0] != 1:
|
||||
raise ValueError("Only batch size 1 supported")
|
||||
model = get_current_diffusion_model()
|
||||
latents = model.decode_first_stage(latents)
|
||||
latents = torch.clamp((latents + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
imgs = []
|
||||
for latent in latents:
|
||||
latent = 255.0 * rearrange(latent.cpu().numpy(), "c h w -> h w c")
|
||||
img = Image.fromarray(latent.astype(np.uint8))
|
||||
imgs.append(img)
|
||||
return imgs
|
||||
img_t = model.decode_first_stage(latent)
|
||||
return torch_img_to_pillow_img(img_t)
|
||||
|
||||
|
||||
def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Image]:
|
||||
return [model_latent_to_pillow_img(latent) for latent in latents]
|
||||
|
||||
|
||||
def pillow_img_to_model_latent(model, img, batch_size=1, half=True):
|
||||
# init_image = pil_img_to_torch(img, half=half).to(device)
|
||||
init_image = pillow_img_to_torch_image(img).to(get_device())
|
||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||
if half:
|
||||
|
@ -1,7 +1,9 @@
|
||||
# based on https://github.com/isl-org/MiDaS
|
||||
from functools import lru_cache
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
@ -13,6 +15,7 @@ from imaginairy.modules.midas.midas.transforms import (
|
||||
PrepareForNet,
|
||||
Resize,
|
||||
)
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
ISL_PATHS = {
|
||||
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
|
||||
@ -151,6 +154,30 @@ def load_model(model_type):
|
||||
return model.eval(), transform
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def load_midas(model_type="dpt_hybrid"):
|
||||
model = MiDaSInference(model_type)
|
||||
model.to(get_device())
|
||||
return model
|
||||
|
||||
|
||||
def torch_image_to_depth_map(image_t: torch.Tensor, model_type="dpt_hybrid"):
|
||||
model = load_midas(model_type)
|
||||
transform = load_midas_transform(model_type)
|
||||
image_t = rearrange(image_t, "b c h w -> b h w c")[0]
|
||||
image_np = ((image_t + 1.0) * 0.5).detach().cpu().numpy()
|
||||
image_np = transform({"image": image_np})["image"]
|
||||
image_t = torch.from_numpy(image_np[None, ...])
|
||||
image_t = image_t.to(device=get_device())
|
||||
|
||||
depth_t = model(image_t)
|
||||
depth_min = torch.amin(depth_t, dim=[1, 2, 3], keepdim=True)
|
||||
depth_max = torch.amax(depth_t, dim=[1, 2, 3], keepdim=True)
|
||||
|
||||
depth_t = (depth_t - depth_min) / (depth_max - depth_min)
|
||||
return depth_t
|
||||
|
||||
|
||||
class MiDaSInference(nn.Module):
|
||||
MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"]
|
||||
MODEL_TYPES_ISL = [
|
||||
|
@ -278,10 +278,13 @@ class ImagineResult:
|
||||
modified_original=None,
|
||||
mask_binary=None,
|
||||
mask_grayscale=None,
|
||||
depth_image=None,
|
||||
result_images=None,
|
||||
timings=None,
|
||||
progress_latents=None,
|
||||
):
|
||||
import torch
|
||||
|
||||
from imaginairy.img_utils import torch_img_to_pillow_img
|
||||
from imaginairy.utils import get_device, get_hardware_description
|
||||
|
||||
self.prompt = prompt
|
||||
@ -300,8 +303,10 @@ class ImagineResult:
|
||||
if mask_grayscale:
|
||||
self.images["mask_grayscale"] = mask_grayscale
|
||||
|
||||
if depth_image is not None:
|
||||
self.images["depth_image"] = depth_image
|
||||
for img_type, r_img in result_images.items():
|
||||
if isinstance(r_img, torch.Tensor):
|
||||
r_img = torch_img_to_pillow_img(r_img)
|
||||
self.images[img_type] = r_img
|
||||
|
||||
self.timings = timings
|
||||
self.progress_latents = progress_latents
|
||||
|
@ -158,7 +158,7 @@ def test_img_to_img_fruit_2_gold(
|
||||
"k_euler_a": 18000,
|
||||
"k_dpm_adaptive": 13000,
|
||||
}
|
||||
threshold = threshold_lookup.get(sampler_type, 11000)
|
||||
threshold = threshold_lookup.get(sampler_type, 14000)
|
||||
|
||||
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg")
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
|
@ -58,9 +58,9 @@ def test_feather_tile_simple(img_ratio, tile_size, overlap_pct):
|
||||
tile_size, overlap_pct, (img.size(2), img.size(3))
|
||||
)
|
||||
|
||||
print(
|
||||
f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
|
||||
)
|
||||
# print(
|
||||
# f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
|
||||
# )
|
||||
|
||||
rebuilt = rebuild_image(
|
||||
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_pct
|
||||
@ -86,9 +86,9 @@ def test_feather_tile_brute():
|
||||
tile_coords, tile_size, overlap = tile_setup(
|
||||
tile_size, overlap_percent, (img.size(2), img.size(3))
|
||||
)
|
||||
print(
|
||||
f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
|
||||
)
|
||||
# print(
|
||||
# f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
|
||||
# )
|
||||
|
||||
rebuilt = rebuild_image(
|
||||
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_percent
|
||||
@ -100,13 +100,11 @@ def test_feather_tile_brute():
|
||||
torch_img_to_pillow_img(rebuilt).show()
|
||||
torch_img_to_pillow_img((rebuilt - img) * 20).show()
|
||||
|
||||
status = "🚫 FAILED"
|
||||
|
||||
else:
|
||||
status = "✅ PASSED"
|
||||
print(
|
||||
f"{status}: img:{img.shape} tile_size={tile_size} overlap_percent={overlap_percent} diff={diff}"
|
||||
)
|
||||
pass
|
||||
# print(
|
||||
# f"{status}: img:{img.shape} tile_size={tile_size} overlap_percent={overlap_percent} diff={diff}"
|
||||
# )
|
||||
assert diff < 1
|
||||
|
||||
for tile_size_pct, overlap_percent, img_ratio, flip_ratio in itertools.product(
|
||||
@ -123,9 +121,9 @@ def test_feather_tile_brute():
|
||||
if overlap_percent >= 0.5:
|
||||
continue
|
||||
|
||||
print(
|
||||
f"img_ratio={img_ratio}, tile_size_pct={tile_size_pct}, overlap_percent={overlap_percent}, tile_size={tile_size} img.shape={img.shape}"
|
||||
)
|
||||
# print(
|
||||
# f"img_ratio={img_ratio}, tile_size_pct={tile_size_pct}, overlap_percent={overlap_percent}, tile_size={tile_size} img.shape={img.shape}"
|
||||
# )
|
||||
tile_untile(img, tile_size=tile_size, overlap_percent=overlap_percent)
|
||||
del img
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user