feature: boolean logic masks

Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{*1.5}` where `+` can be any of `+ - * /`. Single-character boolean
operators also work.  When writing strength modifies know that pixel values are between 0 and 1.

 - feature: apply mask edits to original files
 - feature: auto-rotate images if exif data specifies to do so
 - fix: accept mask images in command line
This commit is contained in:
Bryce 2022-09-23 22:58:48 -07:00 committed by Bryce Drennan
parent d090f9d072
commit 38c7f88950
21 changed files with 463 additions and 162 deletions

View File

@ -35,9 +35,18 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000019_786355545_PLMS50_PS7.5_a_scenic_landscape.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000032_337692011_PLMS40_PS7.5_a_photo_of_a_dog.jpg" height="256"><br>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000078_260972468_PLMS40_PS7.5_portrait_photo_of_a_freckled_woman.jpg" height="256">
### Automated Replacement (txt2mask) [by clipseg](https://github.com/timojl/clipseg)
### Prompt Based Editing [by clipseg](https://github.com/timojl/clipseg)
Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{*1.5}` where `+` can be any of `+ - * /`. Single-character boolean
operators also work. When writing strength modifies know that pixel values are between 0 and 1.
```bash
>> imagine --init-image pearl_earring.jpg --mask-prompt face --mask-mode keep --init-image-strength .4 "a female doctor" "an elegant woman"
>> imagine \
--init-image pearl_earring.jpg \
--mask-prompt "face{*1.9}" \
--mask-mode keep \
--init-image-strength .4 \
"a female doctor" "an elegant woman"
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl000.jpg" height="200">➡️
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl002.jpg" height="200">
@ -45,7 +54,12 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl001.jpg" height="200">
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl003.jpg" height="200">
```bash
>> imagine --init-image fruit-bowl.jpg --mask-prompt fruit --mask-mode replace --init-image-strength .1 "a bowl of pears" "a bowl of gold" "a bowl of popcorn" "a bowl of spaghetti"
>> imagine \
--init-image fruit-bowl.jpg \
--mask-prompt "fruit OR fruit stem{*1.5}" \
--mask-mode replace \
--init-image-strength .1 \
"a bowl of kittens" "a bowl of gold coins" "a bowl of popcorn" "a bowl of spaghetti"
```
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="200">➡️
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/bowl004.jpg" height="200">
@ -130,9 +144,8 @@ prompts = [
ImaginePrompt(
"a bowl of strawberries",
init_image=LazyLoadingImage(filepath="mypath/to/bowl_of_fruit.jpg"),
mask_prompt="fruit|stems",
mask_prompt="fruit OR stem{*2}", # amplify the stem mask x2
mask_mode="replace",
mask_expansion=3
),
ImaginePrompt("strawberries", tile_mode=True),
]
@ -167,6 +180,13 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
## ChangeLog
- feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{+0.1}` where `+` can be any of `+ - * /`
- feature: apply mask edits to original files
- feature: auto-rotate images if exif data specifies to do so
- fix: accept mask images in command line
**1.6.2**
- fix: another bfloat16 fix
@ -214,6 +234,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
- training
## Todo
- refactor how output versions are selected (upscaled, modified original, etc)
- performance optimizations
- ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx
- ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements#

Binary file not shown.

Before

Width:  |  Height:  |  Size: 33 KiB

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

After

Width:  |  Height:  |  Size: 25 KiB

View File

