feature: txt2mask - automated text replacement

from https://github.com/timojl/clipseg
pull/9/head
Bryce 2 years ago
parent 7087c4a680
commit 930295d840

@ -34,31 +34,31 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000019_786355545_PLMS50_PS7.5_a_scenic_landscape.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000032_337692011_PLMS40_PS7.5_a_photo_of_a_dog.jpg" height="256"><br>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000078_260972468_PLMS40_PS7.5_portrait_photo_of_a_freckled_woman.jpg" height="256">
### Tiled Images
### Automated Replacement (txt2mask) [by clipseg](https://github.com/timojl/clipseg)
```bash
>> imagine "gold coins" "a lush forest" "piles of old books" leaves --tile
>> imagine --init-image pearl_earring.jpg --mask-prompt face --mask-mode keep --init-image-strength .4 "a female doctor" "an elegant woman"
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128">
<br>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128">
### Image-to-Image
<img src="assets/mask_examples/pearl000.jpg" height="256">➡️
<img src="assets/mask_examples/pearl002.jpg" height="256">
<img src="assets/mask_examples/pearl004.jpg" height="256">
<img src="assets/mask_examples/pearl001.jpg" height="256">
<img src="assets/mask_examples/pearl003.jpg" height="256">
```bash
>> imagine "portrait of a smiling lady. oil painting" --init-image girl_with_a_pearl_earring.jpg
>> imagine --init-image fruit-bowl.jpg --mask-prompt fruit --mask-mode replace --init-image-strength .1 "a bowl of pears" "a bowl of gold" "a bowl of popcorn" "a bowl of spaghetti"
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/tests/data/girl_with_a_pearl_earring.jpg" height="256"> =>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000105_33084057_DDIM40_PS7.5_portrait_of_a_smiling_lady._oil_painting._.jpg" height="256">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="256">➡️
<img src="assets/mask_examples/bowl004.jpg" height="256">
<img src="assets/mask_examples/bowl001.jpg" height="256">
<img src="assets/mask_examples/bowl002.jpg" height="256">
<img src="assets/mask_examples/bowl003.jpg" height="256">
### Face Enhancement [by CodeFormer](https://github.com/sczhou/CodeFormer)
```bash
>> imagine "a couple smiling" --steps 40 --seed 1 --fix-faces
```
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png" height="256"> =>
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png" height="256"> ➡️
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000178_1_PLMS40_PS7.5_a_couple_smiling_fixed.png" height="256">
@ -66,9 +66,28 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
```bash
>> imagine "colorful smoke" --steps 40 --upscale
```
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg" height="128"> =>
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg" height="128"> ➡️
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000206_856637805_PLMS40_PS7.5_colorful_smoke_upscaled.jpg" height="256">
### Tiled Images
```bash
>> imagine "gold coins" "a lush forest" "piles of old books" leaves --tile
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128">
<br>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128">
### Image-to-Image
```bash
>> imagine "portrait of a smiling lady. oil painting" --init-image girl_with_a_pearl_earring.jpg
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/tests/data/girl_with_a_pearl_earring.jpg" height="256"> ➡️
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000105_33084057_DDIM40_PS7.5_portrait_of_a_smiling_lady._oil_painting._.jpg" height="256">
## Features
- It makes images from text descriptions! 🎉

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

