You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/enhancers/clip_masking.py

116 lines
3.4 KiB
Python

from functools import lru_cache
from typing import Optional, Sequence
import cv2
import numpy as np
import PIL.Image
import torch
from torchvision import transforms
from imaginairy.img_utils import pillow_fit_image_within
from imaginairy.log_utils import log_img
from imaginairy.vendored.clipseg import CLIPDensePredT
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
@lru_cache
def clip_mask_model():
from imaginairy.paths import PKG_ROOT
model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64, complex_trans_conv=True)
model.eval()
model.load_state_dict(
torch.load(
f"{PKG_ROOT}/vendored/clipseg/rd64-uni-refined.pth",
map_location=torch.device("cpu"),
),
strict=False,
)
return model
def get_img_mask(
img: PIL.Image.Image,
mask_description_statement: str,
threshold: Optional[float] = None,
):
from imaginairy.enhancers.bool_masker import MASK_PROMPT
parsed = MASK_PROMPT.parseString(mask_description_statement)
parsed_mask = parsed[0][0]
descriptions = list(parsed_mask.gather_text_descriptions())
orig_size = img.size
img = pillow_fit_image_within(img, max_height=352, max_width=352)
mask_cache = get_img_masks(img, descriptions)
mask = parsed_mask.apply_masks(mask_cache)
log_img(mask, "combined mask")
kernel = np.ones((3, 3), np.uint8)
mask_g = mask.clone()
# trial and error shows 0.5 threshold has the best "shape"
if threshold is not None:
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
log_img(mask, f"mask threshold {0.5}")
mask_np = mask.to(torch.float32).cpu().numpy()
smoother_strength = 2
# grow the mask area to make sure we've masked the thing we care about
for _ in range(smoother_strength):
mask_np = cv2.dilate(mask_np, kernel)
# todo: add an outer blur (not gaussian)
mask = torch.from_numpy(mask_np)
log_img(mask, "mask after closing (dilation then erosion)")
mask_img = transforms.ToPILImage()(mask).resize(
orig_size, resample=PIL.Image.Resampling.LANCZOS
)
mask_img_g = transforms.ToPILImage()(mask_g).resize(
orig_size, resample=PIL.Image.Resampling.LANCZOS
)
return mask_img, mask_img_g
def get_img_masks(img, mask_descriptions: Sequence[str]):
a, b = img.size
orig_size = b, a
log_img(img, "image for masking")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((352, 352)),
]
)
img = transform(img).unsqueeze(0)
with torch.no_grad():
preds = clip_mask_model()(
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
)[0]
preds = transforms.Resize(orig_size)(preds)
preds = [torch.sigmoid(p[0]) for p in preds]
preds_dict = {}
for p, desc in zip(preds, mask_descriptions):
log_img(p, f"clip mask: {desc}")
preds_dict[desc] = p
return preds_dict
def img_mask_to_bounding_box(mask_img: PIL.Image.Image):
mask_np = np.array(mask_img)
mask_np = mask_np.astype(np.uint8)
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) == 0:
return None
contour = contours[0]
x, y, w, h = cv2.boundingRect(contour)
return x, y, x + w, y + h