imaginAIry/imaginairy/enhancers/bool_masker.py

177 lines
5.0 KiB
Python
Raw Normal View History

# pylama:ignore=W0613
"""
Logic for parsing mask prompts.
Supports
lower case text descriptions
Combinations: AND OR NOT ()
Strength Modifiers: {<operator><number>}
Examples:
fruit
fruit bowl
fruit AND NOT pears
fruit OR bowl
(pears OR oranges OR peaches){*1.5}
fruit{-0.1} OR bowl
"""
import operator
from abc import ABC
from typing import ClassVar
import pyparsing as pp
import torch
from kornia.morphology import dilation, erosion
from pyparsing import ParserElement
ParserElement.enablePackrat()
class Mask(ABC):
def get_mask_for_image(self, img):
pass
def gather_text_descriptions(self):
return set()
def apply_masks(self, mask_cache):
pass
class SimpleMask(Mask):
def __init__(self, text):
self.text = text
@classmethod
def from_simple_prompt(cls, instring, tokens_start, ret_tokens):
return cls(text=ret_tokens[0])
def __repr__(self):
return f"'{self.text}'"
def gather_text_descriptions(self):
return {self.text}
def apply_masks(self, mask_cache):
return mask_cache[self.text]
class ModifiedMask(Mask):
ops: ClassVar = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.truediv,
# '%': operator.mod,
# '^': operator.xor,
}
def __init__(self, mask, modifier):
if modifier:
modifier = modifier.strip("{}")
self.mask = mask
self.modifier = modifier
self.operand_str = modifier[0]
self.operand = self.ops[self.operand_str]
self.value = float(modifier[1:])
@classmethod
def from_modifier_parse(cls, instring, tokens_start, ret_tokens):
return cls(mask=ret_tokens[0][0], modifier=ret_tokens[0][1])
def __repr__(self):
return f"{self.mask!r}{self.modifier}"
def gather_text_descriptions(self):
return self.mask.gather_text_descriptions()
def apply_masks(self, mask_cache):
mask = self.mask.apply_masks(mask_cache)
if self.operand_str in {"+", "-"}:
# kernel must be odd
kernel_size = int(round(self.value))
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
morph_method = dilation if self.operand_str == "+" else erosion
mask = mask.unsqueeze_(0).unsqueeze_(0)
mask = morph_method(mask, torch.ones(kernel_size, kernel_size))
mask = mask.squeeze()
return mask
return torch.clamp(self.operand(mask, self.value), 0, 1)
class NestedMask(Mask):
def __init__(self, masks, op):
self.masks = masks
self.op = op
@classmethod
def from_or(cls, instring, tokens_start, ret_tokens):
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
return cls(masks=sub_masks, op="OR")
@classmethod
def from_and(cls, instring, tokens_start, ret_tokens):
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
return cls(masks=sub_masks, op="AND")
@classmethod
def from_not(cls, instring, tokens_start, ret_tokens):
sub_masks = [t for t in ret_tokens[0] if isinstance(t, Mask)]
assert len(sub_masks) == 1
return cls(masks=sub_masks, op="NOT")
def __repr__(self):
if self.op == "NOT":
return f"NOT {self.masks[0]}"
sub = f" {self.op} ".join(repr(m) for m in self.masks)
return f"({sub})"
def gather_text_descriptions(self):
return set().union(*[m.gather_text_descriptions() for m in self.masks])
def apply_masks(self, mask_cache):
submasks = [m.apply_masks(mask_cache) for m in self.masks]
mask = submasks[0]
if self.op == "OR":
for submask in submasks:
mask = torch.maximum(mask, submask)
elif self.op == "AND":
for submask in submasks:
mask = torch.minimum(mask, submask)
elif self.op == "NOT":
mask = 1 - mask
else:
msg = f"Invalid operand {self.op}"
raise ValueError(msg)
return torch.clamp(mask, 0, 1)
AND = (pp.Literal("AND") | pp.Literal("&")).setName("AND").setResultsName("op")
OR = (pp.Literal("OR") | pp.Literal("|")).setName("OR").setResultsName("op")
NOT = (pp.Literal("NOT") | pp.Literal("!")).setName("NOT").setResultsName("op")
PROMPT_MODIFIER = (
pp.Regex(r"{[*/+-]\d+\.?\d*}")
.setName("prompt_modifier")
.setResultsName("prompt_modifier")
)
PROMPT_TEXT = (
pp.Regex(r"[a-z0-9]?[a-z0-9 -]*[a-z0-9]")
.setName("prompt_text")
.setResultsName("prompt_text")
)
SIMPLE_PROMPT = PROMPT_TEXT.setResultsName("simplePrompt")
SIMPLE_PROMPT.setParseAction(SimpleMask.from_simple_prompt)
COMPLEX_PROMPT = pp.infixNotation(
SIMPLE_PROMPT,
[
(PROMPT_MODIFIER, 1, pp.opAssoc.LEFT, ModifiedMask.from_modifier_parse),
(NOT, 1, pp.opAssoc.RIGHT, NestedMask.from_not),
(AND, 2, pp.opAssoc.LEFT, NestedMask.from_and),
(OR, 2, pp.opAssoc.LEFT, NestedMask.from_or),
],
)
MASK_PROMPT = pp.Group(COMPLEX_PROMPT).setResultsName("complexPrompt")