@ -9,21 +9,29 @@ import torch
import torch.nn
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image, ImageDraw, ImageFilter
from PIL import Image, ImageDraw, ImageFilter, ImageOps
from pytorch_lightning import seed_everything
from torch import autocast
from transformers import cached_path
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.img_log import ImageLoggingContext, log_conditioning, log_latent
from imaginairy.img_log import (
ImageLoggingContext,
log_conditioning,
log_img,
log_latent,
)
from imaginairy.safety import is_nsfw
from imaginairy.samplers.base import get_sampler
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
expand_mask,
fix_torch_nn_layer_norm,
get_device,
instantiate_from_config,
pillow_fit_image_within,
pillow_img_to_torch_image,
)
@ -202,22 +210,61 @@ def imagine(
prompt.height // downsampling_factor,
prompt.width // downsampling_factor,
]
if prompt.init_image:
sampler_type = "ddim"
else:
sampler_type = prompt.sampler_type
start_code = None
sampler = get_sampler(prompt.sampler_type, model)
sampler = get_sampler(sampler_type, model)
mask, mask_image, mask_image_orig = None, None, None
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
init_image, w, h = pillow_img_to_torch_image( # noqa
init_image, _, h = pillow_fit_image_within(
prompt.init_image,
max_height=prompt.height,
max_width=prompt.width,
)
init_image = init_image.to(get_device())
init_image_t = pillow_img_to_torch_image(init_image)
if prompt.mask_prompt:
mask_image = get_img_mask(init_image, prompt.mask_prompt)
elif prompt.mask_image:
mask_image = prompt.mask_image
if mask_image is not None:
log_img(mask_image, "init mask")
# mask_image = mask_image.filter(ImageFilter.GaussianBlur(8))
mask_image = expand_mask(mask_image, prompt.mask_expansion)
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image)
log_img(mask_image, "init mask expanded")
log_img(
Image.composite(init_image, mask_image, mask_image),
"mask overlay",
)
mask_image_orig = mask_image
mask_image = mask_image.resize(
(
mask_image.width // downsampling_factor,
mask_image.height // downsampling_factor,
),
resample=Image.Resampling.NEAREST,
)
log_img(mask_image, "init mask 2")
mask = np.array(mask_image)
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.9] = 0
mask[mask >= 0.9] = 1
mask = torch.from_numpy(mask)
mask = mask.to(get_device())
init_image_t = init_image_t.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image)
model.encode_first_stage(init_image_t)
)
log_latent(init_latent, "init_latent")
@ -236,6 +283,8 @@ def imagine(
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=_img_callback,
mask=mask,
orig_latent=init_latent,
)
else:
@ -260,6 +309,14 @@ def imagine(
)
x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig)
if mask_image_orig and init_image:
mask_image_orig = mask_image_orig.filter(
ImageFilter.GaussianBlur(radius=3)
)
mask_image_orig = ImageOps.invert(mask_image_orig)
img = Image.composite(img, init_image, mask_image_orig)
log_img(img, "reconstituted image")
upscaled_img = None
is_nsfw_img = None
if IMAGINAIRY_SAFETY_MODE != SafetyMode.DISABLED:
@ -269,7 +326,7 @@ def imagine(
img = img.filter(ImageFilter.GaussianBlur(radius=40))
if prompt.fix_faces:
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")
logger.info(" Fixing 😊 's in 🖼 using CodeFormer...")
img = enhance_faces(img, fidelity=0.2)
if prompt.upscale:
logger.info(" Upscaling 🖼 using real-ESRGAN...")

@ -123,6 +123,26 @@ def configure_logging(level="INFO"):
is_flag=True,
help="Any images rendered will be tileable. Unfortunately cannot be controlled at the per-image level yet",
)
@click.option(
"--mask-image",
help="A mask to use for inpainting. White gets painted, Black is left alone.",
)
@click.option(
"--mask-prompt",
help="Describe what you want masked and the AI will mask it for you",
)
@click.option(
"--mask-mode",
default="replace",
type=click.Choice(["keep", "replace"]),
help="Should we replace the masked area or keep it?",
)
@click.option(
"--mask-expansion",
default="8",
type=int,
help="How much to grow (or shrink) the mask area",
)
def imagine_cmd(
prompt_texts,
prompt_strength,
@ -141,6 +161,10 @@ def imagine_cmd(
log_level,
show_work,
tile,
mask_image,
mask_prompt,
mask_mode,
mask_expansion,
):
"""Render an image"""
suppress_annoying_logs_and_warnings()
@ -161,7 +185,6 @@ def imagine_cmd(
load_model(tile_mode=tile)
for _ in range(repeats):
for prompt_text in prompt_texts:
prompt = ImaginePrompt(
prompt_text,
prompt_strength=prompt_strength,
@ -172,6 +195,10 @@ def imagine_cmd(
steps=steps,
height=height,
width=width,
mask_image=mask_image,
mask_prompt=mask_prompt,
mask_expansion=mask_expansion,
mask_mode=mask_mode,
upscale=upscale,
fix_faces=fix_faces,
)

@ -3,7 +3,7 @@ from functools import lru_cache
import torch
from torchvision import transforms
from imaginairy import PKG_ROOT
from imaginairy.img_log import log_img
from imaginairy.vendored.clipseg import CLIPDensePredT
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
@ -11,27 +11,54 @@ weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth
@lru_cache()
def clip_mask_model():
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
from imaginairy import PKG_ROOT
model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64)
model.eval()
model.load_state_dict(
torch.load(
f'{PKG_ROOT}/vendored/clipseg/rd64-uni.pth',
map_location=torch.device("cpu")),
strict=False
f"{PKG_ROOT}/vendored/clipseg/rd64-uni.pth",
map_location=torch.device("cpu"),
),
strict=False,
)
return model
def get_img_mask(img, mask_descriptions):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((352, 352)),
])
def get_img_mask(img, mask_description):
return get_img_masks(img, [mask_description])[0]
def get_img_masks(img, mask_descriptions):
a, b = img.size
orig_size = b, a
log_img(img, "image for masking")
# orig_shape = tuple(img.shape)[1:]
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((352, 352)),
]
)
img = transform(img).unsqueeze(0)
with torch.no_grad():
preds = clip_mask_model()(img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions)[0]
preds = clip_mask_model()(
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
)[0]
preds = transforms.Resize(orig_size)(preds)
preds = [torch.sigmoid(p[0]) for p in preds]
bw_preds = []
for p in preds:
log_img(p, f"clip mask for {mask_descriptions}")
# bw_preds.append(pred_transform(p))
_min = p.min()
_max = p.max()
_range = _max - _min
p = (p > (_min + (_range * 0.5))).float()
bw_preds.append(transforms.ToPILImage()(p))
return preds
return bw_preds