@ -4,6 +4,7 @@ import re
from functools import lru_cache
import numpy as np
import PIL
import torch
import torch.nn
from einops import rearrange
@ -22,17 +23,15 @@ from imaginairy.img_log import (
log_img,
log_latent,
)
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
from imaginairy.safety import is_nsfw
from imaginairy.samplers.base import get_sampler
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
expand_mask,
fix_torch_group_norm,
fix_torch_nn_layer_norm,
get_device,
instantiate_from_config,
pillow_fit_image_within,
pillow_img_to_torch_image,
platform_appropriate_autocast,
)
@ -92,8 +91,10 @@ def imagine_image_files(
record_step_images=False,
output_file_extension="jpg",
print_caption=False,
create_modified_originals_for_masks=True,
):
big_path = os.path.join(outdir, "upscaled")
masked_orig_path = os.path.join(outdir, "modified_originals")
os.makedirs(outdir, exist_ok=True)
base_count = len(os.listdir(outdir))
@ -119,6 +120,7 @@ def imagine_image_files(
ddim_eta=ddim_eta,
img_callback=_record_step if record_step_images else None,
add_caption=print_caption,
create_modified_originals_for_masks=create_modified_originals_for_masks,
):
prompt = result.prompt
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
@ -131,6 +133,11 @@ def imagine_image_files(
bigfilepath = os.path.join(big_path, basefilename) + "_upscaled.jpg"
result.save_upscaled(bigfilepath)
logger.info(f" Upscaled 🖼 saved to: {bigfilepath}")
if result.modified_original_img:
os.makedirs(masked_orig_path, exist_ok=True)
bigfilepath = os.path.join(masked_orig_path, basefilename) + "_modified.jpg"
result.save_modified_orig(bigfilepath)
logger.info(f" Modified original 🖼 saved to: {bigfilepath}")
base_count += 1
del result
@ -144,6 +151,7 @@ def imagine(
img_callback=None,
half_mode=None,
add_caption=False,
create_modified_originals_for_masks=True,
):
model = load_model()
@ -197,30 +205,34 @@ def imagine(
logger.info(" Sampler type switched to plms for img2img")
else:
sampler_type = prompt.sampler_type
start_code = None
sampler = get_sampler(sampler_type, model)
mask, mask_image, mask_image_orig = None, None, None
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
init_image, _, h = pillow_fit_image_within(
prompt.init_image,
max_height=prompt.height,
max_width=prompt.width,
)
try:
init_image, _, h = pillow_fit_image_within(
prompt.init_image,
max_height=prompt.height,
max_width=prompt.width,
)
except PIL.UnidentifiedImageError:
logger.warning(f" Could not load image: {prompt.init_image}")
continue
init_image_t = pillow_img_to_torch_image(init_image)
if prompt.mask_prompt:
mask_image = get_img_mask(init_image, prompt.mask_prompt)
mask_image = get_img_mask(
init_image, prompt.mask_prompt, threshold=0.1
)
elif prompt.mask_image:
mask_image = prompt.mask_image
mask_image = prompt.mask_image.convert("L")
if mask_image is not None:
log_img(mask_image, "init mask")
mask_image = expand_mask(mask_image, prompt.mask_expansion)
log_img(mask_image, "init mask expanded")
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image)
@ -236,13 +248,11 @@ def imagine(
),
resample=Image.Resampling.NEAREST,
)
log_img(mask_image, "init mask 2")
log_img(mask_image, "latent_mask")
mask = np.array(mask_image)
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.9] = 0
mask[mask >= 0.9] = 1
mask = torch.from_numpy(mask)
mask = mask.to(get_device())
@ -294,8 +304,6 @@ def imagine(
x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig)
if mask_image_orig and init_image:
mask_image_orig = expand_mask(mask_image_orig, -3)
mask_image_orig = mask_image_orig.filter(
ImageFilter.GaussianBlur(radius=3)
)
@ -305,6 +313,7 @@ def imagine(
log_img(img, "reconstituted image")
upscaled_img = None
rebuilt_orig_img = None
is_nsfw_img = None
if add_caption:
caption = generate_caption(img)
@ -325,11 +334,43 @@ def imagine(
logger.info(" Fixing 😊 's in big 🖼 using CodeFormer...")
upscaled_img = enhance_faces(upscaled_img, fidelity=0.8)
# put the newly generated patch back into the original, full size image
if (
create_modified_originals_for_masks
and mask_image_orig
and prompt.init_image
):
img_to_add_back_to_original = (
upscaled_img if upscaled_img else img
)
img_to_add_back_to_original = (
img_to_add_back_to_original.resize(
prompt.init_image.size,
resample=Image.Resampling.LANCZOS,
)
)
mask_for_orig_size = mask_image_orig.resize(
prompt.init_image.size, resample=Image.Resampling.LANCZOS
)
mask_for_orig_size = mask_for_orig_size.filter(
ImageFilter.GaussianBlur(radius=5)
)
log_img(mask_for_orig_size, "mask for original image size")
rebuilt_orig_img = Image.composite(
img_to_add_back_to_original,
prompt.init_image,
mask_for_orig_size,
)
log_img(rebuilt_orig_img, "reconstituted original")
yield ImagineResult(
img=img,
prompt=prompt,
upscaled_img=upscaled_img,
is_nsfw=is_nsfw_img,
modified_original_img=rebuilt_orig_img,
)

View File

@ -129,7 +129,14 @@ def configure_logging(level="INFO"):
)
@click.option(
"--mask-prompt",
help="Describe what you want masked and the AI will mask it for you",
help=(
"Describe what you want masked and the AI will mask it for you. "
"You can describe complex masks with AND, OR, NOT keywords and parentheses. "
"The strength of each mask can be modified with {*1.5} notation. \n\n"
"Examples: \n"
"car AND (wheels{*1.1} OR trunk OR engine OR windows OR headlights) AND NOT (truck OR headlights){*10}\n"
"fruit|fruit stem"
),
)
@click.option(
"--mask-mode",
@ -137,12 +144,6 @@ def configure_logging(level="INFO"):
type=click.Choice(["keep", "replace"]),
help="Should we replace the masked area or keep it?",
)
@click.option(
"--mask-expansion",
default="2",
type=int,
help="How much to grow (or shrink) the mask area",
)
@click.option(
"--caption",
default=False,
@ -178,7 +179,6 @@ def imagine_cmd(
mask_image,
mask_prompt,
mask_mode,
mask_expansion,
caption,
precision,
):
@ -196,6 +196,9 @@ def imagine_cmd(
if init_image and init_image.startswith("http"):
init_image = LazyLoadingImage(url=init_image)
if mask_image and mask_image.startswith("http"):
mask_image = LazyLoadingImage(url=mask_image)
prompts = []
load_model()
for _ in range(repeats):
@ -212,7 +215,6 @@ def imagine_cmd(
width=width,
mask_image=mask_image,
mask_prompt=mask_prompt,
mask_expansion=mask_expansion,
mask_mode=mask_mode,
upscale=upscale,
fix_faces=fix_faces,

View File

@ -0,0 +1,163 @@
# pylama:ignore=W0613
"""
Logic for parsing mask prompts.
Supports
lower case text descriptions
Combinations: AND OR NOT ()
Strength Modifiers: {<operator><number>}
Examples:
fruit
fruit bowl
fruit AND NOT pears
fruit OR bowl
(pears OR oranges OR peaches){*1.5}
fruit{-0.1} OR bowl
"""
import operator
from abc import ABC
import pyparsing as pp
import torch
from pyparsing import ParserElement
ParserElement.enablePackrat()
class Mask(ABC):
def get_mask_for_image(self, img):
pass
def gather_text_descriptions(self):
return set()
def apply_masks(self, mask_cache):
pass
class SimpleMask(Mask):
def __init__(self, text):
self.text = text
@classmethod
def from_simple_prompt(cls, instring, tokens_start, ret_tokens):
return cls(text=ret_tokens[0])
def __repr__(self):
return f"'{self.text}'"
def gather_text_descriptions(self):
return {self.text}
def apply_masks(self, mask_cache):
return mask_cache[self.text]
class ModifiedMask(Mask):
ops = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.truediv,
# '%': operator.mod,
# '^': operator.xor,
}
def __init__(self, mask, modifier):
if modifier:
modifier = modifier.strip("{}")
self.mask = mask
self.modifier = modifier
self.operand = self.ops[modifier[0]]
self.value = float(modifier[1:])
@classmethod
def from_modifier_parse(cls, instring, tokens_start, ret_tokens):
return cls(mask=ret_tokens[0][0], modifier=ret_tokens[0][1])
def __repr__(self):
return f"{repr(self.mask)}{self.modifier}"
def gather_text_descriptions(self):
return self.mask.gather_text_descriptions()
def apply_masks(self, mask_cache):
mask = self.mask.apply_masks(mask_cache)
return torch.clamp(self.operand(mask, self.value), 0, 1)
class NestedMask(Mask):
def __init__(self, masks, op):
self.masks = masks
self.op = op
@classmethod
def from_or(cls, instring, tokens_start, ret_tokens):
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
return cls(masks=sub_masks, op="OR")
@classmethod
def from_and(cls, instring, tokens_start, ret_tokens):
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
return cls(masks=sub_masks, op="AND")
@classmethod
def from_not(cls, instring, tokens_start, ret_tokens):
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
assert len(sub_masks) == 1
return cls(masks=sub_masks, op="NOT")
def __repr__(self):
if self.op == "NOT":
return f"NOT {self.masks[0]}"
sub = f" {self.op} ".join(repr(m) for m in self.masks)
return f"({sub})"
def gather_text_descriptions(self):
return set().union(*[m.gather_text_descriptions() for m in self.masks])
def apply_masks(self, mask_cache):
submasks = [m.apply_masks(mask_cache) for m in self.masks]
mask = submasks[0]
if self.op == "OR":
for submask in submasks:
mask = torch.maximum(mask, submask)
elif self.op == "AND":
for submask in submasks:
mask = torch.minimum(mask, submask)
elif self.op == "NOT":
mask = 1 - mask
else:
raise ValueError(f"Invalid operand {self.op}")
return torch.clamp(mask, 0, 1)
AND = (pp.Literal("AND") | pp.Literal("&")).setName("AND").setResultsName("op")
OR = (pp.Literal("OR") | pp.Literal("|")).setName("OR").setResultsName("op")
NOT = (pp.Literal("NOT") | pp.Literal("!")).setName("NOT").setResultsName("op")
PROMPT_MODIFIER = (
pp.Regex(r"{[*/+-]\d+\.?\d*}")
.setName("prompt_modifier")
.setResultsName("prompt_modifier")
)
PROMPT_TEXT = (
pp.Regex(r"[a-z0-9]?[a-z0-9 -]*[a-z0-9]")
.setName("prompt_text")
.setResultsName("prompt_text")
)
SIMPLE_PROMPT = PROMPT_TEXT.setResultsName("simplePrompt")
SIMPLE_PROMPT.setParseAction(SimpleMask.from_simple_prompt)
COMPLEX_PROMPT = pp.infixNotation(
SIMPLE_PROMPT,
[
(PROMPT_MODIFIER, 1, pp.opAssoc.LEFT, ModifiedMask.from_modifier_parse),
(NOT, 1, pp.opAssoc.RIGHT, NestedMask.from_not),
(AND, 2, pp.opAssoc.LEFT, NestedMask.from_and),
(OR, 2, pp.opAssoc.LEFT, NestedMask.from_or),
],
)
MASK_PROMPT = pp.Group(COMPLEX_PROMPT).setResultsName("complexPrompt")

View File

@ -1,5 +1,7 @@
from functools import lru_cache
from typing import Optional, Sequence
import PIL.Image
import torch
from torchvision import transforms
@ -26,33 +28,26 @@ def clip_mask_model():
return model
def get_img_mask(img, mask_description, negative_description=""):
pos_descriptions = mask_description.split("|")
pos_masks = get_img_masks(img, pos_descriptions)
pos_mask = pos_masks[0]
for pred in pos_masks:
pos_mask = torch.maximum(pos_mask, pred)
def get_img_mask(
img: PIL.Image.Image,
mask_description_statement: str,
threshold: Optional[float] = None,
):
from imaginairy.enhancers.bool_masker import MASK_PROMPT # noqa
log_img(pos_mask, "pos mask")
if negative_description:
neg_descriptions = negative_description.split("|")
neg_masks = get_img_masks(img, neg_descriptions)
neg_mask = neg_masks[0]
for pred in neg_masks:
neg_mask = torch.maximum(neg_mask, pred)
neg_mask = (neg_mask * 3).clip(0, 1)
log_img(neg_mask, "neg_mask")
pos_mask = torch.minimum(pos_mask, 1 - neg_mask)
_min = pos_mask.min()
_max = pos_mask.max()
_range = _max - _min
pos_mask = (pos_mask > (_min + (_range * 0.35))).float()
return transforms.ToPILImage()(pos_mask)
parsed = MASK_PROMPT.parseString(mask_description_statement)
parsed_mask = parsed[0][0]
descriptions = list(parsed_mask.gather_text_descriptions())
mask_cache = get_img_masks(img, descriptions)
mask = parsed_mask.apply_masks(mask_cache)
log_img(mask, "combined mask")
if threshold is not None:
mask[mask < threshold] = 0
mask[mask >= threshold] = 1
return transforms.ToPILImage()(mask)
def get_img_masks(img, mask_descriptions):
def get_img_masks(img, mask_descriptions: Sequence[str]):
a, b = img.size
orig_size = b, a
log_img(img, "image for masking")
@ -75,11 +70,9 @@ def get_img_masks(img, mask_descriptions):
preds = [torch.sigmoid(p[0]) for p in preds]
bw_preds = []
preds_dict = {}
for p, desc in zip(preds, mask_descriptions):
log_img(p, f"clip mask: {desc}")
# bw_preds.append(pred_transform(p))
preds_dict[desc] = p
bw_preds.append(p)
return bw_preds
return preds_dict

View File

@ -1,12 +1,11 @@
import logging
import re
import numpy as np
import torch
from einops import rearrange
from PIL import Image
from torchvision.transforms import ToPILImage
from imaginairy.img_utils import model_latents_to_pillow_imgs
_CURRENT_LOGGING_CONTEXT = None
logger = logging.getLogger(__name__)
@ -64,11 +63,7 @@ class ImageLoggingContext:
return
self.step_count += 1
description = f"{description} - {latents.shape}"
latents = self.model.decode_first_stage(latents)
latents = torch.clamp((latents + 1.0) / 2.0, min=0.0, max=1.0)
for latent in latents:
latent = 255.0 * rearrange(latent.cpu().numpy(), "c h w -> h w c")
img = Image.fromarray(latent.astype(np.uint8))
for img in model_latents_to_pillow_imgs(latents):
self.img_callback(img, description, self.step_count, self.prompt)
def log_img(self, img, description):

59
imaginairy/img_utils.py Normal file
View File

@ -0,0 +1,59 @@
from typing import Sequence
import numpy as np
import PIL
import torch
from einops import rearrange, repeat
from PIL import Image
from imaginairy.utils import get_device
def pillow_fit_image_within(image: PIL.Image.Image, max_height=512, max_width=512):
image = image.convert("RGB")
w, h = image.size
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.NEAREST)
return image, w, h
def pillow_img_to_torch_image(img: PIL.Image.Image):
img = img.convert("RGB")
img = np.array(img).astype(np.float32) / 255.0
img = img[None].transpose(0, 3, 1, 2)
img = torch.from_numpy(img)
return 2.0 * img - 1.0
def pillow_img_to_opencv_img(img: PIL.Image.Image):
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
return open_cv_image
def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Image]:
from imaginairy.api import load_model # noqa
model = load_model()
latents = model.decode_first_stage(latents)
latents = torch.clamp((latents + 1.0) / 2.0, min=0.0, max=1.0)
imgs = []
for latent in latents:
latent = 255.0 * rearrange(latent.cpu().numpy(), "c h w -> h w c")
img = Image.fromarray(latent.astype(np.uint8))
imgs.append(img)
return imgs
def pillow_img_to_model_latent(model, img, batch_size=1, half=True):
# init_image = pil_img_to_torch(img, half=half).to(device)
init_image = pillow_img_to_torch_image(img).to(get_device())
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
if half:
return model.get_first_stage_encoding(
model.encode_first_stage(init_image.half())
)
return model.get_first_stage_encoding(model.encode_first_stage(init_image))

View File

@ -9,26 +9,15 @@ needs https://github.com/crowsonkb/k-diffusion
from contextlib import nullcontext
import torch
from einops import repeat
from torch import autocast
from imaginairy.utils import get_device, pillow_img_to_torch_image
from imaginairy.img_utils import pillow_img_to_model_latent
from imaginairy.utils import get_device
from imaginairy.vendored import k_diffusion as K
def pil_img_to_latent(model, img, batch_size=1, half=True):
# init_image = pil_img_to_torch(img, half=half).to(device)
init_image = pillow_img_to_torch_image(img).to(get_device())
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
if half:
return model.get_first_stage_encoding(
model.encode_first_stage(init_image.half())
)
return model.get_first_stage_encoding(model.encode_first_stage(init_image))
def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=True):
img_latent = pil_img_to_latent(
img_latent = pillow_img_to_model_latent(
model, pil_img, batch_size=1, device="cuda", half=half
)
return find_noise_for_latent(

View File

@ -1,3 +1,4 @@
# pylama:ignore=W0613
"""SAMPLING ONLY."""
import logging

View File

@ -6,9 +6,8 @@ import random
from datetime import datetime, timezone
from functools import lru_cache
import numpy
import requests
from PIL import Image
from PIL import Image, ImageOps
from urllib3.exceptions import LocationParseError
from urllib3.util import parse_url
@ -64,6 +63,8 @@ class LazyLoadingImage:
logger.info(
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_url}"
)
# fix orientation
self._img = ImageOps.exif_transpose(self._img)
return getattr(self._img, key)
@ -94,7 +95,6 @@ class ImaginePrompt:
mask_prompt=None,
mask_image=None,
mask_mode=MaskMode.REPLACE,
mask_expansion=2,
seed=None,
steps=50,
height=512,
@ -114,7 +114,7 @@ class ImaginePrompt:
self.prompt_strength = prompt_strength
if isinstance(init_image, str):
init_image = LazyLoadingImage(filepath=init_image)
if isinstance(mask_image, str):
mask_image = LazyLoadingImage(filepath=mask_image)
@ -134,7 +134,6 @@ class ImaginePrompt:
self.mask_prompt = mask_prompt
self.mask_image = mask_image
self.mask_mode = mask_mode
self.mask_expansion = mask_expansion
self.tile_mode = tile_mode
@property
@ -178,22 +177,23 @@ class ExifCodes:
class ImagineResult:
def __init__(self, img, prompt: ImaginePrompt, is_nsfw, upscaled_img=None):
def __init__(
self,
img,
prompt: ImaginePrompt,
is_nsfw,
upscaled_img=None,
modified_original_img=None,
):
self.img = img
self.upscaled_img = upscaled_img
self.modified_original_img = modified_original_img
self.prompt = prompt
self.is_nsfw = is_nsfw
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
self.torch_backend = get_device()
self.hardware_name = get_device_name(get_device())
def cv2_img(self):
open_cv_image = numpy.array(self.img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
return open_cv_image
# return cv2.cvtColor(numpy.array(self.img), cv2.COLOR_RGB2BGR)
def md5(self):
return hashlib.md5(self.img.tobytes()).hexdigest()
@ -218,6 +218,9 @@ class ImagineResult:
def save_upscaled(self, save_path):
self.upscaled_img.save(save_path, exif=self._exif())
def save_modified_orig(self, save_path):
self.modified_original_img.save(save_path, exif=self._exif())
@lru_cache(maxsize=2)
def _get_briefly_cached_url(url):

View File

@ -6,17 +6,13 @@ from contextlib import contextmanager, nullcontext
from functools import lru_cache
from typing import List, Optional
import numpy as np
import requests
import torch
from PIL import Image, ImageFilter
from torch import Tensor, autocast
from torch.nn import functional
from torch.overrides import handle_torch_function, has_torch_function_variadic
from transformers import cached_path
from imaginairy.img_log import log_img
logger = logging.getLogger(__name__)
@ -64,7 +60,7 @@ def get_obj_from_str(string, reload=False):
@contextmanager
def platform_appropriate_autocast(precision="autocast"):
"""
allow calculations to run in mixed precision, which can be faster
Allow calculations to run in mixed precision, which can be faster
"""
precision_scope = nullcontext
if precision == "autocast" and get_device() in ("cuda", "cpu"):
@ -153,47 +149,6 @@ def fix_torch_group_norm():
functional.group_norm = orig_group_norm
def expand_mask(mask_image, size):
if size < 0:
threshold = 0.95
else:
threshold = 0.05
mask_image = mask_image.convert("L")
mask_image = mask_image.filter(ImageFilter.GaussianBlur(size))
log_img(mask_image, "init mask blurred")
mask = np.array(mask_image)
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < threshold] = 0
mask[mask >= threshold] = 1
return Image.fromarray(np.uint8(mask.squeeze() * 255))
def img_path_to_torch_image(path):
image = Image.open(path).convert("RGB")
logger.info(f"Loaded input 🖼 of size {image.size} from {path}")
return pillow_img_to_torch_image(image)
def pillow_fit_image_within(image, max_height=512, max_width=512):
image = image.convert("RGB")
w, h = image.size
resize_ratio = min(max_width / w, max_height / h)
w, h = int(w * resize_ratio), int(h * resize_ratio)
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.NEAREST)
return image, w, h
def pillow_img_to_torch_image(image):
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def get_cache_dir():
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
if xdg_cache_home is None:

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

