2022-09-18 00:02:45 +00:00
|
|
|
from functools import lru_cache
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torchvision import transforms
|
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
from imaginairy.img_log import log_img
|
2022-09-18 00:02:45 +00:00
|
|
|
from imaginairy.vendored.clipseg import CLIPDensePredT
|
|
|
|
|
|
|
|
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache()
|
|
|
|
def clip_mask_model():
|
2022-09-20 04:15:38 +00:00
|
|
|
from imaginairy import PKG_ROOT # noqa
|
2022-09-18 13:07:07 +00:00
|
|
|
|
|
|
|
model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64)
|
2022-09-18 00:02:45 +00:00
|
|
|
model.eval()
|
|
|
|
|
|
|
|
model.load_state_dict(
|
|
|
|
torch.load(
|
2022-09-18 13:07:07 +00:00
|
|
|
f"{PKG_ROOT}/vendored/clipseg/rd64-uni.pth",
|
|
|
|
map_location=torch.device("cpu"),
|
|
|
|
),
|
|
|
|
strict=False,
|
2022-09-18 00:02:45 +00:00
|
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2022-09-20 04:15:38 +00:00
|
|
|
def get_img_mask(img, mask_description, negative_description=""):
|
|
|
|
pos_descriptions = mask_description.split("|")
|
|
|
|
pos_masks = get_img_masks(img, pos_descriptions)
|
|
|
|
pos_mask = pos_masks[0]
|
|
|
|
for pred in pos_masks:
|
|
|
|
pos_mask = torch.maximum(pos_mask, pred)
|
2022-09-18 13:07:07 +00:00
|
|
|
|
2022-09-20 04:15:38 +00:00
|
|
|
log_img(pos_mask, "pos mask")
|
2022-09-18 13:07:07 +00:00
|
|
|
|
2022-09-20 04:15:38 +00:00
|
|
|
if negative_description:
|
|
|
|
neg_descriptions = negative_description.split("|")
|
|
|
|
neg_masks = get_img_masks(img, neg_descriptions)
|
|
|
|
neg_mask = neg_masks[0]
|
|
|
|
for pred in neg_masks:
|
|
|
|
neg_mask = torch.maximum(neg_mask, pred)
|
|
|
|
neg_mask = (neg_mask * 3).clip(0, 1)
|
|
|
|
log_img(neg_mask, "neg_mask")
|
|
|
|
pos_mask = torch.minimum(pos_mask, 1 - neg_mask)
|
|
|
|
_min = pos_mask.min()
|
|
|
|
_max = pos_mask.max()
|
|
|
|
_range = _max - _min
|
|
|
|
pos_mask = (pos_mask > (_min + (_range * 0.35))).float()
|
|
|
|
|
|
|
|
return transforms.ToPILImage()(pos_mask)
|
|
|
|
|
|
|
|
|
|
|
|
def get_img_masks(img, mask_descriptions):
|
2022-09-18 13:07:07 +00:00
|
|
|
a, b = img.size
|
|
|
|
orig_size = b, a
|
|
|
|
log_img(img, "image for masking")
|
2022-09-20 04:15:38 +00:00
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
transform = transforms.Compose(
|
|
|
|
[
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
|
|
transforms.Resize((352, 352)),
|
|
|
|
]
|
|
|
|
)
|
2022-09-18 00:02:45 +00:00
|
|
|
img = transform(img).unsqueeze(0)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
2022-09-18 13:07:07 +00:00
|
|
|
preds = clip_mask_model()(
|
|
|
|
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
|
|
|
|
)[0]
|
|
|
|
preds = transforms.Resize(orig_size)(preds)
|
2022-09-18 22:24:31 +00:00
|
|
|
preds = transforms.GaussianBlur(kernel_size=9)(preds)
|
2022-09-18 13:07:07 +00:00
|
|
|
|
|
|
|
preds = [torch.sigmoid(p[0]) for p in preds]
|
2022-09-18 22:24:31 +00:00
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
bw_preds = []
|
2022-09-20 04:15:38 +00:00
|
|
|
for p, desc in zip(preds, mask_descriptions):
|
|
|
|
log_img(p, f"clip mask: {desc}")
|
2022-09-18 13:07:07 +00:00
|
|
|
# bw_preds.append(pred_transform(p))
|
2022-09-20 04:15:38 +00:00
|
|
|
|
|
|
|
bw_preds.append(p)
|
2022-09-18 00:02:45 +00:00
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
return bw_preds
|