imaginAIry/imaginairy/enhancers/bool_masker.py
Bryce 8332593fed feature: dilation and erosion of masks
Previously the `+` and `-` characters in a mask (example: `face{+0.1}`) added to the grayscale value of any masked areas. This wasn't very useful. The new behavior is that the mask will expand or contract by the number of pixel specified. The technical terms for this are dilation and erosion.  This allows much greater control over the masked area.
2022-10-17 23:47:26 -07:00

175 lines
5.0 KiB
Python

# pylama:ignore=W0613
"""
Logic for parsing mask prompts.
Supports
lower case text descriptions
Combinations: AND OR NOT ()
Strength Modifiers: {<operator><number>}
Examples:
fruit
fruit bowl
fruit AND NOT pears
fruit OR bowl
(pears OR oranges OR peaches){*1.5}
fruit{-0.1} OR bowl
"""
import operator
from abc import ABC
import pyparsing as pp
import torch
from kornia.morphology import dilation, erosion
from pyparsing import ParserElement
ParserElement.enablePackrat()
class Mask(ABC):
def get_mask_for_image(self, img):
pass
def gather_text_descriptions(self):
return set()
def apply_masks(self, mask_cache):
pass
class SimpleMask(Mask):
def __init__(self, text):
self.text = text
@classmethod
def from_simple_prompt(cls, instring, tokens_start, ret_tokens):
return cls(text=ret_tokens[0])
def __repr__(self):
return f"'{self.text}'"
def gather_text_descriptions(self):
return {self.text}
def apply_masks(self, mask_cache):
return mask_cache[self.text]
class ModifiedMask(Mask):
ops = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.truediv,
# '%': operator.mod,
# '^': operator.xor,
}
def __init__(self, mask, modifier):
if modifier:
modifier = modifier.strip("{}")
self.mask = mask
self.modifier = modifier
self.operand_str = modifier[0]
self.operand = self.ops[self.operand_str]
self.value = float(modifier[1:])
@classmethod
def from_modifier_parse(cls, instring, tokens_start, ret_tokens):
return cls(mask=ret_tokens[0][0], modifier=ret_tokens[0][1])
def __repr__(self):
return f"{repr(self.mask)}{self.modifier}"
def gather_text_descriptions(self):
return self.mask.gather_text_descriptions()
def apply_masks(self, mask_cache):
mask = self.mask.apply_masks(mask_cache)
if self.operand_str in {"+", "-"}:
# kernel must be odd
kernel_size = int(round(self.value))
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
morph_method = dilation if self.operand_str == "+" else erosion
mask = mask.unsqueeze_(0).unsqueeze_(0)
mask = morph_method(mask, torch.ones(kernel_size, kernel_size))
mask = mask.squeeze()
return mask
return torch.clamp(self.operand(mask, self.value), 0, 1)
class NestedMask(Mask):
def __init__(self, masks, op):
self.masks = masks
self.op = op
@classmethod
def from_or(cls, instring, tokens_start, ret_tokens):
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
return cls(masks=sub_masks, op="OR")
@classmethod
def from_and(cls, instring, tokens_start, ret_tokens):
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
return cls(masks=sub_masks, op="AND")
@classmethod
def from_not(cls, instring, tokens_start, ret_tokens):
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
assert len(sub_masks) == 1
return cls(masks=sub_masks, op="NOT")
def __repr__(self):
if self.op == "NOT":
return f"NOT {self.masks[0]}"
sub = f" {self.op} ".join(repr(m) for m in self.masks)
return f"({sub})"
def gather_text_descriptions(self):
return set().union(*[m.gather_text_descriptions() for m in self.masks])
def apply_masks(self, mask_cache):
submasks = [m.apply_masks(mask_cache) for m in self.masks]
mask = submasks[0]
if self.op == "OR":
for submask in submasks:
mask = torch.maximum(mask, submask)
elif self.op == "AND":
for submask in submasks:
mask = torch.minimum(mask, submask)
elif self.op == "NOT":
mask = 1 - mask
else:
raise ValueError(f"Invalid operand {self.op}")
return torch.clamp(mask, 0, 1)
AND = (pp.Literal("AND") | pp.Literal("&")).setName("AND").setResultsName("op")
OR = (pp.Literal("OR") | pp.Literal("|")).setName("OR").setResultsName("op")
NOT = (pp.Literal("NOT") | pp.Literal("!")).setName("NOT").setResultsName("op")
PROMPT_MODIFIER = (
pp.Regex(r"{[*/+-]\d+\.?\d*}")
.setName("prompt_modifier")
.setResultsName("prompt_modifier")
)
PROMPT_TEXT = (
pp.Regex(r"[a-z0-9]?[a-z0-9 -]*[a-z0-9]")
.setName("prompt_text")
.setResultsName("prompt_text")
)
SIMPLE_PROMPT = PROMPT_TEXT.setResultsName("simplePrompt")
SIMPLE_PROMPT.setParseAction(SimpleMask.from_simple_prompt)
COMPLEX_PROMPT = pp.infixNotation(
SIMPLE_PROMPT,
[
(PROMPT_MODIFIER, 1, pp.opAssoc.LEFT, ModifiedMask.from_modifier_parse),
(NOT, 1, pp.opAssoc.RIGHT, NestedMask.from_not),
(AND, 2, pp.opAssoc.LEFT, NestedMask.from_and),
(OR, 2, pp.opAssoc.LEFT, NestedMask.from_or),
],
)
MASK_PROMPT = pp.Group(COMPLEX_PROMPT).setResultsName("complexPrompt")