feature: boolean logic masks
Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase. Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{*1.5}` where `+` can be any of `+ - * /`. Single-character boolean operators also work. When writing strength modifies know that pixel values are between 0 and 1. - feature: apply mask edits to original files - feature: auto-rotate images if exif data specifies to do so - fix: accept mask images in command line
32
README.md
@ -35,9 +35,18 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000019_786355545_PLMS50_PS7.5_a_scenic_landscape.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000032_337692011_PLMS40_PS7.5_a_photo_of_a_dog.jpg" height="256"><br>
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000078_260972468_PLMS40_PS7.5_portrait_photo_of_a_freckled_woman.jpg" height="256">
|
||||
|
||||
### Automated Replacement (txt2mask) [by clipseg](https://github.com/timojl/clipseg)
|
||||
### Prompt Based Editing [by clipseg](https://github.com/timojl/clipseg)
|
||||
Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
|
||||
Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{*1.5}` where `+` can be any of `+ - * /`. Single-character boolean
|
||||
operators also work. When writing strength modifies know that pixel values are between 0 and 1.
|
||||
|
||||
```bash
|
||||
>> imagine --init-image pearl_earring.jpg --mask-prompt face --mask-mode keep --init-image-strength .4 "a female doctor" "an elegant woman"
|
||||
>> imagine \
|
||||
--init-image pearl_earring.jpg \
|
||||
--mask-prompt "face{*1.9}" \
|
||||
--mask-mode keep \
|
||||
--init-image-strength .4 \
|
||||
"a female doctor" "an elegant woman"
|
||||
```
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl000.jpg" height="200">➡️
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl002.jpg" height="200">
|
||||
@ -45,7 +54,12 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl001.jpg" height="200">
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl003.jpg" height="200">
|
||||
```bash
|
||||
>> imagine --init-image fruit-bowl.jpg --mask-prompt fruit --mask-mode replace --init-image-strength .1 "a bowl of pears" "a bowl of gold" "a bowl of popcorn" "a bowl of spaghetti"
|
||||
>> imagine \
|
||||
--init-image fruit-bowl.jpg \
|
||||
--mask-prompt "fruit OR fruit stem{*1.5}" \
|
||||
--mask-mode replace \
|
||||
--init-image-strength .1 \
|
||||
"a bowl of kittens" "a bowl of gold coins" "a bowl of popcorn" "a bowl of spaghetti"
|
||||
```
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="200">➡️
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/bowl004.jpg" height="200">
|
||||
@ -130,9 +144,8 @@ prompts = [
|
||||
ImaginePrompt(
|
||||
"a bowl of strawberries",
|
||||
init_image=LazyLoadingImage(filepath="mypath/to/bowl_of_fruit.jpg"),
|
||||
mask_prompt="fruit|stems",
|
||||
mask_prompt="fruit OR stem{*2}", # amplify the stem mask x2
|
||||
mask_mode="replace",
|
||||
mask_expansion=3
|
||||
),
|
||||
ImaginePrompt("strawberries", tile_mode=True),
|
||||
]
|
||||
@ -167,6 +180,13 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
||||
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
|
||||
|
||||
## ChangeLog
|
||||
|
||||
- feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
|
||||
Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{+0.1}` where `+` can be any of `+ - * /`
|
||||
- feature: apply mask edits to original files
|
||||
- feature: auto-rotate images if exif data specifies to do so
|
||||
- fix: accept mask images in command line
|
||||
|
||||
**1.6.2**
|
||||
- fix: another bfloat16 fix
|
||||
|
||||
@ -214,6 +234,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
||||
- training
|
||||
|
||||
## Todo
|
||||
|
||||
- refactor how output versions are selected (upscaled, modified original, etc)
|
||||
- performance optimizations
|
||||
- ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx
|
||||
- ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements#
|
||||
|
Before Width: | Height: | Size: 33 KiB After Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 34 KiB After Width: | Height: | Size: 33 KiB |
Before Width: | Height: | Size: 30 KiB After Width: | Height: | Size: 30 KiB |
Before Width: | Height: | Size: 25 KiB After Width: | Height: | Size: 25 KiB |
@ -4,6 +4,7 @@ import re
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn
|
||||
from einops import rearrange
|
||||
@ -22,17 +23,15 @@ from imaginairy.img_log import (
|
||||
log_img,
|
||||
log_latent,
|
||||
)
|
||||
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
||||
from imaginairy.safety import is_nsfw
|
||||
from imaginairy.samplers.base import get_sampler
|
||||
from imaginairy.schema import ImaginePrompt, ImagineResult
|
||||
from imaginairy.utils import (
|
||||
expand_mask,
|
||||
fix_torch_group_norm,
|
||||
fix_torch_nn_layer_norm,
|
||||
get_device,
|
||||
instantiate_from_config,
|
||||
pillow_fit_image_within,
|
||||
pillow_img_to_torch_image,
|
||||
platform_appropriate_autocast,
|
||||
)
|
||||
|
||||
@ -92,8 +91,10 @@ def imagine_image_files(
|
||||
record_step_images=False,
|
||||
output_file_extension="jpg",
|
||||
print_caption=False,
|
||||
create_modified_originals_for_masks=True,
|
||||
):
|
||||
big_path = os.path.join(outdir, "upscaled")
|
||||
masked_orig_path = os.path.join(outdir, "modified_originals")
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
base_count = len(os.listdir(outdir))
|
||||
@ -119,6 +120,7 @@ def imagine_image_files(
|
||||
ddim_eta=ddim_eta,
|
||||
img_callback=_record_step if record_step_images else None,
|
||||
add_caption=print_caption,
|
||||
create_modified_originals_for_masks=create_modified_originals_for_masks,
|
||||
):
|
||||
prompt = result.prompt
|
||||
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
|
||||
@ -131,6 +133,11 @@ def imagine_image_files(
|
||||
bigfilepath = os.path.join(big_path, basefilename) + "_upscaled.jpg"
|
||||
result.save_upscaled(bigfilepath)
|
||||
logger.info(f" Upscaled 🖼 saved to: {bigfilepath}")
|
||||
if result.modified_original_img:
|
||||
os.makedirs(masked_orig_path, exist_ok=True)
|
||||
bigfilepath = os.path.join(masked_orig_path, basefilename) + "_modified.jpg"
|
||||
result.save_modified_orig(bigfilepath)
|
||||
logger.info(f" Modified original 🖼 saved to: {bigfilepath}")
|
||||
base_count += 1
|
||||
del result
|
||||
|
||||
@ -144,6 +151,7 @@ def imagine(
|
||||
img_callback=None,
|
||||
half_mode=None,
|
||||
add_caption=False,
|
||||
create_modified_originals_for_masks=True,
|
||||
):
|
||||
model = load_model()
|
||||
|
||||
@ -197,30 +205,34 @@ def imagine(
|
||||
logger.info(" Sampler type switched to plms for img2img")
|
||||
else:
|
||||
sampler_type = prompt.sampler_type
|
||||
start_code = None
|
||||
|
||||
sampler = get_sampler(sampler_type, model)
|
||||
mask, mask_image, mask_image_orig = None, None, None
|
||||
if prompt.init_image:
|
||||
generation_strength = 1 - prompt.init_image_strength
|
||||
ddim_steps = int(prompt.steps / generation_strength)
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
|
||||
init_image, _, h = pillow_fit_image_within(
|
||||
prompt.init_image,
|
||||
max_height=prompt.height,
|
||||
max_width=prompt.width,
|
||||
)
|
||||
try:
|
||||
init_image, _, h = pillow_fit_image_within(
|
||||
prompt.init_image,
|
||||
max_height=prompt.height,
|
||||
max_width=prompt.width,
|
||||
)
|
||||
except PIL.UnidentifiedImageError:
|
||||
logger.warning(f" Could not load image: {prompt.init_image}")
|
||||
continue
|
||||
|
||||
init_image_t = pillow_img_to_torch_image(init_image)
|
||||
|
||||
if prompt.mask_prompt:
|
||||
mask_image = get_img_mask(init_image, prompt.mask_prompt)
|
||||
mask_image = get_img_mask(
|
||||
init_image, prompt.mask_prompt, threshold=0.1
|
||||
)
|
||||
elif prompt.mask_image:
|
||||
mask_image = prompt.mask_image
|
||||
mask_image = prompt.mask_image.convert("L")
|
||||
|
||||
if mask_image is not None:
|
||||
log_img(mask_image, "init mask")
|
||||
mask_image = expand_mask(mask_image, prompt.mask_expansion)
|
||||
log_img(mask_image, "init mask expanded")
|
||||
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
|
||||
mask_image = ImageOps.invert(mask_image)
|
||||
|
||||
@ -236,13 +248,11 @@ def imagine(
|
||||
),
|
||||
resample=Image.Resampling.NEAREST,
|
||||
)
|
||||
log_img(mask_image, "init mask 2")
|
||||
log_img(mask_image, "latent_mask")
|
||||
|
||||
mask = np.array(mask_image)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.9] = 0
|
||||
mask[mask >= 0.9] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.to(get_device())
|
||||
|
||||
@ -294,8 +304,6 @@ def imagine(
|
||||
x_sample_8_orig = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample_8_orig)
|
||||
if mask_image_orig and init_image:
|
||||
|
||||
mask_image_orig = expand_mask(mask_image_orig, -3)
|
||||
mask_image_orig = mask_image_orig.filter(
|
||||
ImageFilter.GaussianBlur(radius=3)
|
||||
)
|
||||
@ -305,6 +313,7 @@ def imagine(
|
||||
log_img(img, "reconstituted image")
|
||||
|
||||
upscaled_img = None
|
||||
rebuilt_orig_img = None
|
||||
is_nsfw_img = None
|
||||
if add_caption:
|
||||
caption = generate_caption(img)
|
||||
@ -325,11 +334,43 @@ def imagine(
|
||||
logger.info(" Fixing 😊 's in big 🖼 using CodeFormer...")
|
||||
upscaled_img = enhance_faces(upscaled_img, fidelity=0.8)
|
||||
|
||||
# put the newly generated patch back into the original, full size image
|
||||
if (
|
||||
create_modified_originals_for_masks
|
||||
and mask_image_orig
|
||||
and prompt.init_image
|
||||
):
|
||||
img_to_add_back_to_original = (
|
||||
upscaled_img if upscaled_img else img
|
||||
)
|
||||
img_to_add_back_to_original = (
|
||||
img_to_add_back_to_original.resize(
|
||||
prompt.init_image.size,
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
)
|
||||
|
||||
mask_for_orig_size = mask_image_orig.resize(
|
||||
prompt.init_image.size, resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
mask_for_orig_size = mask_for_orig_size.filter(
|
||||
ImageFilter.GaussianBlur(radius=5)
|
||||
)
|
||||
log_img(mask_for_orig_size, "mask for original image size")
|
||||
|
||||
rebuilt_orig_img = Image.composite(
|
||||
img_to_add_back_to_original,
|
||||
prompt.init_image,
|
||||
mask_for_orig_size,
|
||||
)
|
||||
log_img(rebuilt_orig_img, "reconstituted original")
|
||||
|
||||
yield ImagineResult(
|
||||
img=img,
|
||||
prompt=prompt,
|
||||
upscaled_img=upscaled_img,
|
||||
is_nsfw=is_nsfw_img,
|
||||
modified_original_img=rebuilt_orig_img,
|
||||
)
|
||||
|
||||
|
||||
|
@ -129,7 +129,14 @@ def configure_logging(level="INFO"):
|
||||
)
|
||||
@click.option(
|
||||
"--mask-prompt",
|
||||
help="Describe what you want masked and the AI will mask it for you",
|
||||
help=(
|
||||
"Describe what you want masked and the AI will mask it for you. "
|
||||
"You can describe complex masks with AND, OR, NOT keywords and parentheses. "
|
||||
"The strength of each mask can be modified with {*1.5} notation. \n\n"
|
||||
"Examples: \n"
|
||||
"car AND (wheels{*1.1} OR trunk OR engine OR windows OR headlights) AND NOT (truck OR headlights){*10}\n"
|
||||
"fruit|fruit stem"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--mask-mode",
|
||||
@ -137,12 +144,6 @@ def configure_logging(level="INFO"):
|
||||
type=click.Choice(["keep", "replace"]),
|
||||
help="Should we replace the masked area or keep it?",
|
||||
)
|
||||
@click.option(
|
||||
"--mask-expansion",
|
||||
default="2",
|
||||
type=int,
|
||||
help="How much to grow (or shrink) the mask area",
|
||||
)
|
||||
@click.option(
|
||||
"--caption",
|
||||
default=False,
|
||||
@ -178,7 +179,6 @@ def imagine_cmd(
|
||||
mask_image,
|
||||
mask_prompt,
|
||||
mask_mode,
|
||||
mask_expansion,
|
||||
caption,
|
||||
precision,
|
||||
):
|
||||
@ -196,6 +196,9 @@ def imagine_cmd(
|
||||
if init_image and init_image.startswith("http"):
|
||||
init_image = LazyLoadingImage(url=init_image)
|
||||
|
||||
if mask_image and mask_image.startswith("http"):
|
||||
mask_image = LazyLoadingImage(url=mask_image)
|
||||
|
||||
prompts = []
|
||||
load_model()
|
||||
for _ in range(repeats):
|
||||
@ -212,7 +215,6 @@ def imagine_cmd(
|
||||
width=width,
|
||||
mask_image=mask_image,
|
||||
mask_prompt=mask_prompt,
|
||||
mask_expansion=mask_expansion,
|
||||
mask_mode=mask_mode,
|
||||
upscale=upscale,
|
||||
fix_faces=fix_faces,
|
||||
|
163
imaginairy/enhancers/bool_masker.py
Normal file
@ -0,0 +1,163 @@
|
||||
# pylama:ignore=W0613
|
||||
"""
|
||||
Logic for parsing mask prompts.
|
||||
|
||||
Supports
|
||||
lower case text descriptions
|
||||
Combinations: AND OR NOT ()
|
||||
Strength Modifiers: {<operator><number>}
|
||||
|
||||
Examples:
|
||||
fruit
|
||||
fruit bowl
|
||||
fruit AND NOT pears
|
||||
fruit OR bowl
|
||||
(pears OR oranges OR peaches){*1.5}
|
||||
fruit{-0.1} OR bowl
|
||||
|
||||
"""
|
||||
import operator
|
||||
from abc import ABC
|
||||
|
||||
import pyparsing as pp
|
||||
import torch
|
||||
from pyparsing import ParserElement
|
||||
|
||||
ParserElement.enablePackrat()
|
||||
|
||||
|
||||
class Mask(ABC):
|
||||
def get_mask_for_image(self, img):
|
||||
pass
|
||||
|
||||
def gather_text_descriptions(self):
|
||||
return set()
|
||||
|
||||
def apply_masks(self, mask_cache):
|
||||
pass
|
||||
|
||||
|
||||
class SimpleMask(Mask):
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
@classmethod
|
||||
def from_simple_prompt(cls, instring, tokens_start, ret_tokens):
|
||||
return cls(text=ret_tokens[0])
|
||||
|
||||
def __repr__(self):
|
||||
return f"'{self.text}'"
|
||||
|
||||
def gather_text_descriptions(self):
|
||||
return {self.text}
|
||||
|
||||
def apply_masks(self, mask_cache):
|
||||
return mask_cache[self.text]
|
||||
|
||||
|
||||
class ModifiedMask(Mask):
|
||||
ops = {
|
||||
"+": operator.add,
|
||||
"-": operator.sub,
|
||||
"*": operator.mul,
|
||||
"/": operator.truediv,
|
||||
# '%': operator.mod,
|
||||
# '^': operator.xor,
|
||||
}
|
||||
|
||||
def __init__(self, mask, modifier):
|
||||
if modifier:
|
||||
modifier = modifier.strip("{}")
|
||||
self.mask = mask
|
||||
self.modifier = modifier
|
||||
self.operand = self.ops[modifier[0]]
|
||||
self.value = float(modifier[1:])
|
||||
|
||||
@classmethod
|
||||
def from_modifier_parse(cls, instring, tokens_start, ret_tokens):
|
||||
return cls(mask=ret_tokens[0][0], modifier=ret_tokens[0][1])
|
||||
|
||||
def __repr__(self):
|
||||
return f"{repr(self.mask)}{self.modifier}"
|
||||
|
||||
def gather_text_descriptions(self):
|
||||
return self.mask.gather_text_descriptions()
|
||||
|
||||
def apply_masks(self, mask_cache):
|
||||
mask = self.mask.apply_masks(mask_cache)
|
||||
return torch.clamp(self.operand(mask, self.value), 0, 1)
|
||||
|
||||
|
||||
class NestedMask(Mask):
|
||||
def __init__(self, masks, op):
|
||||
self.masks = masks
|
||||
self.op = op
|
||||
|
||||
@classmethod
|
||||
def from_or(cls, instring, tokens_start, ret_tokens):
|
||||
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
|
||||
return cls(masks=sub_masks, op="OR")
|
||||
|
||||
@classmethod
|
||||
def from_and(cls, instring, tokens_start, ret_tokens):
|
||||
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
|
||||
return cls(masks=sub_masks, op="AND")
|
||||
|
||||
@classmethod
|
||||
def from_not(cls, instring, tokens_start, ret_tokens):
|
||||
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
|
||||
assert len(sub_masks) == 1
|
||||
return cls(masks=sub_masks, op="NOT")
|
||||
|
||||
def __repr__(self):
|
||||
if self.op == "NOT":
|
||||
return f"NOT {self.masks[0]}"
|
||||
sub = f" {self.op} ".join(repr(m) for m in self.masks)
|
||||
return f"({sub})"
|
||||
|
||||
def gather_text_descriptions(self):
|
||||
return set().union(*[m.gather_text_descriptions() for m in self.masks])
|
||||
|
||||
def apply_masks(self, mask_cache):
|
||||
submasks = [m.apply_masks(mask_cache) for m in self.masks]
|
||||
mask = submasks[0]
|
||||
if self.op == "OR":
|
||||
for submask in submasks:
|
||||
mask = torch.maximum(mask, submask)
|
||||
elif self.op == "AND":
|
||||
for submask in submasks:
|
||||
mask = torch.minimum(mask, submask)
|
||||
elif self.op == "NOT":
|
||||
mask = 1 - mask
|
||||
else:
|
||||
raise ValueError(f"Invalid operand {self.op}")
|
||||
return torch.clamp(mask, 0, 1)
|
||||
|
||||
|
||||
AND = (pp.Literal("AND") | pp.Literal("&")).setName("AND").setResultsName("op")
|
||||
OR = (pp.Literal("OR") | pp.Literal("|")).setName("OR").setResultsName("op")
|
||||
NOT = (pp.Literal("NOT") | pp.Literal("!")).setName("NOT").setResultsName("op")
|
||||
|
||||
PROMPT_MODIFIER = (
|
||||
pp.Regex(r"{[*/+-]\d+\.?\d*}")
|
||||
.setName("prompt_modifier")
|
||||
.setResultsName("prompt_modifier")
|
||||
)
|
||||
PROMPT_TEXT = (
|
||||
pp.Regex(r"[a-z0-9]?[a-z0-9 -]*[a-z0-9]")
|
||||
.setName("prompt_text")
|
||||
.setResultsName("prompt_text")
|
||||
)
|
||||
SIMPLE_PROMPT = PROMPT_TEXT.setResultsName("simplePrompt")
|
||||
SIMPLE_PROMPT.setParseAction(SimpleMask.from_simple_prompt)
|
||||
|
||||
COMPLEX_PROMPT = pp.infixNotation(
|
||||
SIMPLE_PROMPT,
|
||||
[
|
||||
(PROMPT_MODIFIER, 1, pp.opAssoc.LEFT, ModifiedMask.from_modifier_parse),
|
||||
(NOT, 1, pp.opAssoc.RIGHT, NestedMask.from_not),
|
||||
(AND, 2, pp.opAssoc.LEFT, NestedMask.from_and),
|
||||
(OR, 2, pp.opAssoc.LEFT, NestedMask.from_or),
|
||||
],
|
||||
)
|
||||
MASK_PROMPT = pp.Group(COMPLEX_PROMPT).setResultsName("complexPrompt")
|
@ -1,5 +1,7 @@
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
@ -26,33 +28,26 @@ def clip_mask_model():
|
||||
return model
|
||||
|
||||
|
||||
def get_img_mask(img, mask_description, negative_description=""):
|
||||
pos_descriptions = mask_description.split("|")
|
||||
pos_masks = get_img_masks(img, pos_descriptions)
|
||||
pos_mask = pos_masks[0]
|
||||
for pred in pos_masks:
|
||||
pos_mask = torch.maximum(pos_mask, pred)
|
||||
def get_img_mask(
|
||||
img: PIL.Image.Image,
|
||||
mask_description_statement: str,
|
||||
threshold: Optional[float] = None,
|
||||
):
|
||||
from imaginairy.enhancers.bool_masker import MASK_PROMPT # noqa
|
||||
|
||||
log_img(pos_mask, "pos mask")
|
||||
|
||||
if negative_description:
|
||||
neg_descriptions = negative_description.split("|")
|
||||
neg_masks = get_img_masks(img, neg_descriptions)
|
||||
neg_mask = neg_masks[0]
|
||||
for pred in neg_masks:
|
||||
neg_mask = torch.maximum(neg_mask, pred)
|
||||
neg_mask = (neg_mask * 3).clip(0, 1)
|
||||
log_img(neg_mask, "neg_mask")
|
||||
pos_mask = torch.minimum(pos_mask, 1 - neg_mask)
|
||||
_min = pos_mask.min()
|
||||
_max = pos_mask.max()
|
||||
_range = _max - _min
|
||||
pos_mask = (pos_mask > (_min + (_range * 0.35))).float()
|
||||
|
||||
return transforms.ToPILImage()(pos_mask)
|
||||
parsed = MASK_PROMPT.parseString(mask_description_statement)
|
||||
parsed_mask = parsed[0][0]
|
||||
descriptions = list(parsed_mask.gather_text_descriptions())
|
||||
mask_cache = get_img_masks(img, descriptions)
|
||||
mask = parsed_mask.apply_masks(mask_cache)
|
||||
log_img(mask, "combined mask")
|
||||
if threshold is not None:
|
||||
mask[mask < threshold] = 0
|
||||
mask[mask >= threshold] = 1
|
||||
return transforms.ToPILImage()(mask)
|
||||
|
||||
|
||||
def get_img_masks(img, mask_descriptions):
|
||||
def get_img_masks(img, mask_descriptions: Sequence[str]):
|
||||
a, b = img.size
|
||||
orig_size = b, a
|
||||
log_img(img, "image for masking")
|
||||
@ -75,11 +70,9 @@ def get_img_masks(img, mask_descriptions):
|
||||
|
||||
preds = [torch.sigmoid(p[0]) for p in preds]
|
||||
|
||||
bw_preds = []
|
||||
preds_dict = {}
|
||||
for p, desc in zip(preds, mask_descriptions):
|
||||
log_img(p, f"clip mask: {desc}")
|
||||
# bw_preds.append(pred_transform(p))
|
||||
preds_dict[desc] = p
|
||||
|
||||
bw_preds.append(p)
|
||||
|
||||
return bw_preds
|
||||
return preds_dict
|
||||
|
@ -1,12 +1,11 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from imaginairy.img_utils import model_latents_to_pillow_imgs
|
||||
|
||||
_CURRENT_LOGGING_CONTEXT = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -64,11 +63,7 @@ class ImageLoggingContext:
|
||||
return
|
||||
self.step_count += 1
|
||||
description = f"{description} - {latents.shape}"
|
||||
latents = self.model.decode_first_stage(latents)
|
||||
latents = torch.clamp((latents + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
for latent in latents:
|
||||
latent = 255.0 * rearrange(latent.cpu().numpy(), "c h w -> h w c")
|
||||
img = Image.fromarray(latent.astype(np.uint8))
|
||||
for img in model_latents_to_pillow_imgs(latents):
|
||||
self.img_callback(img, description, self.step_count, self.prompt)
|
||||
|
||||
def log_img(self, img, description):
|
||||
|
59
imaginairy/img_utils.py
Normal file
@ -0,0 +1,59 @@
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
|
||||
def pillow_fit_image_within(image: PIL.Image.Image, max_height=512, max_width=512):
|
||||
image = image.convert("RGB")
|
||||
w, h = image.size
|
||||
resize_ratio = min(max_width / w, max_height / h)
|
||||
w, h = int(w * resize_ratio), int(h * resize_ratio)
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((w, h), resample=Image.Resampling.NEAREST)
|
||||
return image, w, h
|
||||
|
||||
|
||||
def pillow_img_to_torch_image(img: PIL.Image.Image):
|
||||
img = img.convert("RGB")
|
||||
img = np.array(img).astype(np.float32) / 255.0
|
||||
img = img[None].transpose(0, 3, 1, 2)
|
||||
img = torch.from_numpy(img)
|
||||
return 2.0 * img - 1.0
|
||||
|
||||
|
||||
def pillow_img_to_opencv_img(img: PIL.Image.Image):
|
||||
open_cv_image = np.array(img)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
return open_cv_image
|
||||
|
||||
|
||||
def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Image]:
|
||||
from imaginairy.api import load_model # noqa
|
||||
|
||||
model = load_model()
|
||||
latents = model.decode_first_stage(latents)
|
||||
latents = torch.clamp((latents + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
imgs = []
|
||||
for latent in latents:
|
||||
latent = 255.0 * rearrange(latent.cpu().numpy(), "c h w -> h w c")
|
||||
img = Image.fromarray(latent.astype(np.uint8))
|
||||
imgs.append(img)
|
||||
return imgs
|
||||
|
||||
|
||||
def pillow_img_to_model_latent(model, img, batch_size=1, half=True):
|
||||
# init_image = pil_img_to_torch(img, half=half).to(device)
|
||||
init_image = pillow_img_to_torch_image(img).to(get_device())
|
||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||
if half:
|
||||
return model.get_first_stage_encoding(
|
||||
model.encode_first_stage(init_image.half())
|
||||
)
|
||||
return model.get_first_stage_encoding(model.encode_first_stage(init_image))
|
@ -9,26 +9,15 @@ needs https://github.com/crowsonkb/k-diffusion
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from einops import repeat
|
||||
from torch import autocast
|
||||
|
||||
from imaginairy.utils import get_device, pillow_img_to_torch_image
|
||||
from imaginairy.img_utils import pillow_img_to_model_latent
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.vendored import k_diffusion as K
|
||||
|
||||
|
||||
def pil_img_to_latent(model, img, batch_size=1, half=True):
|
||||
# init_image = pil_img_to_torch(img, half=half).to(device)
|
||||
init_image = pillow_img_to_torch_image(img).to(get_device())
|
||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||
if half:
|
||||
return model.get_first_stage_encoding(
|
||||
model.encode_first_stage(init_image.half())
|
||||
)
|
||||
return model.get_first_stage_encoding(model.encode_first_stage(init_image))
|
||||
|
||||
|
||||
def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=True):
|
||||
img_latent = pil_img_to_latent(
|
||||
img_latent = pillow_img_to_model_latent(
|
||||
model, pil_img, batch_size=1, device="cuda", half=half
|
||||
)
|
||||
return find_noise_for_latent(
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pylama:ignore=W0613
|
||||
"""SAMPLING ONLY."""
|
||||
import logging
|
||||
|
||||
|
@ -6,9 +6,8 @@ import random
|
||||
from datetime import datetime, timezone
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy
|
||||
import requests
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
from urllib3.exceptions import LocationParseError
|
||||
from urllib3.util import parse_url
|
||||
|
||||
@ -64,6 +63,8 @@ class LazyLoadingImage:
|
||||
logger.info(
|
||||
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_url}"
|
||||
)
|
||||
# fix orientation
|
||||
self._img = ImageOps.exif_transpose(self._img)
|
||||
|
||||
return getattr(self._img, key)
|
||||
|
||||
@ -94,7 +95,6 @@ class ImaginePrompt:
|
||||
mask_prompt=None,
|
||||
mask_image=None,
|
||||
mask_mode=MaskMode.REPLACE,
|
||||
mask_expansion=2,
|
||||
seed=None,
|
||||
steps=50,
|
||||
height=512,
|
||||
@ -114,7 +114,7 @@ class ImaginePrompt:
|
||||
self.prompt_strength = prompt_strength
|
||||
if isinstance(init_image, str):
|
||||
init_image = LazyLoadingImage(filepath=init_image)
|
||||
|
||||
|
||||
if isinstance(mask_image, str):
|
||||
mask_image = LazyLoadingImage(filepath=mask_image)
|
||||
|
||||
@ -134,7 +134,6 @@ class ImaginePrompt:
|
||||
self.mask_prompt = mask_prompt
|
||||
self.mask_image = mask_image
|
||||
self.mask_mode = mask_mode
|
||||
self.mask_expansion = mask_expansion
|
||||
self.tile_mode = tile_mode
|
||||
|
||||
@property
|
||||
@ -178,22 +177,23 @@ class ExifCodes:
|
||||
|
||||
|
||||
class ImagineResult:
|
||||
def __init__(self, img, prompt: ImaginePrompt, is_nsfw, upscaled_img=None):
|
||||
def __init__(
|
||||
self,
|
||||
img,
|
||||
prompt: ImaginePrompt,
|
||||
is_nsfw,
|
||||
upscaled_img=None,
|
||||
modified_original_img=None,
|
||||
):
|
||||
self.img = img
|
||||
self.upscaled_img = upscaled_img
|
||||
self.modified_original_img = modified_original_img
|
||||
self.prompt = prompt
|
||||
self.is_nsfw = is_nsfw
|
||||
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||
self.torch_backend = get_device()
|
||||
self.hardware_name = get_device_name(get_device())
|
||||
|
||||
def cv2_img(self):
|
||||
open_cv_image = numpy.array(self.img)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
return open_cv_image
|
||||
# return cv2.cvtColor(numpy.array(self.img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
def md5(self):
|
||||
return hashlib.md5(self.img.tobytes()).hexdigest()
|
||||
|
||||
@ -218,6 +218,9 @@ class ImagineResult:
|
||||
def save_upscaled(self, save_path):
|
||||
self.upscaled_img.save(save_path, exif=self._exif())
|
||||
|
||||
def save_modified_orig(self, save_path):
|
||||
self.modified_original_img.save(save_path, exif=self._exif())
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_briefly_cached_url(url):
|
||||
|
@ -6,17 +6,13 @@ from contextlib import contextmanager, nullcontext
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image, ImageFilter
|
||||
from torch import Tensor, autocast
|
||||
from torch.nn import functional
|
||||
from torch.overrides import handle_torch_function, has_torch_function_variadic
|
||||
from transformers import cached_path
|
||||
|
||||
from imaginairy.img_log import log_img
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -64,7 +60,7 @@ def get_obj_from_str(string, reload=False):
|
||||
@contextmanager
|
||||
def platform_appropriate_autocast(precision="autocast"):
|
||||
"""
|
||||
allow calculations to run in mixed precision, which can be faster
|
||||
Allow calculations to run in mixed precision, which can be faster
|
||||
"""
|
||||
precision_scope = nullcontext
|
||||
if precision == "autocast" and get_device() in ("cuda", "cpu"):
|
||||
@ -153,47 +149,6 @@ def fix_torch_group_norm():
|
||||
functional.group_norm = orig_group_norm
|
||||
|
||||
|
||||
def expand_mask(mask_image, size):
|
||||
if size < 0:
|
||||
threshold = 0.95
|
||||
else:
|
||||
threshold = 0.05
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_image = mask_image.filter(ImageFilter.GaussianBlur(size))
|
||||
log_img(mask_image, "init mask blurred")
|
||||
mask = np.array(mask_image)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
|
||||
mask[mask < threshold] = 0
|
||||
mask[mask >= threshold] = 1
|
||||
return Image.fromarray(np.uint8(mask.squeeze() * 255))
|
||||
|
||||
|
||||
def img_path_to_torch_image(path):
|
||||
image = Image.open(path).convert("RGB")
|
||||
logger.info(f"Loaded input 🖼 of size {image.size} from {path}")
|
||||
return pillow_img_to_torch_image(image)
|
||||
|
||||
|
||||
def pillow_fit_image_within(image, max_height=512, max_width=512):
|
||||
image = image.convert("RGB")
|
||||
w, h = image.size
|
||||
resize_ratio = min(max_width / w, max_height / h)
|
||||
w, h = int(w * resize_ratio), int(h * resize_ratio)
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((w, h), resample=Image.Resampling.NEAREST)
|
||||
return image, w, h
|
||||
|
||||
|
||||
def pillow_img_to_torch_image(image):
|
||||
image = image.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def get_cache_dir():
|
||||
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
|
||||
if xdg_cache_home is None:
|
||||
|
BIN
tests/data/girl_with_a_pearl_earring_large.jpg
Normal file
After Width: | Height: | Size: 1.1 MiB |
@ -4,6 +4,8 @@ import pytest
|
||||
from PIL import Image
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from imaginairy import ImaginePrompt, imagine
|
||||
from imaginairy.enhancers.bool_masker import MASK_PROMPT
|
||||
from imaginairy.enhancers.clip_masking import get_img_mask
|
||||
from imaginairy.enhancers.describe_image_blip import generate_caption
|
||||
from imaginairy.enhancers.describe_image_clip import find_img_text_similarity
|
||||
@ -28,9 +30,98 @@ def img_hash(img):
|
||||
|
||||
|
||||
def test_clip_masking():
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
||||
pred = get_img_mask(img, "head")
|
||||
pred.save(f"{TESTS_FOLDER}/test_output/earring_mask.png")
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring_large.jpg")
|
||||
for mask_modifier in [
|
||||
"*0.5",
|
||||
"*1",
|
||||
"*10",
|
||||
]:
|
||||
pred = get_img_mask(img, f"(head OR face){{{mask_modifier}}}")
|
||||
pred.save(f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}.png")
|
||||
|
||||
prompt = ImaginePrompt(
|
||||
"professional photo of a woman",
|
||||
init_image=img,
|
||||
init_image_strength=0.95,
|
||||
# lower steps for faster tests
|
||||
# steps=40,
|
||||
steps=4,
|
||||
mask_prompt="(head OR face)*5",
|
||||
mask_mode="replace",
|
||||
upscale=True,
|
||||
fix_faces=True,
|
||||
)
|
||||
|
||||
result = next(imagine(prompt))
|
||||
result.modified_original_img.save(
|
||||
f"{TESTS_FOLDER}/test_output/earring_mask_photo.png"
|
||||
)
|
||||
|
||||
|
||||
boolean_mask_test_cases = [
|
||||
(
|
||||
"fruit bowl",
|
||||
"'fruit bowl'",
|
||||
),
|
||||
(
|
||||
"((((fruit bowl))))",
|
||||
"'fruit bowl'",
|
||||
),
|
||||
(
|
||||
"fruit OR bowl",
|
||||
"('fruit' OR 'bowl')",
|
||||
),
|
||||
(
|
||||
"fruit|bowl",
|
||||
"('fruit' OR 'bowl')",
|
||||
),
|
||||
(
|
||||
"fruit | bowl",
|
||||
"('fruit' OR 'bowl')",
|
||||
),
|
||||
(
|
||||
"fruit OR bowl OR pear",
|
||||
"('fruit' OR 'bowl' OR 'pear')",
|
||||
),
|
||||
(
|
||||
"fruit AND bowl",
|
||||
"('fruit' AND 'bowl')",
|
||||
),
|
||||
(
|
||||
"fruit & bowl",
|
||||
"('fruit' AND 'bowl')",
|
||||
),
|
||||
(
|
||||
"fruit AND NOT green",
|
||||
"('fruit' AND NOT 'green')",
|
||||
),
|
||||
(
|
||||
"fruit bowl{+0.5}",
|
||||
"'fruit bowl'+0.5",
|
||||
),
|
||||
(
|
||||
"fruit bowl{+0.5} OR fruit",
|
||||
"('fruit bowl'+0.5 OR 'fruit')",
|
||||
),
|
||||
(
|
||||
"NOT pizza",
|
||||
"NOT 'pizza'",
|
||||
),
|
||||
(
|
||||
"car AND (wheels OR trunk OR engine OR windows) AND NOT (truck OR headlights{*10})",
|
||||
"('car' AND ('wheels' OR 'trunk' OR 'engine' OR 'windows') AND NOT ('truck' OR 'headlights'*10))",
|
||||
),
|
||||
(
|
||||
"car AND (wheels OR trunk OR engine OR windows OR headlights) AND NOT (truck OR headlights){*10}",
|
||||
"('car' AND ('wheels' OR 'trunk' OR 'engine' OR 'windows' OR 'headlights') AND NOT ('truck' OR 'headlights')*10)",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mask_text,expected", boolean_mask_test_cases)
|
||||
def test_clip_mask_parser(mask_text, expected):
|
||||
parsed = MASK_PROMPT.parseString(mask_text)[0][0]
|
||||
assert str(parsed) == expected
|
||||
|
||||
|
||||
def test_describe_picture():
|
||||
|
@ -6,9 +6,10 @@ from PIL import ImageDraw
|
||||
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine, imagine_image_files
|
||||
from imaginairy.api import load_model
|
||||
from imaginairy.img_log import ImageLoggingContext, filesafe_text, log_latent
|
||||
from imaginairy.img_utils import pillow_img_to_torch_image
|
||||
from imaginairy.modules.clip_embedders import FrozenCLIPEmbedder
|
||||
from imaginairy.samplers.ddim import DDIMSampler
|
||||
from imaginairy.utils import get_device, pillow_img_to_torch_image
|
||||
from imaginairy.utils import get_device
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
|
||||
|
@ -140,9 +140,8 @@ def test_cliptext_inpainting():
|
||||
prompt_strength=12,
|
||||
init_image=f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg",
|
||||
init_image_strength=0.3,
|
||||
mask_prompt="face",
|
||||
mask_prompt="face{*2}",
|
||||
mask_mode=ImaginePrompt.MaskMode.KEEP,
|
||||
mask_expansion=-3,
|
||||
width=512,
|
||||
height=512,
|
||||
steps=5,
|
||||
|
@ -1,25 +1,12 @@
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy.api import load_model
|
||||
from imaginairy.safety import is_nsfw
|
||||
from imaginairy.utils import get_device, pillow_img_to_torch_image
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
|
||||
def test_is_nsfw():
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
|
||||
latent = _pil_to_latent(img)
|
||||
assert is_nsfw(img)
|
||||
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
||||
latent = _pil_to_latent(img)
|
||||
assert not is_nsfw(img)
|
||||
|
||||
|
||||
def _pil_to_latent(img):
|
||||
model = load_model()
|
||||
model.tile_mode(False)
|
||||
img = pillow_img_to_torch_image(img)
|
||||
img = img.to(get_device())
|
||||
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
|
||||
return latent
|
||||
|
2
tox.ini
@ -12,7 +12,7 @@ skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads
|
||||
linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy
|
||||
ignore =
|
||||
Z999,C0103,C0301,C0114,C0115,C0116,
|
||||
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D415,
|
||||
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D415,
|
||||
Z999,E501,E1101,
|
||||
Z999,R0901,R0902,R0903,R0193,R0912,R0913,R0914,R0915,
|
||||
Z999,W0221,W0511,W1203
|
||||
|