View File

@ -4,6 +4,8 @@ import pytest
from PIL import Image
from pytorch_lightning import seed_everything
from imaginairy import ImaginePrompt, imagine
from imaginairy.enhancers.bool_masker import MASK_PROMPT
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.enhancers.describe_image_clip import find_img_text_similarity
@ -28,9 +30,98 @@ def img_hash(img):
def test_clip_masking():
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
pred = get_img_mask(img, "head")
pred.save(f"{TESTS_FOLDER}/test_output/earring_mask.png")
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring_large.jpg")
for mask_modifier in [
"*0.5",
"*1",
"*10",
]:
pred = get_img_mask(img, f"(head OR face){{{mask_modifier}}}")
pred.save(f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}.png")
prompt = ImaginePrompt(
"professional photo of a woman",
init_image=img,
init_image_strength=0.95,
# lower steps for faster tests
# steps=40,
steps=4,
mask_prompt="(head OR face)*5",
mask_mode="replace",
upscale=True,
fix_faces=True,
)
result = next(imagine(prompt))
result.modified_original_img.save(
f"{TESTS_FOLDER}/test_output/earring_mask_photo.png"
)
boolean_mask_test_cases = [
(
"fruit bowl",
"'fruit bowl'",
),
(
"((((fruit bowl))))",
"'fruit bowl'",
),
(
"fruit OR bowl",
"('fruit' OR 'bowl')",
),
(
"fruit|bowl",
"('fruit' OR 'bowl')",
),
(
"fruit | bowl",
"('fruit' OR 'bowl')",
),
(
"fruit OR bowl OR pear",
"('fruit' OR 'bowl' OR 'pear')",
),
(
"fruit AND bowl",
"('fruit' AND 'bowl')",
),
(
"fruit & bowl",
"('fruit' AND 'bowl')",
),
(
"fruit AND NOT green",
"('fruit' AND NOT 'green')",
),
(
"fruit bowl{+0.5}",
"'fruit bowl'+0.5",
),
(
"fruit bowl{+0.5} OR fruit",
"('fruit bowl'+0.5 OR 'fruit')",
),
(
"NOT pizza",
"NOT 'pizza'",
),
(
"car AND (wheels OR trunk OR engine OR windows) AND NOT (truck OR headlights{*10})",
"('car' AND ('wheels' OR 'trunk' OR 'engine' OR 'windows') AND NOT ('truck' OR 'headlights'*10))",
),
(
"car AND (wheels OR trunk OR engine OR windows OR headlights) AND NOT (truck OR headlights){*10}",
"('car' AND ('wheels' OR 'trunk' OR 'engine' OR 'windows' OR 'headlights') AND NOT ('truck' OR 'headlights')*10)",
),
]
@pytest.mark.parametrize("mask_text,expected", boolean_mask_test_cases)
def test_clip_mask_parser(mask_text, expected):
parsed = MASK_PROMPT.parseString(mask_text)[0][0]
assert str(parsed) == expected
def test_describe_picture():