@ -19,6 +19,8 @@ def realesrgan_upsampler():
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0)
device = get_device()
if "mps" in device:
device = "cpu"
upsampler.device = torch.device(device)
upsampler.model.to(device)
@ -28,6 +30,7 @@ def realesrgan_upsampler():
def upscale_image(img):
img = img.convert("RGB")
np_img = np.array(img, dtype=np.uint8)
upsampler_output, img_mode = realesrgan_upsampler().enhance(np_img[:, :, ::-1])
return Image.fromarray(upsampler_output[:, :, ::-1], mode=img_mode)

@ -32,6 +32,12 @@ def log_latent(latents, description):
_CURRENT_LOGGING_CONTEXT.log_latents(latents, description)
def log_img(img, description):
if _CURRENT_LOGGING_CONTEXT is None:
return
_CURRENT_LOGGING_CONTEXT.log_img(img, description)
class ImageLoggingContext:
def __init__(self, prompt, model, img_callback=None, img_outdir=None):
self.prompt = prompt
@ -71,6 +77,15 @@ class ImageLoggingContext:
img = Image.fromarray(latent.astype(np.uint8))
self.img_callback(img, description, self.step_count, self.prompt)
def log_img(self, img, description):
if not self.img_callback:
return
self.step_count += 1
if isinstance(img, torch.Tensor):
img = ToPILImage()(img.squeeze().cpu().detach())
img = img.copy()
self.img_callback(img, description, self.step_count, self.prompt)
# def img_callback(self, img, description, step_count, prompt):
# steps_path = os.path.join(self.img_outdir, "steps", f"{self.file_num:08}_S{prompt.seed}")
# os.makedirs(steps_path, exist_ok=True)

@ -16,7 +16,11 @@ from einops import rearrange
from torchvision.utils import make_grid
from tqdm import tqdm
from imaginairy.modules.diffusion.util import make_beta_schedule, noise_like
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_beta_schedule,
noise_like,
)
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.utils import instantiate_from_config, log_params
@ -852,6 +856,18 @@ class LatentDiffusion(DDPM):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
def q_sample(self, x_start, t, noise=None):
noise = (
noise
if noise is not None
else torch.randn_like(x_start, device="cpu").to(x_start.device)
)
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):

@ -42,20 +42,25 @@ def get_sampler(sampler_type, model):
class CFGDenoiser(nn.Module):
"""
Conditional forward guidance wrapper
"""
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
denoised = uncond + (cond - uncond) * cond_scale
if mask is not None:
assert orig_latent is not None
mask_inv = 1.0 - mask
denoised = (orig_latent * mask_inv) + (mask * denoised)
return denoised
class DiffusionSampler:

