From ea1d4baafe995342945dfd0298825523e868b5cb Mon Sep 17 00:00:00 2001 From: Bryce Date: Wed, 15 Feb 2023 08:02:36 -0800 Subject: [PATCH] refactor: cleanup image generation code --- imaginairy/api.py | 321 +++++++++++++------------------- imaginairy/img_utils.py | 64 +++++-- imaginairy/modules/midas/api.py | 27 +++ imaginairy/schema.py | 11 +- tests/test_api.py | 2 +- tests/test_feather_tile.py | 28 ++- 6 files changed, 228 insertions(+), 225 deletions(-) diff --git a/imaginairy/api.py b/imaginairy/api.py index b30dee5..05ffca8 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -195,18 +195,20 @@ def _generate_single_image( half_mode=None, add_caption=False, ): - import numpy as np import torch.nn - from einops import rearrange, repeat from PIL import Image, ImageOps from pytorch_lightning import seed_everything - from torch.cuda import OutOfMemoryError from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.face_restoration_codeformer import enhance_faces from imaginairy.enhancers.upscale_realesrgan import upscale_image - from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image + from imaginairy.img_utils import ( + pillow_fit_image_within, + pillow_img_to_torch_image, + pillow_mask_to_latent_mask, + torch_img_to_pillow_img, + ) from imaginairy.log_utils import ( ImageLoggingContext, log_conditioning, @@ -214,7 +216,7 @@ def _generate_single_image( log_latent, ) from imaginairy.model_manager import get_diffusion_model - from imaginairy.modules.midas.utils import AddMiDaS + from imaginairy.modules.midas.api import torch_image_to_depth_map from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint from imaginairy.safety import create_safety_score from imaginairy.samplers import SAMPLER_LOOKUP @@ -241,7 +243,6 @@ def _generate_single_image( half_mode=half_mode, for_inpainting=prompt.mask_image or prompt.mask_prompt or prompt.outpaint, ) - has_depth_channel = hasattr(model, "depth_stage_key") progress_latents = [] def latent_logger(latents): @@ -279,9 +280,14 @@ def _generate_single_image( ] SamplerCls = SAMPLER_LOOKUP[prompt.sampler_type.lower()] sampler = SamplerCls(model) - mask = mask_image = mask_image_orig = mask_grayscale = None + mask_latent = mask_image = mask_image_orig = mask_grayscale = None t_enc = init_latent = init_latent_noised = None starting_image = None + denoiser_cls = None + + c_cat = [] + c_cat_neutral = None + result_images = {} if prompt.init_image: starting_image = prompt.init_image generation_strength = 1 - prompt.init_image_strength @@ -321,25 +327,12 @@ def _generate_single_image( if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE: mask_image = ImageOps.invert(mask_image) - log_img( - Image.composite(init_image, mask_image, mask_image), - "mask overlay", - ) mask_image_orig = mask_image - mask_image = mask_image.resize( - ( - mask_image.width // downsampling_factor, - mask_image.height // downsampling_factor, - ), - resample=Image.Resampling.LANCZOS, - ) log_img(mask_image, "latent_mask") + mask_latent = pillow_mask_to_latent_mask( + mask_image, downsampling_factor=downsampling_factor + ).to(get_device()) - mask = np.array(mask_image) - mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] - mask = torch.from_numpy(mask) - mask = mask.to(get_device()) init_image_t = pillow_img_to_torch_image(init_image) init_image_t = init_image_t.to(get_device()) init_latent = model.get_first_stage_encoding( @@ -348,7 +341,7 @@ def _generate_single_image( shape = init_latent.shape log_latent(init_latent, "init_latent") - # encode (scaled latent) + seed_everything(prompt.seed) noise = randn_seeded(seed=prompt.seed, size=init_latent.size()) noise = noise.to(get_device()) @@ -371,90 +364,45 @@ def _generate_single_image( schedule=schedule, noise=noise, ) - batch_size = 1 - log_latent(init_latent_noised, "init_latent_noised") - batch = { - "txt": batch_size * [prompt.prompt_text], - } - c_cat = [] - c_cat_neutral = None - depth_image_display = None - if has_depth_channel and starting_image: - midas_model = AddMiDaS() - _init_image_d = np.array(starting_image.convert("RGB")) - _init_image_d = ( - torch.from_numpy(_init_image_d).to(dtype=torch.float32) / 127.5 - 1.0 - ) - depth_image = midas_model(_init_image_d) - depth_image = torch.from_numpy(depth_image[None, ...]) - batch[model.depth_stage_key] = depth_image.to(device=get_device()) - _init_image_d = rearrange(_init_image_d, "h w c -> 1 c h w") - batch["jpg"] = _init_image_d - for ck in model.concat_keys: - cc = batch[ck] - cc = model.depth_model(cc) - depth_min, depth_max = torch.amin( - cc, dim=[1, 2, 3], keepdim=True - ), torch.amax(cc, dim=[1, 2, 3], keepdim=True) - display_depth = (cc - depth_min) / (depth_max - depth_min) - depth_image_display = Image.fromarray( - (display_depth[0, 0, ...].cpu().numpy() * 255.0).astype(np.uint8) - ) - cc = torch.nn.functional.interpolate( - cc, + + if hasattr(model, "depth_stage_key"): + # depth model + depth_t = torch_image_to_depth_map(init_image_t) + depth_latent = torch.nn.functional.interpolate( + depth_t, size=shape[2:], mode="bicubic", align_corners=False, ) - depth_min, depth_max = torch.amin( - cc, dim=[1, 2, 3], keepdim=True - ), torch.amax(cc, dim=[1, 2, 3], keepdim=True) - cc = 2.0 * (cc - depth_min) / (depth_max - depth_min) - 1.0 - c_cat.append(cc) - c_cat = [torch.cat(c_cat, dim=1)] + result_images["depth_image"] = depth_t + c_cat.append(depth_latent) - if mask_image_orig and not has_depth_channel: - mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to( - get_device() - ) - inverted_mask = 1 - mask - masked_image_t = init_image_t * (mask_t < 0.5) - batch.update( - { - "image": repeat( - init_image_t.to(device=get_device()), - "1 ... -> n ...", - n=batch_size, - ), - "txt": batch_size * [prompt.prompt_text], - "mask": repeat( - inverted_mask.to(device=get_device()), - "1 ... -> n ...", - n=batch_size, - ), - "masked_image": repeat( - masked_image_t.to(device=get_device()), - "1 ... -> n ...", - n=batch_size, - ), - } - ) + elif hasattr(model, "masked_image_key"): + # inpainting model + mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to( + get_device() + ) + inverted_mask = 1 - mask_latent + masked_image_t = init_image_t * (mask_t < 0.5) + log_img(masked_image_t, "masked_image") - for concat_key in getattr(model, "concat_keys", []): - cc = batch[concat_key].float() - if concat_key != model.masked_image_key: - bchw = [batch_size, 4, shape[2], shape[3]] - cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) - else: - cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) - c_cat.append(cc) + inverted_mask_latent = torch.nn.functional.interpolate( + inverted_mask, size=shape[-2:] + ) + c_cat.append(inverted_mask_latent) + + masked_image_latent = model.get_first_stage_encoding( + model.encode_first_stage(masked_image_t) + ) + c_cat.append(masked_image_latent) + + elif model.cond_stage_key == "edit": + # pix2pix model + c_cat = [model.encode_first_stage(init_image_t).mode()] + c_cat_neutral = [torch.zeros_like(init_latent)] + denoiser_cls = CFGEditingDenoiser if c_cat: c_cat = [torch.cat(c_cat, dim=1)] - denoiser_cls = None - if model.cond_stage_key == "edit": - c_cat = [model.encode_first_stage(init_image_t).mode()] - c_cat_neutral = [torch.zeros_like(init_latent)] - denoiser_cls = CFGEditingDenoiser if c_cat_neutral is None: c_cat_neutral = c_cat @@ -467,6 +415,8 @@ def _generate_single_image( "c_concat": c_cat_neutral, "c_crossattn": [neutral_conditioning], } + log_latent(init_latent_noised, "init_latent_noised") + with lc.timing("sampling"): samples = sampler.sample( num_steps=prompt.steps, @@ -475,105 +425,74 @@ def _generate_single_image( neutral_conditioning=neutral_conditioning, guidance_scale=prompt.prompt_strength, t_start=t_enc, - mask=mask, + mask=mask_latent, orig_latent=init_latent, shape=shape, batch_size=1, denoiser_cls=denoiser_cls, ) - # from torch.nn.functional import interpolate - # samples = interpolate(samples, scale_factor=2, mode='nearest') + with lc.timing("decoding"): - try: - x_samples = model.decode_first_stage(samples) - except OutOfMemoryError: - model.cond_stage_model.to("cpu") - model.model.to("cpu") - x_samples = model.decode_first_stage(samples) - model.cond_stage_model.to(get_device()) - model.model.to(get_device()) + gen_imgs_t = model.decode_first_stage(samples) + gen_img = torch_img_to_pillow_img(gen_imgs_t) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + if mask_image_orig and init_image: + mask_final = mask_image_orig.copy() + log_img(mask_final, "reconstituting mask") + mask_final = ImageOps.invert(mask_final) + gen_img = Image.composite(gen_img, init_image, mask_final) + log_img(gen_img, "reconstituted image") - for x_sample in x_samples: - x_sample = x_sample.to(torch.float32) - x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") - x_sample_8_orig = x_sample.astype(np.uint8) - img = Image.fromarray(x_sample_8_orig) - if mask_image_orig and init_image: - # mask_final = mask_image_orig.filter( - # ImageFilter.GaussianBlur(radius=3) - # ) - mask_final = mask_image_orig.copy() - log_img(mask_final, "reconstituting mask") - mask_final = ImageOps.invert(mask_final) - img = Image.composite(img, init_image, mask_final) - log_img(img, "reconstituted image") + upscaled_img = None + rebuilt_orig_img = None - upscaled_img = None - rebuilt_orig_img = None + if add_caption: + caption = generate_caption(gen_img) + logger.info(f"Generated caption: {caption}") - if add_caption: - caption = generate_caption(img) - logger.info(f"Generated caption: {caption}") - - with lc.timing("safety-filter"): - safety_score = create_safety_score( - img, - safety_mode=IMAGINAIRY_SAFETY_MODE, - ) - if safety_score.is_filtered: - progress_latents.clear() - if not safety_score.is_filtered: - if prompt.fix_faces: - logger.info("Fixing 😊 's in 🖼 using CodeFormer...") - with lc.timing("face enhancement"): - img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity) - if prompt.upscale: - logger.info("Upscaling 🖼 using real-ESRGAN...") - with lc.timing("upscaling"): - upscaled_img = upscale_image(img) - - # put the newly generated patch back into the original, full size image - if prompt.mask_modify_original and mask_image_orig and starting_image: - img_to_add_back_to_original = upscaled_img if upscaled_img else img - img_to_add_back_to_original = img_to_add_back_to_original.resize( - starting_image.size, - resample=Image.Resampling.LANCZOS, - ) - - mask_for_orig_size = mask_image_orig.resize( - starting_image.size, - resample=Image.Resampling.LANCZOS, - ) - # mask_for_orig_size = mask_for_orig_size.filter( - # ImageFilter.GaussianBlur(radius=5) - # ) - log_img(mask_for_orig_size, "mask for original image size") - - rebuilt_orig_img = Image.composite( - starting_image, - img_to_add_back_to_original, - mask_for_orig_size, - ) - log_img(rebuilt_orig_img, "reconstituted original") - - result = ImagineResult( - img=img, - prompt=prompt, - upscaled_img=upscaled_img, - is_nsfw=safety_score.is_nsfw, - safety_score=safety_score, - modified_original=rebuilt_orig_img, - mask_binary=mask_image_orig, - mask_grayscale=mask_grayscale, - depth_image=depth_image_display, - timings=lc.get_timings(), - progress_latents=progress_latents.copy(), + with lc.timing("safety-filter"): + safety_score = create_safety_score( + gen_img, + safety_mode=IMAGINAIRY_SAFETY_MODE, ) - _most_recent_result = result - logger.info(f"Image Generated. Timings: {result.timings_str()}") - return result + if safety_score.is_filtered: + progress_latents.clear() + if not safety_score.is_filtered: + if prompt.fix_faces: + logger.info("Fixing 😊 's in 🖼 using CodeFormer...") + with lc.timing("face enhancement"): + gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity) + if prompt.upscale: + logger.info("Upscaling 🖼 using real-ESRGAN...") + with lc.timing("upscaling"): + upscaled_img = upscale_image(gen_img) + + # put the newly generated patch back into the original, full-size image + if prompt.mask_modify_original and mask_image_orig and starting_image: + img_to_add_back_to_original = upscaled_img if upscaled_img else gen_img + rebuilt_orig_img = combine_image( + original_img=starting_image, + generated_img=img_to_add_back_to_original, + mask_img=mask_image_orig, + ) + + result = ImagineResult( + img=gen_img, + prompt=prompt, + upscaled_img=upscaled_img, + is_nsfw=safety_score.is_nsfw, + safety_score=safety_score, + modified_original=rebuilt_orig_img, + mask_binary=mask_image_orig, + mask_grayscale=mask_grayscale, + result_images=result_images, + timings=lc.get_timings(), + progress_latents=progress_latents.copy(), + ) + + _most_recent_result = result + logger.info(f"Image Generated. Timings: {result.timings_str()}") + return result def _prompts_to_embeddings(prompts, model): @@ -587,3 +506,29 @@ def _prompts_to_embeddings(prompts, model): def prompt_normalized(prompt): return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130] + + +def combine_image(original_img, generated_img, mask_img): + """Combine the generated image with the original image using the mask image.""" + from PIL import Image + + from imaginairy.log_utils import log_img + + generated_img = generated_img.resize( + original_img.size, + resample=Image.Resampling.LANCZOS, + ) + + mask_for_orig_size = mask_img.resize( + original_img.size, + resample=Image.Resampling.LANCZOS, + ) + log_img(mask_for_orig_size, "mask for original image size") + + rebuilt_orig_img = Image.composite( + original_img, + generated_img, + mask_for_orig_size, + ) + log_img(rebuilt_orig_img, "reconstituted original") + return rebuilt_orig_img diff --git a/imaginairy/img_utils.py b/imaginairy/img_utils.py index 654932e..af2c6d2 100644 --- a/imaginairy/img_utils.py +++ b/imaginairy/img_utils.py @@ -34,20 +34,28 @@ def pillow_fit_image_within( return image -def pillow_img_to_torch_image(img: PIL.Image.Image): - img = img.convert("RGB") +def pillow_img_to_torch_image(img: PIL.Image.Image, convert="RGB"): + if convert: + img = img.convert(convert) img = np.array(img).astype(np.float32) / 255.0 img = img[None].transpose(0, 3, 1, 2) img = torch.from_numpy(img) return 2.0 * img - 1.0 -def torch_img_to_pillow_img(img: torch.Tensor): - img = rearrange(img, "b c h w -> b h w c") - img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0) - img = (255.0 * img).cpu().numpy().astype(np.uint8) - img = Image.fromarray(img[0]) - return img +def pillow_mask_to_latent_mask(mask_img: PIL.Image.Image, downsampling_factor): + mask_img = mask_img.resize( + ( + mask_img.width // downsampling_factor, + mask_img.height // downsampling_factor, + ), + resample=Image.Resampling.LANCZOS, + ) + + mask = np.array(mask_img).astype(np.float32) / 255.0 + mask = mask[None, None] + mask = torch.from_numpy(mask) + return mask def pillow_img_to_opencv_img(img: PIL.Image.Image): @@ -57,22 +65,42 @@ def pillow_img_to_opencv_img(img: PIL.Image.Image): return open_cv_image -def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Image]: +def torch_img_to_pillow_img(img_t: torch.Tensor): + if len(img_t.shape) == 3: + img_t = img_t.unsqueeze(0) + if img_t.shape[0] != 1: + raise ValueError("Only batch size 1 supported") + if img_t.shape[1] == 1: + colorspace = "L" + elif img_t.shape[1] == 3: + colorspace = "RGB" + else: + raise ValueError("Unsupported colorspace") + img_t = rearrange(img_t, "b c h w -> b h w c") + img_t = torch.clamp((img_t + 1.0) / 2.0, min=0.0, max=1.0) + img_np = (255.0 * img_t).cpu().numpy().astype(np.uint8)[0] + if colorspace == "L": + img_np = img_np[:, :, 0] + return Image.fromarray(img_np, colorspace) + + +def model_latent_to_pillow_img(latent: torch.Tensor) -> PIL.Image.Image: from imaginairy.model_manager import get_current_diffusion_model # noqa + if len(latent.shape) == 3: + latent = latent.unsqueeze(0) + if latent.shape[0] != 1: + raise ValueError("Only batch size 1 supported") model = get_current_diffusion_model() - latents = model.decode_first_stage(latents) - latents = torch.clamp((latents + 1.0) / 2.0, min=0.0, max=1.0) - imgs = [] - for latent in latents: - latent = 255.0 * rearrange(latent.cpu().numpy(), "c h w -> h w c") - img = Image.fromarray(latent.astype(np.uint8)) - imgs.append(img) - return imgs + img_t = model.decode_first_stage(latent) + return torch_img_to_pillow_img(img_t) + + +def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Image]: + return [model_latent_to_pillow_img(latent) for latent in latents] def pillow_img_to_model_latent(model, img, batch_size=1, half=True): - # init_image = pil_img_to_torch(img, half=half).to(device) init_image = pillow_img_to_torch_image(img).to(get_device()) init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) if half: diff --git a/imaginairy/modules/midas/api.py b/imaginairy/modules/midas/api.py index a90e251..138c852 100644 --- a/imaginairy/modules/midas/api.py +++ b/imaginairy/modules/midas/api.py @@ -1,7 +1,9 @@ # based on https://github.com/isl-org/MiDaS +from functools import lru_cache import cv2 import torch +from einops import rearrange from torch import nn from torchvision.transforms import Compose @@ -13,6 +15,7 @@ from imaginairy.modules.midas.midas.transforms import ( PrepareForNet, Resize, ) +from imaginairy.utils import get_device ISL_PATHS = { "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", @@ -151,6 +154,30 @@ def load_model(model_type): return model.eval(), transform +@lru_cache() +def load_midas(model_type="dpt_hybrid"): + model = MiDaSInference(model_type) + model.to(get_device()) + return model + + +def torch_image_to_depth_map(image_t: torch.Tensor, model_type="dpt_hybrid"): + model = load_midas(model_type) + transform = load_midas_transform(model_type) + image_t = rearrange(image_t, "b c h w -> b h w c")[0] + image_np = ((image_t + 1.0) * 0.5).detach().cpu().numpy() + image_np = transform({"image": image_np})["image"] + image_t = torch.from_numpy(image_np[None, ...]) + image_t = image_t.to(device=get_device()) + + depth_t = model(image_t) + depth_min = torch.amin(depth_t, dim=[1, 2, 3], keepdim=True) + depth_max = torch.amax(depth_t, dim=[1, 2, 3], keepdim=True) + + depth_t = (depth_t - depth_min) / (depth_max - depth_min) + return depth_t + + class MiDaSInference(nn.Module): MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"] MODEL_TYPES_ISL = [ diff --git a/imaginairy/schema.py b/imaginairy/schema.py index 3fc94c0..879baf1 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -278,10 +278,13 @@ class ImagineResult: modified_original=None, mask_binary=None, mask_grayscale=None, - depth_image=None, + result_images=None, timings=None, progress_latents=None, ): + import torch + + from imaginairy.img_utils import torch_img_to_pillow_img from imaginairy.utils import get_device, get_hardware_description self.prompt = prompt @@ -300,8 +303,10 @@ class ImagineResult: if mask_grayscale: self.images["mask_grayscale"] = mask_grayscale - if depth_image is not None: - self.images["depth_image"] = depth_image + for img_type, r_img in result_images.items(): + if isinstance(r_img, torch.Tensor): + r_img = torch_img_to_pillow_img(r_img) + self.images[img_type] = r_img self.timings = timings self.progress_latents = progress_latents diff --git a/tests/test_api.py b/tests/test_api.py index 390c7f9..1b91e8c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -158,7 +158,7 @@ def test_img_to_img_fruit_2_gold( "k_euler_a": 18000, "k_dpm_adaptive": 13000, } - threshold = threshold_lookup.get(sampler_type, 11000) + threshold = threshold_lookup.get(sampler_type, 14000) pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg") img_path = f"{filename_base_for_outputs}.png" diff --git a/tests/test_feather_tile.py b/tests/test_feather_tile.py index 90b2cb8..7dbb24e 100644 --- a/tests/test_feather_tile.py +++ b/tests/test_feather_tile.py @@ -58,9 +58,9 @@ def test_feather_tile_simple(img_ratio, tile_size, overlap_pct): tile_size, overlap_pct, (img.size(2), img.size(3)) ) - print( - f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}" - ) + # print( + # f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}" + # ) rebuilt = rebuild_image( tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_pct @@ -86,9 +86,9 @@ def test_feather_tile_brute(): tile_coords, tile_size, overlap = tile_setup( tile_size, overlap_percent, (img.size(2), img.size(3)) ) - print( - f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}" - ) + # print( + # f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}" + # ) rebuilt = rebuild_image( tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_percent @@ -100,13 +100,11 @@ def test_feather_tile_brute(): torch_img_to_pillow_img(rebuilt).show() torch_img_to_pillow_img((rebuilt - img) * 20).show() - status = "🚫 FAILED" - else: - status = "✅ PASSED" - print( - f"{status}: img:{img.shape} tile_size={tile_size} overlap_percent={overlap_percent} diff={diff}" - ) + pass + # print( + # f"{status}: img:{img.shape} tile_size={tile_size} overlap_percent={overlap_percent} diff={diff}" + # ) assert diff < 1 for tile_size_pct, overlap_percent, img_ratio, flip_ratio in itertools.product( @@ -123,9 +121,9 @@ def test_feather_tile_brute(): if overlap_percent >= 0.5: continue - print( - f"img_ratio={img_ratio}, tile_size_pct={tile_size_pct}, overlap_percent={overlap_percent}, tile_size={tile_size} img.shape={img.shape}" - ) + # print( + # f"img_ratio={img_ratio}, tile_size_pct={tile_size_pct}, overlap_percent={overlap_percent}, tile_size={tile_size} img.shape={img.shape}" + # ) tile_untile(img, tile_size=tile_size, overlap_percent=overlap_percent) del img