Merge pull request #9 from brycedrennan/clip-masking

Clip masking
pull/11/head
Bryce Drennan 2 years ago committed by GitHub
commit a5e010d01f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -76,10 +76,25 @@ vendor_openai_clip:
git --git-dir ./downloads/CLIP/.git rev-parse HEAD | tee ./imaginairy/vendored/clip/clip-commit-hash.txt
echo "vendored from git@github.com:openai/CLIP.git" | tee ./imaginairy/vendored/clip/readme.txt
revendorize:
revendorize: vendorize_kdiffusion
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip COMMIT=d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
make af
vendorize_clipseg:
make download_repo REPO=git@github.com:timojl/clipseg.git PKG=clipseg COMMIT=664ee94393491cdd7ad422f67eb1ce670d3d00e6
rm -rf ./imaginairy/vendored/clipseg
mkdir -p ./imaginairy/vendored/clipseg
cp -R ./downloads/clipseg/models/* ./imaginairy/vendored/clipseg/
sed -i '' -e 's#import clip#from imaginairy.vendored import clip#g' ./imaginairy/vendored/clipseg/clipseg.py
rm ./imaginairy/vendored/clipseg/vitseg.py
mv ./imaginairy/vendored/clipseg/clipseg.py ./imaginairy/vendored/clipseg/__init__.py
wget https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth -P ./imaginairy/vendored/clipseg
vendorize_kdiffusion:
make vendorize REPO=git@github.com:crowsonkb/k-diffusion.git PKG=k_diffusion COMMIT=1a0703dfb7d24d8806267c3e7ccc4caf67fd1331
#sed -i '' -e 's/^import\sclip/from\simaginairy.vendored\simport\sclip/g' imaginairy/vendored/k_diffusion/evaluation.py
#sed -i '' -e 's/import\sclip/from\simaginairy.vendored\simport\sclip/g' imaginairy/vendored/k_diffusion/evaluation.py
rm imaginairy/vendored/k_diffusion/evaluation.py
touch imaginairy/vendored/k_diffusion/evaluation.py
rm imaginairy/vendored/k_diffusion/config.py
@ -89,8 +104,6 @@ revendorize:
sed -i '' -e 's#x = x + torch.randn_like(x) \* sigma_up#x = x + torch.randn_like(x, device="cpu").to(x.device) \* sigma_up#g' imaginairy/vendored/k_diffusion/sampling.py
make af
vendorize: ## vendorize a github repo. `make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip`
mkdir -p ./downloads
-cd ./downloads && git clone $(REPO) $(PKG)
@ -101,6 +114,11 @@ vendorize: ## vendorize a github repo. `make vendorize REPO=git@github.com:ope
touch ./imaginairy/vendored/$(PKG)/version.py
echo "vendored from $(REPO)" | tee ./imaginairy/vendored/$(PKG)/readme.txt
download_repo:
mkdir -p ./downloads
-cd ./downloads && git clone $(REPO) $(PKG)
cd ./downloads/$(PKG) && git pull
vendorize_whole_repo:
mkdir -p ./downloads
-cd ./downloads && git clone $(REPO) $(PKG)

@ -34,31 +34,31 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000019_786355545_PLMS50_PS7.5_a_scenic_landscape.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000032_337692011_PLMS40_PS7.5_a_photo_of_a_dog.jpg" height="256"><br>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000078_260972468_PLMS40_PS7.5_portrait_photo_of_a_freckled_woman.jpg" height="256">
### Tiled Images
### Automated Replacement (txt2mask) [by clipseg](https://github.com/timojl/clipseg)
```bash
>> imagine "gold coins" "a lush forest" "piles of old books" leaves --tile
>> imagine --init-image pearl_earring.jpg --mask-prompt face --mask-mode keep --init-image-strength .4 "a female doctor" "an elegant woman"
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128">
<br>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128">
### Image-to-Image
<img src="assets/mask_examples/pearl000.jpg" height="256">➡️
<img src="assets/mask_examples/pearl002.jpg" height="256">
<img src="assets/mask_examples/pearl004.jpg" height="256">
<img src="assets/mask_examples/pearl001.jpg" height="256">
<img src="assets/mask_examples/pearl003.jpg" height="256">
```bash
>> imagine "portrait of a smiling lady. oil painting" --init-image girl_with_a_pearl_earring.jpg
>> imagine --init-image fruit-bowl.jpg --mask-prompt fruit --mask-mode replace --init-image-strength .1 "a bowl of pears" "a bowl of gold" "a bowl of popcorn" "a bowl of spaghetti"
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/tests/data/girl_with_a_pearl_earring.jpg" height="256"> =>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000105_33084057_DDIM40_PS7.5_portrait_of_a_smiling_lady._oil_painting._.jpg" height="256">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="256">➡️
<img src="assets/mask_examples/bowl004.jpg" height="256">
<img src="assets/mask_examples/bowl001.jpg" height="256">
<img src="assets/mask_examples/bowl002.jpg" height="256">
<img src="assets/mask_examples/bowl003.jpg" height="256">
### Face Enhancement [by CodeFormer](https://github.com/sczhou/CodeFormer)
```bash
>> imagine "a couple smiling" --steps 40 --seed 1 --fix-faces
```
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png" height="256"> =>
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000178_1_PLMS40_PS7.5_a_couple_smiling_nofix.png" height="256"> ➡️
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000178_1_PLMS40_PS7.5_a_couple_smiling_fixed.png" height="256">
@ -66,9 +66,28 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
```bash
>> imagine "colorful smoke" --steps 40 --upscale
```
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg" height="128"> =>
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg" height="128"> ➡️
<img src="https://github.com/brycedrennan/imaginAIry/raw/master/assets/000206_856637805_PLMS40_PS7.5_colorful_smoke_upscaled.jpg" height="256">
### Tiled Images
```bash
>> imagine "gold coins" "a lush forest" "piles of old books" leaves --tile
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000066_801493266_PLMS40_PS7.5_gold_coins.jpg" height="128">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000118_597948545_PLMS40_PS7.5_a_lush_forest.jpg" height="128">
<br>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000075_961095192_PLMS40_PS7.5_piles_of_old_books.jpg" height="128">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000040_527733581_PLMS40_PS7.5_leaves.jpg" height="128">
### Image-to-Image
```bash
>> imagine "portrait of a smiling lady. oil painting" --init-image girl_with_a_pearl_earring.jpg
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/tests/data/girl_with_a_pearl_earring.jpg" height="256"> ➡️
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000105_33084057_DDIM40_PS7.5_portrait_of_a_smiling_lady._oil_painting._.jpg" height="256">
## Features
- It makes images from text descriptions! 🎉

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

@ -1,4 +1,4 @@
import os
import os.path
# tells pytorch to allow MPS usage (for Mac M1 compatibility)
os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1")
@ -10,3 +10,5 @@ from .schema import ( # noqa
LazyLoadingImage,
WeightedPrompt,
)
PKG_ROOT = os.path.dirname(__file__)

@ -9,21 +9,29 @@ import torch
import torch.nn
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image, ImageDraw, ImageFilter
from PIL import Image, ImageDraw, ImageFilter, ImageOps
from pytorch_lightning import seed_everything
from torch import autocast
from transformers import cached_path
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.img_log import ImageLoggingContext, log_conditioning, log_latent
from imaginairy.img_log import (
ImageLoggingContext,
log_conditioning,
log_img,
log_latent,
)
from imaginairy.safety import is_nsfw
from imaginairy.samplers.base import get_sampler
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
expand_mask,
fix_torch_nn_layer_norm,
get_device,
instantiate_from_config,
pillow_fit_image_within,
pillow_img_to_torch_image,
)
@ -202,22 +210,61 @@ def imagine(
prompt.height // downsampling_factor,
prompt.width // downsampling_factor,
]
if prompt.init_image:
sampler_type = "ddim"
else:
sampler_type = prompt.sampler_type
start_code = None
sampler = get_sampler(prompt.sampler_type, model)
sampler = get_sampler(sampler_type, model)
mask, mask_image, mask_image_orig = None, None, None
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
init_image, w, h = pillow_img_to_torch_image( # noqa
init_image, _, h = pillow_fit_image_within(
prompt.init_image,
max_height=prompt.height,
max_width=prompt.width,
)
init_image = init_image.to(get_device())
init_image_t = pillow_img_to_torch_image(init_image)
if prompt.mask_prompt:
mask_image = get_img_mask(init_image, prompt.mask_prompt)
elif prompt.mask_image:
mask_image = prompt.mask_image
if mask_image is not None:
log_img(mask_image, "init mask")
# mask_image = mask_image.filter(ImageFilter.GaussianBlur(8))
mask_image = expand_mask(mask_image, prompt.mask_expansion)
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image)
log_img(mask_image, "init mask expanded")
log_img(
Image.composite(init_image, mask_image, mask_image),
"mask overlay",
)
mask_image_orig = mask_image
mask_image = mask_image.resize(
(
mask_image.width // downsampling_factor,
mask_image.height // downsampling_factor,
),
resample=Image.Resampling.NEAREST,
)
log_img(mask_image, "init mask 2")
mask = np.array(mask_image)
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.9] = 0
mask[mask >= 0.9] = 1
mask = torch.from_numpy(mask)
mask = mask.to(get_device())
init_image_t = init_image_t.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image)
model.encode_first_stage(init_image_t)
)
log_latent(init_latent, "init_latent")
@ -236,6 +283,8 @@ def imagine(
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=_img_callback,
mask=mask,
orig_latent=init_latent,
)
else:
@ -260,6 +309,14 @@ def imagine(
)
x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig)
if mask_image_orig and init_image:
mask_image_orig = mask_image_orig.filter(
ImageFilter.GaussianBlur(radius=3)
)
mask_image_orig = ImageOps.invert(mask_image_orig)
img = Image.composite(img, init_image, mask_image_orig)
log_img(img, "reconstituted image")
upscaled_img = None
is_nsfw_img = None
if IMAGINAIRY_SAFETY_MODE != SafetyMode.DISABLED:
@ -269,7 +326,7 @@ def imagine(
img = img.filter(ImageFilter.GaussianBlur(radius=40))
if prompt.fix_faces:
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")
logger.info(" Fixing 😊 's in 🖼 using CodeFormer...")
img = enhance_faces(img, fidelity=0.2)
if prompt.upscale:
logger.info(" Upscaling 🖼 using real-ESRGAN...")

@ -123,6 +123,26 @@ def configure_logging(level="INFO"):
is_flag=True,
help="Any images rendered will be tileable. Unfortunately cannot be controlled at the per-image level yet",
)
@click.option(
"--mask-image",
help="A mask to use for inpainting. White gets painted, Black is left alone.",
)
@click.option(
"--mask-prompt",
help="Describe what you want masked and the AI will mask it for you",
)
@click.option(
"--mask-mode",
default="replace",
type=click.Choice(["keep", "replace"]),
help="Should we replace the masked area or keep it?",
)
@click.option(
"--mask-expansion",
default="8",
type=int,
help="How much to grow (or shrink) the mask area",
)
def imagine_cmd(
prompt_texts,
prompt_strength,
@ -141,6 +161,10 @@ def imagine_cmd(
log_level,
show_work,
tile,
mask_image,
mask_prompt,
mask_mode,
mask_expansion,
):
"""Render an image"""
suppress_annoying_logs_and_warnings()
@ -161,7 +185,6 @@ def imagine_cmd(
load_model(tile_mode=tile)
for _ in range(repeats):
for prompt_text in prompt_texts:
prompt = ImaginePrompt(
prompt_text,
prompt_strength=prompt_strength,
@ -172,6 +195,10 @@ def imagine_cmd(
steps=steps,
height=height,
width=width,
mask_image=mask_image,
mask_prompt=mask_prompt,
mask_expansion=mask_expansion,
mask_mode=mask_mode,
upscale=upscale,
fix_faces=fix_faces,
)

@ -0,0 +1,64 @@
from functools import lru_cache
import torch
from torchvision import transforms
from imaginairy.img_log import log_img
from imaginairy.vendored.clipseg import CLIPDensePredT
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
@lru_cache()
def clip_mask_model():
from imaginairy import PKG_ROOT
model = CLIPDensePredT(version="ViT-B/16", reduce_dim=64)
model.eval()
model.load_state_dict(
torch.load(
f"{PKG_ROOT}/vendored/clipseg/rd64-uni.pth",
map_location=torch.device("cpu"),
),
strict=False,
)
return model
def get_img_mask(img, mask_description):
return get_img_masks(img, [mask_description])[0]
def get_img_masks(img, mask_descriptions):
a, b = img.size
orig_size = b, a
log_img(img, "image for masking")
# orig_shape = tuple(img.shape)[1:]
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((352, 352)),
]
)
img = transform(img).unsqueeze(0)
with torch.no_grad():
preds = clip_mask_model()(
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
)[0]
preds = transforms.Resize(orig_size)(preds)
preds = [torch.sigmoid(p[0]) for p in preds]
bw_preds = []
for p in preds:
log_img(p, f"clip mask for {mask_descriptions}")
# bw_preds.append(pred_transform(p))
_min = p.min()
_max = p.max()
_range = _max - _min
p = (p > (_min + (_range * 0.5))).float()
bw_preds.append(transforms.ToPILImage()(p))
return bw_preds

@ -19,6 +19,8 @@ def realesrgan_upsampler():
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0)
device = get_device()
if "mps" in device:
device = "cpu"
upsampler.device = torch.device(device)
upsampler.model.to(device)
@ -28,6 +30,7 @@ def realesrgan_upsampler():
def upscale_image(img):
img = img.convert("RGB")
np_img = np.array(img, dtype=np.uint8)
upsampler_output, img_mode = realesrgan_upsampler().enhance(np_img[:, :, ::-1])
return Image.fromarray(upsampler_output[:, :, ::-1], mode=img_mode)

@ -32,6 +32,12 @@ def log_latent(latents, description):
_CURRENT_LOGGING_CONTEXT.log_latents(latents, description)
def log_img(img, description):
if _CURRENT_LOGGING_CONTEXT is None:
return
_CURRENT_LOGGING_CONTEXT.log_img(img, description)
class ImageLoggingContext:
def __init__(self, prompt, model, img_callback=None, img_outdir=None):
self.prompt = prompt
@ -71,6 +77,15 @@ class ImageLoggingContext:
img = Image.fromarray(latent.astype(np.uint8))
self.img_callback(img, description, self.step_count, self.prompt)
def log_img(self, img, description):
if not self.img_callback:
return
self.step_count += 1
if isinstance(img, torch.Tensor):
img = ToPILImage()(img.squeeze().cpu().detach())
img = img.copy()
self.img_callback(img, description, self.step_count, self.prompt)
# def img_callback(self, img, description, step_count, prompt):
# steps_path = os.path.join(self.img_outdir, "steps", f"{self.file_num:08}_S{prompt.seed}")
# os.makedirs(steps_path, exist_ok=True)

@ -16,7 +16,11 @@ from einops import rearrange
from torchvision.utils import make_grid
from tqdm import tqdm
from imaginairy.modules.diffusion.util import make_beta_schedule, noise_like
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_beta_schedule,
noise_like,
)
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.utils import instantiate_from_config, log_params
@ -852,6 +856,18 @@ class LatentDiffusion(DDPM):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
def q_sample(self, x_start, t, noise=None):
noise = (
noise
if noise is not None
else torch.randn_like(x_start, device="cpu").to(x_start.device)
)
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):

@ -42,20 +42,25 @@ def get_sampler(sampler_type, model):
class CFGDenoiser(nn.Module):
"""
Conditional forward guidance wrapper
"""
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
denoised = uncond + (cond - uncond) * cond_scale
if mask is not None:
assert orig_latent is not None
mask_inv = 1.0 - mask
denoised = (orig_latent * mask_inv) + (mask * denoised)
return denoised
class DiffusionSampler:

@ -335,10 +335,11 @@ class DDIMSampler:
img_callback=None,
score_corrector=None,
temperature=1.0,
mask=None,
orig_latent=None,
):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start]
timesteps = self.ddim_timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
@ -352,6 +353,15 @@ class DDIMSampler:
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
)
if mask is not None:
assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts)
log_latent(xdec_orig, "xdec_orig")
log_latent(xdec_orig * mask, "masked_xdec_orig")
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
log_latent(x_dec, "x_dec")
x_dec, pred_x0 = self.p_sample_ddim(
x_dec,
cond,

@ -8,28 +8,6 @@ from imaginairy.vendored.k_diffusion import sampling as k_sampling
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
class CFGMaskedDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
x_in = x
x_in = torch.cat([x_in] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
if mask is not None:
assert x0 is not None
img_orig = x0
mask_inv = 1.0 - mask
denoised = (img_orig * mask_inv) + (mask * denoised)
return denoised
class KDiffusionSampler:
def __init__(self, model, sampler_name):
self.model = model

@ -81,12 +81,20 @@ class WeightedPrompt:
class ImaginePrompt:
class MaskMode:
KEEP = "keep"
REPLACE = "replace"
def __init__(
self,
prompt=None,
prompt_strength=7.5,
init_image=None, # Pillow Image, LazyLoadingImage, or filepath str
init_image_strength=0.3,
mask_prompt=None,
mask_image=None,
mask_mode=MaskMode.REPLACE,
mask_expansion=8,
seed=None,
steps=50,
height=512,
@ -105,6 +113,10 @@ class ImaginePrompt:
self.prompt_strength = prompt_strength
if isinstance(init_image, str):
init_image = LazyLoadingImage(filepath=init_image)
if mask_image is not None and mask_prompt is not None:
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
self.init_image = init_image
self.init_image_strength = init_image_strength
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
@ -115,6 +127,10 @@ class ImaginePrompt:
self.fix_faces = fix_faces
self.sampler_type = sampler_type
self.conditioning = conditioning
self.mask_prompt = mask_prompt
self.mask_image = mask_image
self.mask_mode = mask_mode
self.mask_expansion = mask_expansion
@property
def prompt_text(self):

@ -9,12 +9,14 @@ from typing import List, Optional
import numpy as np
import requests
import torch
from PIL import Image
from PIL import Image, ImageFilter
from torch import Tensor
from torch.nn import functional
from torch.overrides import handle_torch_function, has_torch_function_variadic
from transformers import cached_path
from imaginairy.img_log import log_img
logger = logging.getLogger(__name__)
@ -102,23 +104,45 @@ def fix_torch_nn_layer_norm():
functional.layer_norm = orig_function
def img_path_to_torch_image(path, max_height=512, max_width=512):
def expand_mask(mask_image, size):
if size < 0:
threshold = 0.9
else:
threshold = 0.1
mask_image = mask_image.convert("L")
mask_image = mask_image.filter(ImageFilter.GaussianBlur(size))
log_img(mask_image, "init mask blurred")
mask = np.array(mask_image)
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < threshold] = 0
mask[mask >= threshold] = 1
return Image.fromarray(np.uint8(mask.squeeze() * 255))
def img_path_to_torch_image(path):
image = Image.open(path).convert("RGB")
logger.info(f"Loaded input 🖼 of size {image.size} from {path}")
return pillow_img_to_torch_image(image, max_height=max_height, max_width=max_width)
return pillow_img_to_torch_image(image)
def pillow_img_to_torch_image(image, max_height=512, max_width=512):
def pillow_fit_image_within(image, max_height=512, max_width=512):
image = image.convert("RGB")
w, h = image.size
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.NEAREST)
return image, w, h
def pillow_img_to_torch_image(image):
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0, w, h
return 2.0 * image - 1.0
def get_cache_dir():

@ -0,0 +1,727 @@
import math
from os.path import basename, dirname, isfile, join
import torch
from torch import nn
from torch.nn import functional as nnf
from torch.nn.modules.activation import ReLU
def precompute_clip_vectors():
from trails.initialization import init_dataset
lvis = init_dataset(
"LVIS_OneShot3",
split="train",
mask="text_label",
image_size=224,
aug=1,
normalize=True,
reduce_factor=None,
add_bar=False,
negative_prob=0.5,
)
all_names = list(lvis.category_names.values())
from models.clip_prompts import imagenet_templates
from imaginairy.vendored import clip
clip_model = clip.load("ViT-B/32", device="cuda", jit=False)[0]
prompt_vectors = {}
for name in all_names[:100]:
with torch.no_grad():
conditionals = [
t.format(name).replace("_", " ") for t in imagenet_templates
]
text_tokens = clip.tokenize(conditionals).cuda()
cond = clip_model.encode_text(text_tokens).cpu()
for cond, vec in zip(conditionals, cond):
prompt_vectors[cond] = vec.cpu()
import pickle
pickle.dump(prompt_vectors, open("precomputed_prompt_vectors.pickle", "wb"))
def get_prompt_list(prompt):
if prompt == "plain":
return ["{}"]
elif prompt == "fixed":
return ["a photo of a {}."]
elif prompt == "shuffle":
return ["a photo of a {}.", "a photograph of a {}.", "an image of a {}.", "{}."]
elif prompt == "shuffle+":
return [
"a photo of a {}.",
"a photograph of a {}.",
"an image of a {}.",
"{}.",
"a cropped photo of a {}.",
"a good photo of a {}.",
"a photo of one {}.",
"a bad photo of a {}.",
"a photo of the {}.",
]
elif prompt == "shuffle_clip":
from models.clip_prompts import imagenet_templates
return imagenet_templates
else:
raise ValueError("Invalid value for prompt")
def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
"""
Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
The mlp and layer norm come from CLIP.
x: input.
b: multihead attention module.
"""
x_ = b.ln_1(x)
q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(
3, dim=-1
)
tgt_len, bsz, embed_dim = q.size()
head_dim = embed_dim // b.attn.num_heads
scaling = float(head_dim) ** -0.5
q = (
q.contiguous()
.view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim)
.transpose(0, 1)
)
k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
q = q * scaling
attn_output_weights = torch.bmm(
q, k.transpose(1, 2)
) # n_heads * batch_size, tokens^2, tokens^2
if attn_mask is not None:
attn_mask_type, attn_mask = attn_mask
n_heads = attn_output_weights.size(0) // attn_mask.size(0)
attn_mask = attn_mask.repeat(n_heads, 1)
if attn_mask_type == "cls_token":
# the mask only affects similarities compared to the readout-token.
attn_output_weights[:, 0, 1:] = (
attn_output_weights[:, 0, 1:] * attn_mask[None, ...]
)
# attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0]
if attn_mask_type == "all":
# print(attn_output_weights.shape, attn_mask[:, None].shape)
attn_output_weights[:, 1:, 1:] = (
attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
)
attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = b.attn.out_proj(attn_output)
x = x + attn_output
x = x + b.mlp(b.ln_2(x))
if with_aff:
return x, attn_output_weights
else:
return x
class CLIPDenseBase(nn.Module):
def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens):
super().__init__()
from imaginairy.vendored import clip
# prec = torch.FloatTensor
self.clip_model, _ = clip.load(version, device="cpu", jit=False)
self.model = self.clip_model.visual
# if not None, scale conv weights such that we obtain n_tokens.
self.n_tokens = n_tokens
for p in self.clip_model.parameters():
p.requires_grad_(False)
# conditional
if reduce_cond is not None:
self.reduce_cond = nn.Linear(512, reduce_cond)
for p in self.reduce_cond.parameters():
p.requires_grad_(False)
else:
self.reduce_cond = None
self.film_mul = nn.Linear(
512 if reduce_cond is None else reduce_cond, reduce_dim
)
self.film_add = nn.Linear(
512 if reduce_cond is None else reduce_cond, reduce_dim
)
self.reduce = nn.Linear(768, reduce_dim)
self.prompt_list = get_prompt_list(prompt)
# precomputed prompts
import pickle
if isfile("precomputed_prompt_vectors.pickle"):
precomp = pickle.load(open("precomputed_prompt_vectors.pickle", "rb"))
self.precomputed_prompts = {
k: torch.from_numpy(v) for k, v in precomp.items()
}
else:
self.precomputed_prompts = dict()
def rescaled_pos_emb(self, new_size):
assert len(new_size) == 2
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
b = (
nnf.interpolate(a, new_size, mode="bicubic", align_corners=False)
.squeeze(0)
.view(768, new_size[0] * new_size[1])
.T
)
return torch.cat([self.model.positional_embedding[:1], b])
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
with torch.no_grad():
inp_size = x_inp.shape[2:]
if self.n_tokens is not None:
stride2 = x_inp.shape[2] // self.n_tokens
conv_weight2 = nnf.interpolate(
self.model.conv1.weight,
(stride2, stride2),
mode="bilinear",
align_corners=True,
)
x = nnf.conv2d(
x_inp,
conv_weight2,
bias=self.model.conv1.bias,
stride=stride2,
dilation=self.model.conv1.dilation,
)
else:
x = self.model.conv1(x_inp) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.model.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197
if x.shape[1] != standard_n_tokens:
new_shape = int(math.sqrt(x.shape[1] - 1))
x = (
x
+ self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[
None, :, :
]
)
else:
x = x + self.model.positional_embedding.to(x.dtype)
x = self.model.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
activations, affinities = [], []
for i, res_block in enumerate(self.model.transformer.resblocks):
if mask is not None:
mask_layer, mask_type, mask_tensor = mask
if mask_layer == i or mask_layer == "all":
# import ipdb; ipdb.set_trace()
size = int(math.sqrt(x.shape[0] - 1))
attn_mask = (
mask_type,
nnf.interpolate(
mask_tensor.unsqueeze(1).float(), (size, size)
).view(mask_tensor.shape[0], size * size),
)
else:
attn_mask = None
else:
attn_mask = None
x, aff_per_head = forward_multihead_attention(
x, res_block, with_aff=True, attn_mask=attn_mask
)
if i in extract_layers:
affinities += [aff_per_head]
# if self.n_tokens is not None:
# activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)]
# else:
activations += [x]
if len(extract_layers) > 0 and i == max(extract_layers) and skip:
print("early skip")
break
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_post(x[:, 0, :])
if self.model.proj is not None:
x = x @ self.model.proj
return x, activations, affinities
def sample_prompts(self, words, prompt_list=None):
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
prompt_indices = torch.multinomial(
torch.ones(len(prompt_list)), len(words), replacement=True
)
prompts = [prompt_list[i] for i in prompt_indices]
return [promt.format(w) for promt, w in zip(prompts, words)]
def get_cond_vec(self, conditional, batch_size):
# compute conditional from a single string
if conditional is not None and type(conditional) == str:
cond = self.compute_conditional(conditional)
cond = cond.repeat(batch_size, 1)
# compute conditional from string list/tuple
elif (
conditional is not None
and type(conditional) in {list, tuple}
and type(conditional[0]) == str
):
assert len(conditional) == batch_size
cond = self.compute_conditional(conditional)
# use conditional directly
elif (
conditional is not None
and type(conditional) == torch.Tensor
and conditional.ndim == 2
):
cond = conditional
# compute conditional from image
elif conditional is not None and type(conditional) == torch.Tensor:
with torch.no_grad():
cond, _, _ = self.visual_forward(conditional)
else:
raise ValueError("invalid conditional")
return cond
def compute_conditional(self, conditional):
from imaginairy.vendored import clip
dev = next(self.parameters()).device
if type(conditional) in {list, tuple}:
text_tokens = clip.tokenize(conditional).to(dev)
cond = self.clip_model.encode_text(text_tokens)
else:
if conditional in self.precomputed_prompts:
cond = self.precomputed_prompts[conditional].float().to(dev)
else:
text_tokens = clip.tokenize([conditional]).to(dev)
cond = self.clip_model.encode_text(text_tokens)[0]
if self.shift_vector is not None:
return cond + self.shift_vector
else:
return cond
def clip_load_untrained(version):
assert version == "ViT-B/16"
from clip.clip import _MODELS, _download
from clip.model import CLIP
model = torch.jit.load(_download(_MODELS["ViT-B/16"])).eval()
state_dict = model.state_dict()
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[
k
for k in state_dict.keys()
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
]
)
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"transformer.resblocks")
)
)
return CLIP(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
)
class CLIPDensePredT(CLIPDenseBase):
def __init__(
self,
version="ViT-B/32",
extract_layers=(3, 6, 9),
cond_layer=0,
reduce_dim=128,
n_heads=4,
prompt="fixed",
extra_blocks=0,
reduce_cond=None,
fix_shift=False,
learn_trans_conv_only=False,
limit_to_clip_only=False,
upsample=False,
add_calibration=False,
rev_activations=False,
trans_conv=None,
n_tokens=None,
):
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
# device = 'cpu'
self.extract_layers = extract_layers
self.cond_layer = cond_layer
self.limit_to_clip_only = limit_to_clip_only
self.process_cond = None
self.rev_activations = rev_activations
depth = len(extract_layers)
if add_calibration:
self.calibration_conds = 1
self.upsample_proj = (
nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
)
self.add_activation1 = True
self.version = version
self.token_shape = {"ViT-B/32": (7, 7), "ViT-B/16": (14, 14)}[version]
if fix_shift:
# self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False)
self.shift_vector = nn.Parameter(
torch.load(join(dirname(basename(__file__)), "shift_text_to_vis.pth")),
requires_grad=False,
)
# self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False)
else:
self.shift_vector = None
if trans_conv is None:
trans_conv_ks = {"ViT-B/32": (32, 32), "ViT-B/16": (16, 16)}[version]
else:
# explicitly define transposed conv kernel size
trans_conv_ks = (trans_conv, trans_conv)
self.trans_conv = nn.ConvTranspose2d(
reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks
)
assert len(self.extract_layers) == depth
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
self.blocks = nn.ModuleList(
[
nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads)
for _ in range(len(self.extract_layers))
]
)
self.extra_blocks = nn.ModuleList(
[
nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads)
for _ in range(extra_blocks)
]
)
# refinement and trans conv
if learn_trans_conv_only:
for p in self.parameters():
p.requires_grad_(False)
for p in self.trans_conv.parameters():
p.requires_grad_(True)
self.prompt_list = get_prompt_list(prompt)
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
assert type(return_features) == bool
inp_image = inp_image.to(self.model.positional_embedding.device)
if mask is not None:
raise ValueError("mask not supported")
# x_inp = normalize(inp_image)
x_inp = inp_image
bs, dev = inp_image.shape[0], x_inp.device
cond = self.get_cond_vec(conditional, bs)
visual_q, activations, _ = self.visual_forward(
x_inp, extract_layers=[0] + list(self.extract_layers)
)
activation1 = activations[0]
activations = activations[1:]
_activations = activations[::-1] if not self.rev_activations else activations
a = None
for i, (activation, block, reduce) in enumerate(
zip(_activations, self.blocks, self.reduces)
):
if a is not None:
a = reduce(activation) + a
else:
a = reduce(activation)
if i == self.cond_layer:
if self.reduce_cond is not None:
cond = self.reduce_cond(cond)
a = self.film_mul(cond) * a + self.film_add(cond)
a = block(a)
for block in self.extra_blocks:
a = a + block(a)
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
size = int(math.sqrt(a.shape[2]))
a = a.view(bs, a.shape[1], size, size)
a = self.trans_conv(a)
if self.n_tokens is not None:
a = nnf.interpolate(a, x_inp.shape[2:], mode="bilinear", align_corners=True)
if self.upsample_proj is not None:
a = self.upsample_proj(a)
a = nnf.interpolate(a, x_inp.shape[2:], mode="bilinear")
if return_features:
return a, visual_q, cond, [activation1] + activations
else:
return (a,)
class CLIPDensePredTMasked(CLIPDensePredT):
def __init__(
self,
version="ViT-B/32",
extract_layers=(3, 6, 9),
cond_layer=0,
reduce_dim=128,
n_heads=4,
prompt="fixed",
extra_blocks=0,
reduce_cond=None,
fix_shift=False,
learn_trans_conv_only=False,
refine=None,
limit_to_clip_only=False,
upsample=False,
add_calibration=False,
n_tokens=None,
):
super().__init__(
version=version,
extract_layers=extract_layers,
cond_layer=cond_layer,
reduce_dim=reduce_dim,
n_heads=n_heads,
prompt=prompt,
extra_blocks=extra_blocks,
reduce_cond=reduce_cond,
fix_shift=fix_shift,
learn_trans_conv_only=learn_trans_conv_only,
limit_to_clip_only=limit_to_clip_only,
upsample=upsample,
add_calibration=add_calibration,
n_tokens=n_tokens,
)
def visual_forward_masked(self, img_s, seg_s):
return super().visual_forward(img_s, mask=("all", "cls_token", seg_s))
def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False):
if seg_s is None:
cond = cond_or_img_s
else:
img_s = cond_or_img_s
with torch.no_grad():
cond, _, _ = self.visual_forward_masked(img_s, seg_s)
return super().forward(img_q, cond, return_features=return_features)
class CLIPDenseBaseline(CLIPDenseBase):
def __init__(
self,
version="ViT-B/32",
cond_layer=0,
extract_layer=9,
reduce_dim=128,
reduce2_dim=None,
prompt="fixed",
reduce_cond=None,
limit_to_clip_only=False,
n_tokens=None,
):
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
device = "cpu"
# self.cond_layer = cond_layer
self.extract_layer = extract_layer
self.limit_to_clip_only = limit_to_clip_only
self.shift_vector = None
self.token_shape = {"ViT-B/32": (7, 7), "ViT-B/16": (14, 14)}[version]
assert reduce2_dim is not None
self.reduce2 = nn.Sequential(
nn.Linear(reduce_dim, reduce2_dim),
nn.ReLU(),
nn.Linear(reduce2_dim, reduce_dim),
)
trans_conv_ks = {"ViT-B/32": (32, 32), "ViT-B/16": (16, 16)}[version]
self.trans_conv = nn.ConvTranspose2d(
reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks
)
def forward(self, inp_image, conditional=None, return_features=False):
inp_image = inp_image.to(self.model.positional_embedding.device)
# x_inp = normalize(inp_image)
x_inp = inp_image
bs, dev = inp_image.shape[0], x_inp.device
cond = self.get_cond_vec(conditional, bs)
visual_q, activations, affinities = self.visual_forward(
x_inp, extract_layers=[self.extract_layer]
)
a = activations[0]
a = self.reduce(a)
a = self.film_mul(cond) * a + self.film_add(cond)
if self.reduce2 is not None:
a = self.reduce2(a)
# the original model would execute a transformer block here
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
size = int(math.sqrt(a.shape[2]))
a = a.view(bs, a.shape[1], size, size)
a = self.trans_conv(a)
if return_features:
return a, visual_q, cond, activations
else:
return (a,)
class CLIPSegMultiLabel(nn.Module):
def __init__(self, model) -> None:
super().__init__()
from third_party.JoEm.data_loader import VOC, get_seen_idx, get_unseen_idx
self.pascal_classes = VOC
from general_utils import load_model
from models.clipseg import CLIPDensePredT
# self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
self.clipseg = load_model(model, strict=False)
self.clipseg.eval()
def forward(self, x):
bs = x.shape[0]
out = torch.ones(21, bs, 352, 352).to(x.device) * -10
for class_id, class_name in enumerate(self.pascal_classes):
fac = 3 if class_name == "background" else 1
with torch.no_grad():
pred = torch.sigmoid(self.clipseg(x, class_name)[0][:, 0]) * fac
out[class_id] += pred
out = out.permute(1, 0, 2, 3)
return out
# construct output tensor

@ -19,7 +19,13 @@ setup(
entry_points={
"console_scripts": ["imagine=imaginairy.cmds:imagine_cmd"],
},
package_data={"imaginairy": ["configs/*.yaml", "vendored/clip/*.txt.gz"]},
package_data={
"imaginairy": [
"configs/*.yaml",
"vendored/clip/*.txt.gz",
"vendored/clipseg/*.pth",
]
},
install_requires=[
"click",
"protobuf != 3.20.2, != 3.19.5",

Binary file not shown.

After

Width:  |  Height:  |  Size: 553 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

@ -3,6 +3,7 @@ import hashlib
from PIL import Image
from pytorch_lightning import seed_everything
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.utils import get_device
from tests import TESTS_FOLDER
@ -21,3 +22,14 @@ def test_fix_faces():
def img_hash(img):
return hashlib.md5(img.tobytes()).hexdigest()
def test_clip_masking():
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
pred = get_img_mask(img, "head")
pred.save(f"{TESTS_FOLDER}/test_output/earring_mask.png")
def test_clip_inpainting():
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
pred = get_img_mask(img, "background")

@ -89,3 +89,39 @@ def test_img_to_file():
)
out_folder = f"{TESTS_FOLDER}/test_output"
imagine_image_files(prompt, outdir=out_folder)
def test_inpainting():
prompt = ImaginePrompt(
"a basketball on a bench",
init_image=f"{TESTS_FOLDER}/data/bench2.png",
init_image_strength=0.4,
mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png"),
width=512,
height=512,
steps=5,
seed=1,
sampler_type="DDIM",
)
out_folder = f"{TESTS_FOLDER}/test_output"
imagine_image_files(prompt, outdir=out_folder)
def test_cliptext_inpainting():
prompts = [
ImaginePrompt(
"elegant woman. oil painting",
prompt_strength=12,
init_image=f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg",
init_image_strength=0.3,
mask_prompt="face",
mask_mode=ImaginePrompt.MaskMode.KEEP,
mask_expansion=-3,
width=512,
height=512,
steps=5,
sampler_type="DDIM",
),
]
out_folder = f"{TESTS_FOLDER}/test_output"
imagine_image_files(prompts, outdir=out_folder)

@ -18,7 +18,7 @@ def test_is_nsfw():
def _pil_to_latent(img):
model = load_model()
img, w, h = pillow_img_to_torch_image(img)
img = pillow_img_to_torch_image(img)
img = img.to(get_device())
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
return latent

Loading…
Cancel
Save