@ -335,10 +335,11 @@ class DDIMSampler:
img_callback=None,
score_corrector=None,
temperature=1.0,
mask=None,
orig_latent=None,
):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start]
timesteps = self.ddim_timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
@ -352,6 +353,15 @@ class DDIMSampler:
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
)
if mask is not None:
assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts)
log_latent(xdec_orig, "xdec_orig")
log_latent(xdec_orig * mask, "masked_xdec_orig")
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
log_latent(x_dec, "x_dec")
x_dec, pred_x0 = self.p_sample_ddim(
x_dec,
cond,

@ -8,28 +8,6 @@ from imaginairy.vendored.k_diffusion import sampling as k_sampling
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
class CFGMaskedDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
x_in = x
x_in = torch.cat([x_in] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
if mask is not None:
assert x0 is not None
img_orig = x0
mask_inv = 1.0 - mask
denoised = (img_orig * mask_inv) + (mask * denoised)
return denoised
class KDiffusionSampler:
def __init__(self, model, sampler_name):
self.model = model

@ -81,12 +81,20 @@ class WeightedPrompt:
class ImaginePrompt:
class MaskMode:
KEEP = "keep"
REPLACE = "replace"
def __init__(
self,
prompt=None,
prompt_strength=7.5,
init_image=None, # Pillow Image, LazyLoadingImage, or filepath str
init_image_strength=0.3,
mask_prompt=None,
mask_image=None,
mask_mode=MaskMode.REPLACE,
mask_expansion=8,
seed=None,
steps=50,
height=512,
@ -105,6 +113,10 @@ class ImaginePrompt:
self.prompt_strength = prompt_strength
if isinstance(init_image, str):
init_image = LazyLoadingImage(filepath=init_image)
if mask_image is not None and mask_prompt is not None:
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
self.init_image = init_image
self.init_image_strength = init_image_strength
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
@ -115,6 +127,10 @@ class ImaginePrompt:
self.fix_faces = fix_faces
self.sampler_type = sampler_type
self.conditioning = conditioning
self.mask_prompt = mask_prompt
self.mask_image = mask_image
self.mask_mode = mask_mode
self.mask_expansion = mask_expansion
@property
def prompt_text(self):

@ -9,12 +9,14 @@ from typing import List, Optional
import numpy as np
import requests
import torch
from PIL import Image
from PIL import Image, ImageFilter
from torch import Tensor
from torch.nn import functional
from torch.overrides import handle_torch_function, has_torch_function_variadic
from transformers import cached_path
from imaginairy.img_log import log_img
logger = logging.getLogger(__name__)
@ -102,23 +104,45 @@ def fix_torch_nn_layer_norm():
functional.layer_norm = orig_function
def img_path_to_torch_image(path, max_height=512, max_width=512):
def expand_mask(mask_image, size):
if size < 0:
threshold = 0.9
else:
threshold = 0.1
mask_image = mask_image.convert("L")
mask_image = mask_image.filter(ImageFilter.GaussianBlur(size))
log_img(mask_image, "init mask blurred")
mask = np.array(mask_image)
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < threshold] = 0
mask[mask >= threshold] = 1
return Image.fromarray(np.uint8(mask.squeeze() * 255))
def img_path_to_torch_image(path):
image = Image.open(path).convert("RGB")
logger.info(f"Loaded input 🖼 of size {image.size} from {path}")
return pillow_img_to_torch_image(image, max_height=max_height, max_width=max_width)
return pillow_img_to_torch_image(image)
def pillow_img_to_torch_image(image, max_height=512, max_width=512):
def pillow_fit_image_within(image, max_height=512, max_width=512):
image = image.convert("RGB")
w, h = image.size
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.NEAREST)
return image, w, h
def pillow_img_to_torch_image(image):
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0, w, h
return 2.0 * image - 1.0
def get_cache_dir():

@ -1,5 +1,6 @@
import math
from os.path import basename, dirname, join, isfile
from os.path import basename, dirname, isfile, join
import torch
from torch import nn
from torch.nn import functional as nnf
@ -9,86 +10,119 @@ from torch.nn.modules.activation import ReLU
def precompute_clip_vectors():
from trails.initialization import init_dataset
lvis = init_dataset('LVIS_OneShot3', split='train', mask='text_label', image_size=224, aug=1, normalize=True,
reduce_factor=None, add_bar=False, negative_prob=0.5)
lvis = init_dataset(
"LVIS_OneShot3",
split="train",
mask="text_label",
image_size=224,
aug=1,
normalize=True,
reduce_factor=None,
add_bar=False,
negative_prob=0.5,
)
all_names = list(lvis.category_names.values())
from imaginairy.vendored import clip
from models.clip_prompts import imagenet_templates
clip_model = clip.load("ViT-B/32", device='cuda', jit=False)[0]
from imaginairy.vendored import clip
clip_model = clip.load("ViT-B/32", device="cuda", jit=False)[0]
prompt_vectors = {}
for name in all_names[:100]:
with torch.no_grad():
conditionals = [t.format(name).replace('_', ' ') for t in imagenet_templates]
conditionals = [
t.format(name).replace("_", " ") for t in imagenet_templates
]
text_tokens = clip.tokenize(conditionals).cuda()
cond = clip_model.encode_text(text_tokens).cpu()
for cond, vec in zip(conditionals, cond):
prompt_vectors[cond] = vec.cpu()
import pickle
pickle.dump(prompt_vectors, open('precomputed_prompt_vectors.pickle', 'wb'))
pickle.dump(prompt_vectors, open("precomputed_prompt_vectors.pickle", "wb"))
def get_prompt_list(prompt):
if prompt == 'plain':
return ['{}']
elif prompt == 'fixed':
return ['a photo of a {}.']
elif prompt == 'shuffle':
return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
elif prompt == 'shuffle+':
return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
'a bad photo of a {}.', 'a photo of the {}.']
elif prompt == 'shuffle_clip':
if prompt == "plain":
return ["{}"]
elif prompt == "fixed":
return ["a photo of a {}."]
elif prompt == "shuffle":
return ["a photo of a {}.", "a photograph of a {}.", "an image of a {}.", "{}."]
elif prompt == "shuffle+":
return [
"a photo of a {}.",
"a photograph of a {}.",
"an image of a {}.",
"{}.",
"a cropped photo of a {}.",
"a good photo of a {}.",
"a photo of one {}.",
"a bad photo of a {}.",
"a photo of the {}.",
]
elif prompt == "shuffle_clip":
from models.clip_prompts import imagenet_templates
return imagenet_templates
else:
raise ValueError('Invalid value for prompt')
raise ValueError("Invalid value for prompt")
def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
"""
Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
"""
Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
The mlp and layer norm come from CLIP.
x: input.
b: multihead attention module.
b: multihead attention module.
"""
x_ = b.ln_1(x)
q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1)
q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(
3, dim=-1
)
tgt_len, bsz, embed_dim = q.size()
head_dim = embed_dim // b.attn.num_heads
scaling = float(head_dim) ** -0.5
q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
q = (
q.contiguous()
.view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim)
.transpose(0, 1)
)
k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
q = q * scaling
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2
attn_output_weights = torch.bmm(
q, k.transpose(1, 2)
) # n_heads * batch_size, tokens^2, tokens^2
if attn_mask is not None:
attn_mask_type, attn_mask = attn_mask
n_heads = attn_output_weights.size(0) // attn_mask.size(0)
attn_mask = attn_mask.repeat(n_heads, 1)
if attn_mask_type == 'cls_token':
if attn_mask_type == "cls_token":
# the mask only affects similarities compared to the readout-token.
attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...]
attn_output_weights[:, 0, 1:] = (
attn_output_weights[:, 0, 1:] * attn_mask[None, ...]
)
# attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0]
if attn_mask_type == 'all':
if attn_mask_type == "all":
# print(attn_output_weights.shape, attn_mask[:, None].shape)
attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
attn_output_weights[:, 1:, 1:] = (
attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
)
attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
attn_output = torch.bmm(attn_output_weights, v)
@ -105,14 +139,13 @@ def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
class CLIPDenseBase(nn.Module):
def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens):
super().__init__()
from imaginairy.vendored import clip
# prec = torch.FloatTensor
self.clip_model, _ = clip.load(version, device='cpu', jit=False)
self.clip_model, _ = clip.load(version, device="cpu", jit=False)
self.model = self.clip_model.visual
# if not None, scale conv weights such that we obtain n_tokens.
@ -127,32 +160,43 @@ class CLIPDenseBase(nn.Module):
for p in self.reduce_cond.parameters():
p.requires_grad_(False)
else:
self.reduce_cond = None
self.reduce_cond = None
self.film_mul = nn.Linear(
512 if reduce_cond is None else reduce_cond, reduce_dim
)
self.film_add = nn.Linear(
512 if reduce_cond is None else reduce_cond, reduce_dim
)
self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
self.reduce = nn.Linear(768, reduce_dim)
self.prompt_list = get_prompt_list(prompt)
self.prompt_list = get_prompt_list(prompt)
# precomputed prompts
import pickle
if isfile('precomputed_prompt_vectors.pickle'):
precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
if isfile("precomputed_prompt_vectors.pickle"):
precomp = pickle.load(open("precomputed_prompt_vectors.pickle", "rb"))
self.precomputed_prompts = {
k: torch.from_numpy(v) for k, v in precomp.items()
}
else:
self.precomputed_prompts = dict()
def rescaled_pos_emb(self, new_size):
assert len(new_size) == 2
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
b = (
nnf.interpolate(a, new_size, mode="bicubic", align_corners=False)
.squeeze(0)
.view(768, new_size[0] * new_size[1])
.T
)
return torch.cat([self.model.positional_embedding[:1], b])
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
with torch.no_grad():
@ -160,21 +204,46 @@ class CLIPDenseBase(nn.Module):
if self.n_tokens is not None:
stride2 = x_inp.shape[2] // self.n_tokens
conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True)
x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation)
conv_weight2 = nnf.interpolate(
self.model.conv1.weight,
(stride2, stride2),
mode="bilinear",
align_corners=True,
)
x = nnf.conv2d(
x_inp,
conv_weight2,
bias=self.model.conv1.bias,
stride=stride2,
dilation=self.model.conv1.dilation,
)
else:
x = self.model.conv1(x_inp) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = torch.cat(
[
self.model.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197
if x.shape[1] != standard_n_tokens:
new_shape = int(math.sqrt(x.shape[1]-1))
x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:]
new_shape = int(math.sqrt(x.shape[1] - 1))
x = (
x
+ self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[
None, :, :
]
)
else:
x = x + self.model.positional_embedding.to(x.dtype)
@ -184,34 +253,41 @@ class CLIPDenseBase(nn.Module):
activations, affinities = [], []
for i, res_block in enumerate(self.model.transformer.resblocks):
if mask is not None:
mask_layer, mask_type, mask_tensor = mask
if mask_layer == i or mask_layer == 'all':
if mask_layer == i or mask_layer == "all":
# import ipdb; ipdb.set_trace()
size = int(math.sqrt(x.shape[0] - 1))
attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size))
attn_mask = (
mask_type,
nnf.interpolate(
mask_tensor.unsqueeze(1).float(), (size, size)
).view(mask_tensor.shape[0], size * size),
)
else:
attn_mask = None
else:
attn_mask = None
x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask)
x, aff_per_head = forward_multihead_attention(
x, res_block, with_aff=True, attn_mask=attn_mask
)
if i in extract_layers:
affinities += [aff_per_head]
#if self.n_tokens is not None:
# if self.n_tokens is not None:
# activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)]
#else:
# else:
activations += [x]
if len(extract_layers) > 0 and i == max(extract_layers) and skip:
print('early skip')
print("early skip")
break
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_post(x[:, 0, :])
@ -224,7 +300,9 @@ class CLIPDenseBase(nn.Module):
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
prompt_indices = torch.multinomial(
torch.ones(len(prompt_list)), len(words), replacement=True
)
prompts = [prompt_list[i] for i in prompt_indices]
return [promt.format(w) for promt, w in zip(prompts, words)]
@ -235,12 +313,20 @@ class CLIPDenseBase(nn.Module):
cond = cond.repeat(batch_size, 1)
# compute conditional from string list/tuple
elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
elif (
conditional is not None
and type(conditional) in {list, tuple}
and type(conditional[0]) == str
):
assert len(conditional) == batch_size
cond = self.compute_conditional(conditional)
# use conditional directly
elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
elif (
conditional is not None
and type(conditional) == torch.Tensor
and conditional.ndim == 2
):
cond = conditional
# compute conditional from image
@ -248,8 +334,8 @@ class CLIPDenseBase(nn.Module):
with torch.no_grad():
cond, _, _ = self.visual_forward(conditional)
else:
raise ValueError('invalid conditional')
return cond
raise ValueError("invalid conditional")
return cond
def compute_conditional(self, conditional):
from imaginairy.vendored import clip
@ -265,7 +351,7 @@ class CLIPDenseBase(nn.Module):
else:
text_tokens = clip.tokenize([conditional]).to(dev)
cond = self.clip_model.encode_text(text_tokens)[0]
if self.shift_vector is not None:
return cond + self.shift_vector
else:
@ -273,14 +359,21 @@ class CLIPDenseBase(nn.Module):
def clip_load_untrained(version):
assert version == 'ViT-B/16'
from clip.model import CLIP
assert version == "ViT-B/16"
from clip.clip import _MODELS, _download
model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval()
from clip.model import CLIP
model = torch.jit.load(_download(_MODELS["ViT-B/16"])).eval()
state_dict = model.state_dict()
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_layers = len(
[
k
for k in state_dict.keys()
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
]
)
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
@ -289,19 +382,49 @@ def clip_load_untrained(version):
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)
transformer_layers = len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"transformer.resblocks")
)
)
return CLIP(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
)
class CLIPDensePredT(CLIPDenseBase):
def __init__(
self,
version="ViT-B/32",
extract_layers=(3, 6, 9),
cond_layer=0,
reduce_dim=128,
n_heads=4,
prompt="fixed",
extra_blocks=0,
reduce_cond=None,
fix_shift=False,
learn_trans_conv_only=False,
limit_to_clip_only=False,
upsample=False,
add_calibration=False,
rev_activations=False,
trans_conv=None,
n_tokens=None,
):
def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
extra_blocks=0, reduce_cond=None, fix_shift=False,
learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False,
add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None):
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
# device = 'cpu'
@ -310,53 +433,69 @@ class CLIPDensePredT(CLIPDenseBase):
self.limit_to_clip_only = limit_to_clip_only
self.process_cond = None
self.rev_activations = rev_activations
depth = len(extract_layers)
if add_calibration:
self.calibration_conds = 1
self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
self.upsample_proj = (
nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
)
self.add_activation1 = True
self.version = version
self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
self.token_shape = {"ViT-B/32": (7, 7), "ViT-B/16": (14, 14)}[version]
if fix_shift:
# self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False)
self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False)
self.shift_vector = nn.Parameter(
torch.load(join(dirname(basename(__file__)), "shift_text_to_vis.pth")),
requires_grad=False,
)
# self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False)
else:
self.shift_vector = None
if trans_conv is None:
trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
trans_conv_ks = {"ViT-B/32": (32, 32), "ViT-B/16": (16, 16)}[version]
else:
# explicitly define transposed conv kernel size
trans_conv_ks = (trans_conv, trans_conv)
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
self.trans_conv = nn.ConvTranspose2d(
reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks
)
assert len(self.extract_layers) == depth
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
self.blocks = nn.ModuleList(
[
nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads)
for _ in range(len(self.extract_layers))
]
)
self.extra_blocks = nn.ModuleList(
[
nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads)
for _ in range(extra_blocks)
]
)
# refinement and trans conv
if learn_trans_conv_only:
for p in self.parameters():
p.requires_grad_(False)
for p in self.trans_conv.parameters():
p.requires_grad_(True)
self.prompt_list = get_prompt_list(prompt)
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
assert type(return_features) == bool
@ -364,7 +503,7 @@ class CLIPDensePredT(CLIPDenseBase):
inp_image = inp_image.to(self.model.positional_embedding.device)
if mask is not None:
raise ValueError('mask not supported')
raise ValueError("mask not supported")
# x_inp = normalize(inp_image)
x_inp = inp_image
@ -373,7 +512,9 @@ class CLIPDensePredT(CLIPDenseBase):
cond = self.get_cond_vec(conditional, bs)
visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
visual_q, activations, _ = self.visual_forward(
x_inp, extract_layers=[0] + list(self.extract_layers)
)
activation1 = activations[0]
activations = activations[1:]
@ -381,8 +522,10 @@ class CLIPDensePredT(CLIPDenseBase):
_activations = activations[::-1] if not self.rev_activations else activations
a = None
for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)):
for i, (activation, block, reduce) in enumerate(
zip(_activations, self.blocks, self.reduces)
):
if a is not None:
a = reduce(activation) + a
else:
@ -391,7 +534,7 @@ class CLIPDensePredT(CLIPDenseBase):
if i == self.cond_layer:
if self.reduce_cond is not None:
cond = self.reduce_cond(cond)
a = self.film_mul(cond) * a + self.film_add(cond)
a = block(a)
@ -399,7 +542,7 @@ class CLIPDensePredT(CLIPDenseBase):
for block in self.extra_blocks:
a = a + block(a)
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
size = int(math.sqrt(a.shape[2]))
@ -408,33 +551,57 @@ class CLIPDensePredT(CLIPDenseBase):
a = self.trans_conv(a)
if self.n_tokens is not None:
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True)
a = nnf.interpolate(a, x_inp.shape[2:], mode="bilinear", align_corners=True)
if self.upsample_proj is not None:
a = self.upsample_proj(a)
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
a = nnf.interpolate(a, x_inp.shape[2:], mode="bilinear")
if return_features:
return a, visual_q, cond, [activation1] + activations
else:
return a,
return (a,)
class CLIPDensePredTMasked(CLIPDensePredT):
def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4,
prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False,
refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None):
super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim,
n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond,
fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only,
limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration,
n_tokens=n_tokens)
def __init__(
self,
version="ViT-B/32",
extract_layers=(3, 6, 9),
cond_layer=0,
reduce_dim=128,
n_heads=4,
prompt="fixed",
extra_blocks=0,
reduce_cond=None,
fix_shift=False,
learn_trans_conv_only=False,
refine=None,
limit_to_clip_only=False,
upsample=False,
add_calibration=False,
n_tokens=None,
):
super().__init__(
version=version,
extract_layers=extract_layers,
cond_layer=cond_layer,
reduce_dim=reduce_dim,
n_heads=n_heads,
prompt=prompt,
extra_blocks=extra_blocks,
reduce_cond=reduce_cond,
fix_shift=fix_shift,
learn_trans_conv_only=learn_trans_conv_only,
limit_to_clip_only=limit_to_clip_only,
upsample=upsample,
add_calibration=add_calibration,
n_tokens=n_tokens,
)
def visual_forward_masked(self, img_s, seg_s):
return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s))
return super().visual_forward(img_s, mask=("all", "cls_token", seg_s))
def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False):
@ -449,34 +616,42 @@ class CLIPDensePredTMasked(CLIPDensePredT):
return super().forward(img_q, cond, return_features=return_features)
class CLIPDenseBaseline(CLIPDenseBase):
def __init__(
self,
version="ViT-B/32",
cond_layer=0,
extract_layer=9,
reduce_dim=128,
reduce2_dim=None,
prompt="fixed",
reduce_cond=None,
limit_to_clip_only=False,
n_tokens=None,
):
def __init__(self, version='ViT-B/32', cond_layer=0,
extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed',
reduce_cond=None, limit_to_clip_only=False, n_tokens=None):
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
device = 'cpu'
device = "cpu"
# self.cond_layer = cond_layer
self.extract_layer = extract_layer
self.limit_to_clip_only = limit_to_clip_only
self.shift_vector = None
self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
self.token_shape = {"ViT-B/32": (7, 7), "ViT-B/16": (14, 14)}[version]
assert reduce2_dim is not None
self.reduce2 = nn.Sequential(
nn.Linear(reduce_dim, reduce2_dim),
nn.ReLU(),
nn.Linear(reduce2_dim, reduce_dim)
nn.Linear(reduce2_dim, reduce_dim),
)
trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
trans_conv_ks = {"ViT-B/32": (32, 32), "ViT-B/16": (16, 16)}[version]
self.trans_conv = nn.ConvTranspose2d(
reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks
)
def forward(self, inp_image, conditional=None, return_features=False):
@ -489,7 +664,9 @@ class CLIPDenseBaseline(CLIPDenseBase):
cond = self.get_cond_vec(conditional, bs)
visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer])
visual_q, activations, affinities = self.visual_forward(
x_inp, extract_layers=[self.extract_layer]
)
a = activations[0]
a = self.reduce(a)
@ -500,7 +677,7 @@ class CLIPDenseBaseline(CLIPDenseBase):
# the original model would execute a transformer block here
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
size = int(math.sqrt(a.shape[2]))
@ -510,23 +687,23 @@ class CLIPDenseBaseline(CLIPDenseBase):
if return_features:
return a, visual_q, cond, activations
else:
return a,
return (a,)
class CLIPSegMultiLabel(nn.Module):
def __init__(self, model) -> None:
super().__init__()
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
from third_party.JoEm.data_loader import VOC, get_seen_idx, get_unseen_idx
self.pascal_classes = VOC
from models.clipseg import CLIPDensePredT
from general_utils import load_model
from models.clipseg import CLIPDensePredT
# self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
self.clipseg = load_model(model, strict=False)
self.clipseg.eval()
def forward(self, x):
@ -535,18 +712,16 @@ class CLIPSegMultiLabel(nn.Module):
out = torch.ones(21, bs, 352, 352).to(x.device) * -10
for class_id, class_name in enumerate(self.pascal_classes):
fac = 3 if class_name == 'background' else 1
fac = 3 if class_name == "background" else 1
with torch.no_grad():
pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac
pred = torch.sigmoid(self.clipseg(x, class_name)[0][:, 0]) * fac
out[class_id] += pred
out = out.permute(1, 0, 2, 3)
return out
# construct output tensor

@ -19,7 +19,13 @@ setup(
entry_points={
"console_scripts": ["imagine=imaginairy.cmds:imagine_cmd"],
},
package_data={"imaginairy": ["configs/*.yaml", "vendored/clip/*.txt.gz", "vendored/clipseg/*.pth"]},
package_data={
"imaginairy": [
"configs/*.yaml",
"vendored/clip/*.txt.gz",
"vendored/clipseg/*.pth",
]
},
install_requires=[
"click",
"protobuf != 3.20.2, != 3.19.5",

Binary file not shown.

After

Width:  |  Height:  |  Size: 553 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

@ -1,9 +1,7 @@
import hashlib
import torch
from PIL import Image
from pytorch_lightning import seed_everything
from torchvision import transforms
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
@ -28,6 +26,10 @@ def img_hash(img):
def test_clip_masking():
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
preds = get_img_mask(img, ['background'])
mask = transforms.ToPILImage()(torch.sigmoid(preds[0][0]))
mask.save(f"{TESTS_FOLDER}/test_output/earring_mask.png")
pred = get_img_mask(img, "head")
pred.save(f"{TESTS_FOLDER}/test_output/earring_mask.png")
def test_clip_inpainting():
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
pred = get_img_mask(img, "background")

@ -89,3 +89,39 @@ def test_img_to_file():
)
out_folder = f"{TESTS_FOLDER}/test_output"
imagine_image_files(prompt, outdir=out_folder)
def test_inpainting():
prompt = ImaginePrompt(
"a basketball on a bench",
init_image=f"{TESTS_FOLDER}/data/bench2.png",
init_image_strength=0.4,
mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png"),
width=512,
height=512,
steps=5,
seed=1,
sampler_type="DDIM",
)
out_folder = f"{TESTS_FOLDER}/test_output"
imagine_image_files(prompt, outdir=out_folder)
def test_cliptext_inpainting():
prompts = [
ImaginePrompt(
"elegant woman. oil painting",
prompt_strength=12,
init_image=f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg",
init_image_strength=0.3,
mask_prompt="face",
mask_mode=ImaginePrompt.MaskMode.KEEP,
mask_expansion=-3,
width=512,
height=512,
steps=5,
sampler_type="DDIM",
),
]
out_folder = f"{TESTS_FOLDER}/test_output"
imagine_image_files(prompts, outdir=out_folder)

@ -18,7 +18,7 @@ def test_is_nsfw():
def _pil_to_latent(img):
model = load_model()
img, w, h = pillow_img_to_torch_image(img)
img = pillow_img_to_torch_image(img)
img = img.to(get_device())
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
return latent

Loading…
Cancel
Save