View File

@ -6,9 +6,10 @@ from PIL import ImageDraw
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine, imagine_image_files
from imaginairy.api import load_model
from imaginairy.img_log import ImageLoggingContext, filesafe_text, log_latent
from imaginairy.img_utils import pillow_img_to_torch_image
from imaginairy.modules.clip_embedders import FrozenCLIPEmbedder
from imaginairy.samplers.ddim import DDIMSampler
from imaginairy.utils import get_device, pillow_img_to_torch_image
from imaginairy.utils import get_device
from tests import TESTS_FOLDER

View File

@ -140,9 +140,8 @@ def test_cliptext_inpainting():
prompt_strength=12,
init_image=f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg",
init_image_strength=0.3,
mask_prompt="face",
mask_prompt="face{*2}",
mask_mode=ImaginePrompt.MaskMode.KEEP,
mask_expansion=-3,
width=512,
height=512,
steps=5,

View File

@ -1,25 +1,12 @@
from PIL import Image
from imaginairy.api import load_model
from imaginairy.safety import is_nsfw
from imaginairy.utils import get_device, pillow_img_to_torch_image
from tests import TESTS_FOLDER
def test_is_nsfw():
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
latent = _pil_to_latent(img)
assert is_nsfw(img)
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
latent = _pil_to_latent(img)
assert not is_nsfw(img)
def _pil_to_latent(img):
model = load_model()
model.tile_mode(False)
img = pillow_img_to_torch_image(img)
img = img.to(get_device())
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
return latent

View File

@ -12,7 +12,7 @@ skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads
linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy
ignore =
Z999,C0103,C0301,C0114,C0115,C0116,
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D415,
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D415,
Z999,E501,E1101,
Z999,R0901,R0902,R0903,R0193,R0912,R0913,R0914,R0915,
Z999,W0221,W0511,W1203