feature: improved safety filter

- provides more informative logs
- provides a detailed safety score object
- adds non-bypassable filter for extreme content
pull/54/head
Bryce 2 years ago committed by Bryce Drennan
parent 0db5c329bb
commit 6a80759016

@ -213,6 +213,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
## ChangeLog
**3.0.0**
- feature: improved safety filter
**2.4.0**
- 🎉 feature: prompt expansion

@ -24,7 +24,7 @@ from imaginairy.img_log import (
log_latent,
)
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
from imaginairy.safety import is_nsfw
from imaginairy.safety import SafetyMode, create_safety_score
from imaginairy.samplers.base import get_sampler
from imaginairy.samplers.plms import PLMSSchedule
from imaginairy.schema import ImaginePrompt, ImagineResult
@ -40,16 +40,14 @@ LIB_PATH = os.path.dirname(__file__)
logger = logging.getLogger(__name__)
class SafetyMode:
DISABLED = "disabled"
CLASSIFY = "classify"
FILTER = "filter"
# leave undocumented. I'd ask that no one publicize this flag. Just want a
# slight barrier to entry. Please don't use this is any way that's gonna cause
# the press or governments to freak out about AI...
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER)
# the media or politicians to freak out about AI...
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT)
if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED
elif IMAGINAIRY_SAFETY_MODE == "filter":
IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT
DEFAULT_MODEL_WEIGHTS_LOCATION = (
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
@ -68,8 +66,7 @@ def load_model_from_config(
ckpt_path = cached_path(model_weights_location)
else:
ckpt_path = model_weights_location
logger.info(f"Loading model onto {get_device()} backend...")
logger.debug(f"Loading model from {ckpt_path}")
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
pl_sd = torch.load(ckpt_path, map_location="cpu")
if "global_step" in pl_sd:
logger.debug(f"Global Step: {pl_sd['global_step']}")
@ -352,59 +349,61 @@ def imagine(
upscaled_img = None
rebuilt_orig_img = None
is_nsfw_img = None
if add_caption:
caption = generate_caption(img)
logger.info(f" Generated caption: {caption}")
if IMAGINAIRY_SAFETY_MODE != SafetyMode.DISABLED:
is_nsfw_img = is_nsfw(img)
if is_nsfw_img and IMAGINAIRY_SAFETY_MODE == SafetyMode.FILTER:
logger.info(" ⚠️ Filtering NSFW image")
img = img.filter(ImageFilter.GaussianBlur(radius=40))
if prompt.fix_faces:
logger.info(" Fixing 😊 's in 🖼 using CodeFormer...")
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
if prompt.upscale:
logger.info(" Upscaling 🖼 using real-ESRGAN...")
upscaled_img = upscale_image(img)
# put the newly generated patch back into the original, full size image
if (
prompt.mask_modify_original
and mask_image_orig
and prompt.init_image
):
img_to_add_back_to_original = (
upscaled_img if upscaled_img else img
)
img_to_add_back_to_original = (
img_to_add_back_to_original.resize(
safety_score = create_safety_score(
img,
safety_mode=IMAGINAIRY_SAFETY_MODE,
)
if not safety_score.is_filtered:
if prompt.fix_faces:
logger.info(" Fixing 😊 's in 🖼 using CodeFormer...")
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
if prompt.upscale:
logger.info(" Upscaling 🖼 using real-ESRGAN...")
upscaled_img = upscale_image(img)
# put the newly generated patch back into the original, full size image
if (
prompt.mask_modify_original
and mask_image_orig
and prompt.init_image
):
img_to_add_back_to_original = (
upscaled_img if upscaled_img else img
)
img_to_add_back_to_original = (
img_to_add_back_to_original.resize(
prompt.init_image.size,
resample=Image.Resampling.LANCZOS,
)
)
mask_for_orig_size = mask_image_orig.resize(
prompt.init_image.size,
resample=Image.Resampling.LANCZOS,
)
)
mask_for_orig_size = mask_image_orig.resize(
prompt.init_image.size, resample=Image.Resampling.LANCZOS
)
mask_for_orig_size = mask_for_orig_size.filter(
ImageFilter.GaussianBlur(radius=5)
)
log_img(mask_for_orig_size, "mask for original image size")
mask_for_orig_size = mask_for_orig_size.filter(
ImageFilter.GaussianBlur(radius=5)
)
log_img(mask_for_orig_size, "mask for original image size")
rebuilt_orig_img = Image.composite(
prompt.init_image,
img_to_add_back_to_original,
mask_for_orig_size,
)
log_img(rebuilt_orig_img, "reconstituted original")
rebuilt_orig_img = Image.composite(
prompt.init_image,
img_to_add_back_to_original,
mask_for_orig_size,
)
log_img(rebuilt_orig_img, "reconstituted original")
yield ImagineResult(
img=img,
prompt=prompt,
upscaled_img=upscaled_img,
is_nsfw=is_nsfw_img,
is_nsfw=safety_score.is_nsfw,
safety_score=safety_score,
modified_original=rebuilt_orig_img,
mask_binary=mask_image_orig,
mask_grayscale=mask_grayscale,

@ -1,20 +1,128 @@
import logging
from functools import lru_cache
import numpy as np
import torch
from diffusers.pipelines.stable_diffusion import safety_checker as safety_checker_mod
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from transformers import AutoFeatureExtractor
logger = logging.getLogger(__name__)
class SafetyMode:
STRICT = "strict"
RELAXED = "relaxed"
class SafetyResult:
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
_default_adjustment = 0.0
def __init__(self):
self.nsfw_scores = {}
self.special_care_scores = {}
self.is_filtered = False
def add_special_care_score(self, concept_idx, abs_score, threshold):
adjustment = self._default_adjustment
adjusted_score = round(abs_score - threshold + adjustment, 3)
try:
score_name = _SPECIAL_CARE_DESCRIPTIONS[concept_idx]
except LookupError:
score_name = ""
if adjusted_score > 0 and score_name:
logger.debug(
f" 🔞🔞 '{score_name}' abs:{abs_score:.3f} adj:{adjusted_score}"
)
self.special_care_scores[concept_idx] = adjusted_score
def add_nsfw_score(self, concept_idx, abs_score, threshold):
if len(self.special_care_scores) != 3:
raise ValueError("special care scores must be set first")
adjustment = self._default_adjustment
if self.special_care_score > 0:
adjustment += 0.01
adjusted_score = round(abs_score - threshold + adjustment, 3)
try:
score_name = _CONCEPT_DESCRIPTIONS[concept_idx]
except LookupError:
score_name = ""
if adjusted_score > 0 and score_name:
logger.debug(
f" 🔞 '{concept_idx}:{score_name}' abs:{abs_score:.3f} adj:{adjusted_score}"
)
self.nsfw_scores[concept_idx] = adjusted_score
@property
def nsfw_score(self):
return max(self.nsfw_scores.values())
@property
def special_care_score(self):
return max(self.special_care_scores.values())
@property
def special_care_nsfw_score(self):
return min(self.nsfw_score, self.special_care_score)
@property
def is_nsfw(self):
return self.nsfw_score > 0
@property
def is_special_care_nsfw(self):
return self.special_care_nsfw_score > 0
class EnhancedStableDiffusionSafetyChecker(
safety_checker_mod.StableDiffusionSafetyChecker
):
@torch.no_grad()
def forward(self, clip_input):
pooled_output = self.vision_model(clip_input)[1]
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = (
safety_checker_mod.cosine_distance(image_embeds, self.special_care_embeds)
.cpu()
.numpy()
)
cos_dist = (
safety_checker_mod.cosine_distance(image_embeds, self.concept_embeds)
.cpu()
.numpy()
)
safety_results = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
safety_result = SafetyResult()
for concet_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concet_idx]
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
safety_result.add_special_care_score(
concet_idx, concept_cos, concept_threshold
)
for concet_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concet_idx]
concept_threshold = self.concept_embeds_weights[concet_idx].item()
safety_result.add_nsfw_score(concet_idx, concept_cos, concept_threshold)
safety_results.append(safety_result)
return safety_results
@lru_cache()
def safety_models():
safety_model_id = "CompVis/stable-diffusion-safety-checker"
monkeypatch_safety_cosine_distance()
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_checker = EnhancedStableDiffusionSafetyChecker.from_pretrained(
safety_model_id
)
return safety_feature_extractor, safety_checker
@ -32,13 +140,28 @@ def monkeypatch_safety_cosine_distance():
safety_checker_mod.cosine_distance = cosine_distance_float32
def is_nsfw(img):
_CONCEPT_DESCRIPTIONS = []
_SPECIAL_CARE_DESCRIPTIONS = []
def create_safety_score(img, safety_mode=SafetyMode.STRICT):
safety_feature_extractor, safety_checker = safety_models()
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
clip_input = safety_checker_input.pixel_values
_, has_nsfw_concept = safety_checker(
images=[np.empty((2, 2))], clip_input=clip_input
)
safety_result = safety_checker(clip_input)[0]
if safety_result.is_special_care_nsfw:
img.paste((150, 0, 0), (0, 0, img.size[0], img.size[1]))
safety_result.is_filtered = True
logger.info(
f" ⚠️🔞️ Filtering NSFW image. nsfw score: {safety_result.nsfw_score}"
)
elif safety_mode == SafetyMode.STRICT and safety_result.is_nsfw:
img.paste((50, 0, 0), (0, 0, img.size[0], img.size[1]))
safety_result.is_filtered = True
logger.info(
f" ⚠️ Filtering NSFW image. nsfw score: {safety_result.nsfw_score}"
)
return has_nsfw_concept[0]
return safety_result

@ -191,6 +191,7 @@ class ImagineResult:
img,
prompt: ImaginePrompt,
is_nsfw,
safety_score,
upscaled_img=None,
modified_original=None,
mask_binary=None,
@ -217,6 +218,7 @@ class ImagineResult:
self.upscaled_img = upscaled_img
self.is_nsfw = is_nsfw
self.safety_score = safety_score
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
self.torch_backend = get_device()
self.hardware_name = get_hardware_description(get_device())

@ -146,7 +146,10 @@ def test_clip_mask_parser(mask_text, expected):
def test_describe_picture():
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
caption = generate_caption(img)
assert caption == "a painting of a girl with a pearl ear"
assert (
caption
== "a painting of a girl with a pearl earring wearing a yellow dress and a pearl earring in her ear and a black background"
)
def test_clip_text_comparison():

@ -1,13 +1,15 @@
from PIL import Image
from imaginairy.safety import is_nsfw
from imaginairy.safety import create_safety_score
from tests import TESTS_FOLDER
def test_is_nsfw():
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
assert is_nsfw(img)
safety_score = create_safety_score(img)
assert safety_score.is_nsfw
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
assert not is_nsfw(img)
safety_score = create_safety_score(img)
assert not safety_score.is_nsfw

Loading…
Cancel
Save