2022-09-18 00:02:45 +00:00
|
|
|
from functools import lru_cache
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torchvision import transforms
|
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
from imaginairy.img_log import log_img
|
2022-09-18 00:02:45 +00:00
|
|
|
from imaginairy.vendored.clipseg import CLIPDensePredT
|
|
|
|
|
|
|
|
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache()
|
|
|
|
def clip_mask_model():
|
2022-09-18 13:07:07 +00:00
|
|
|
from imaginairy import PKG_ROOT
|
|
|
|
|
|
|
|
model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64)
|
2022-09-18 00:02:45 +00:00
|
|
|
model.eval()
|
|
|
|
|
|
|
|
model.load_state_dict(
|
|
|
|
torch.load(
|
2022-09-18 13:07:07 +00:00
|
|
|
f"{PKG_ROOT}/vendored/clipseg/rd64-uni.pth",
|
|
|
|
map_location=torch.device("cpu"),
|
|
|
|
),
|
|
|
|
strict=False,
|
2022-09-18 00:02:45 +00:00
|
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
def get_img_mask(img, mask_description):
|
2022-09-18 22:24:31 +00:00
|
|
|
descriptions = mask_description.split("|")
|
|
|
|
return get_img_masks(img, descriptions, combine=True)[0]
|
2022-09-18 13:07:07 +00:00
|
|
|
|
|
|
|
|
2022-09-18 22:24:31 +00:00
|
|
|
def get_img_masks(img, mask_descriptions, combine=False):
|
2022-09-18 13:07:07 +00:00
|
|
|
a, b = img.size
|
|
|
|
orig_size = b, a
|
|
|
|
log_img(img, "image for masking")
|
|
|
|
# orig_shape = tuple(img.shape)[1:]
|
|
|
|
transform = transforms.Compose(
|
|
|
|
[
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
|
|
transforms.Resize((352, 352)),
|
|
|
|
]
|
|
|
|
)
|
2022-09-18 00:02:45 +00:00
|
|
|
img = transform(img).unsqueeze(0)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
2022-09-18 13:07:07 +00:00
|
|
|
preds = clip_mask_model()(
|
|
|
|
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
|
|
|
|
)[0]
|
|
|
|
preds = transforms.Resize(orig_size)(preds)
|
2022-09-18 22:24:31 +00:00
|
|
|
preds = transforms.GaussianBlur(kernel_size=9)(preds)
|
2022-09-18 13:07:07 +00:00
|
|
|
|
|
|
|
preds = [torch.sigmoid(p[0]) for p in preds]
|
2022-09-18 22:24:31 +00:00
|
|
|
|
|
|
|
if combine:
|
|
|
|
f_pred = preds[0]
|
|
|
|
for description, pred in zip(mask_descriptions, preds):
|
|
|
|
log_img(pred, f"mask search: {description}")
|
|
|
|
f_pred = torch.maximum(f_pred, pred)
|
|
|
|
preds = [f_pred]
|
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
bw_preds = []
|
|
|
|
for p in preds:
|
|
|
|
log_img(p, f"clip mask for {mask_descriptions}")
|
|
|
|
# bw_preds.append(pred_transform(p))
|
|
|
|
_min = p.min()
|
|
|
|
_max = p.max()
|
|
|
|
_range = _max - _min
|
2022-09-18 22:24:31 +00:00
|
|
|
p = (p > (_min + (_range * 0.25))).float()
|
2022-09-18 13:07:07 +00:00
|
|
|
bw_preds.append(transforms.ToPILImage()(p))
|
2022-09-18 00:02:45 +00:00
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
return bw_preds
|