mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
1c986d8644
- disable lint fixer that updates to newer syntax
105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
from functools import lru_cache
|
|
from typing import Optional, Sequence
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import PIL.Image
|
|
import torch
|
|
from torchvision import transforms
|
|
|
|
from imaginairy.img_utils import pillow_fit_image_within
|
|
from imaginairy.log_utils import log_img
|
|
from imaginairy.vendored.clipseg import CLIPDensePredT
|
|
|
|
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
|
|
|
|
|
|
@lru_cache()
|
|
def clip_mask_model():
|
|
from imaginairy.paths import PKG_ROOT # noqa
|
|
|
|
model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64, complex_trans_conv=True)
|
|
model.eval()
|
|
|
|
model.load_state_dict(
|
|
torch.load(
|
|
f"{PKG_ROOT}/vendored/clipseg/rd64-uni-refined.pth",
|
|
map_location=torch.device("cpu"),
|
|
),
|
|
strict=False,
|
|
)
|
|
return model
|
|
|
|
|
|
def get_img_mask(
|
|
img: PIL.Image.Image,
|
|
mask_description_statement: str,
|
|
threshold: Optional[float] = None,
|
|
):
|
|
from imaginairy.enhancers.bool_masker import MASK_PROMPT # noqa
|
|
|
|
parsed = MASK_PROMPT.parseString(mask_description_statement)
|
|
parsed_mask = parsed[0][0]
|
|
descriptions = list(parsed_mask.gather_text_descriptions())
|
|
orig_size = img.size
|
|
img = pillow_fit_image_within(img, max_height=352, max_width=352)
|
|
mask_cache = get_img_masks(img, descriptions)
|
|
mask = parsed_mask.apply_masks(mask_cache)
|
|
log_img(mask, "combined mask")
|
|
|
|
kernel = np.ones((3, 3), np.uint8)
|
|
mask_g = mask.clone()
|
|
|
|
# trial and error shows 0.5 threshold has the best "shape"
|
|
if threshold is not None:
|
|
mask[mask < 0.5] = 0
|
|
mask[mask >= 0.5] = 1
|
|
log_img(mask, f"mask threshold {0.5}")
|
|
|
|
mask_np = mask.to(torch.float32).cpu().numpy()
|
|
smoother_strength = 2
|
|
# grow the mask area to make sure we've masked the thing we care about
|
|
for _ in range(smoother_strength):
|
|
mask_np = cv2.dilate(mask_np, kernel)
|
|
# todo: add an outer blur (not gaussian)
|
|
mask = torch.from_numpy(mask_np)
|
|
log_img(mask, "mask after closing (dilation then erosion)")
|
|
|
|
mask_img = transforms.ToPILImage()(mask).resize(
|
|
orig_size, resample=PIL.Image.Resampling.LANCZOS
|
|
)
|
|
mask_img_g = transforms.ToPILImage()(mask_g).resize(
|
|
orig_size, resample=PIL.Image.Resampling.LANCZOS
|
|
)
|
|
return mask_img, mask_img_g
|
|
|
|
|
|
def get_img_masks(img, mask_descriptions: Sequence[str]):
|
|
a, b = img.size
|
|
orig_size = b, a
|
|
log_img(img, "image for masking")
|
|
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
transforms.Resize((352, 352)),
|
|
]
|
|
)
|
|
img = transform(img).unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
preds = clip_mask_model()(
|
|
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
|
|
)[0]
|
|
preds = transforms.Resize(orig_size)(preds)
|
|
|
|
preds = [torch.sigmoid(p[0]) for p in preds]
|
|
|
|
preds_dict = {}
|
|
for p, desc in zip(preds, mask_descriptions):
|
|
log_img(p, f"clip mask: {desc}")
|
|
preds_dict[desc] = p
|
|
|
|
return preds_dict
|