feature: improved safety filter

- provides more informative logs
- provides a detailed safety score object
- adds non-bypassable filter for extreme content
This commit is contained in:
Bryce 2022-10-10 01:22:11 -07:00 committed by Bryce Drennan
parent 0db5c329bb
commit 6a80759016
6 changed files with 195 additions and 64 deletions

View File

@ -213,6 +213,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing) [Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
## ChangeLog ## ChangeLog
**3.0.0**
- feature: improved safety filter
**2.4.0** **2.4.0**
- 🎉 feature: prompt expansion - 🎉 feature: prompt expansion

View File

@ -24,7 +24,7 @@ from imaginairy.img_log import (
log_latent, log_latent,
) )
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
from imaginairy.safety import is_nsfw from imaginairy.safety import SafetyMode, create_safety_score
from imaginairy.samplers.base import get_sampler from imaginairy.samplers.base import get_sampler
from imaginairy.samplers.plms import PLMSSchedule from imaginairy.samplers.plms import PLMSSchedule
from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.schema import ImaginePrompt, ImagineResult
@ -40,16 +40,14 @@ LIB_PATH = os.path.dirname(__file__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SafetyMode:
DISABLED = "disabled"
CLASSIFY = "classify"
FILTER = "filter"
# leave undocumented. I'd ask that no one publicize this flag. Just want a # leave undocumented. I'd ask that no one publicize this flag. Just want a
# slight barrier to entry. Please don't use this is any way that's gonna cause # slight barrier to entry. Please don't use this is any way that's gonna cause
# the press or governments to freak out about AI... # the media or politicians to freak out about AI...
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER) IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT)
if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED
elif IMAGINAIRY_SAFETY_MODE == "filter":
IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT
DEFAULT_MODEL_WEIGHTS_LOCATION = ( DEFAULT_MODEL_WEIGHTS_LOCATION = (
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media" "https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
@ -68,8 +66,7 @@ def load_model_from_config(
ckpt_path = cached_path(model_weights_location) ckpt_path = cached_path(model_weights_location)
else: else:
ckpt_path = model_weights_location ckpt_path = model_weights_location
logger.info(f"Loading model onto {get_device()} backend...") logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
logger.debug(f"Loading model from {ckpt_path}")
pl_sd = torch.load(ckpt_path, map_location="cpu") pl_sd = torch.load(ckpt_path, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
logger.debug(f"Global Step: {pl_sd['global_step']}") logger.debug(f"Global Step: {pl_sd['global_step']}")
@ -352,16 +349,16 @@ def imagine(
upscaled_img = None upscaled_img = None
rebuilt_orig_img = None rebuilt_orig_img = None
is_nsfw_img = None
if add_caption: if add_caption:
caption = generate_caption(img) caption = generate_caption(img)
logger.info(f" Generated caption: {caption}") logger.info(f" Generated caption: {caption}")
if IMAGINAIRY_SAFETY_MODE != SafetyMode.DISABLED:
is_nsfw_img = is_nsfw(img)
if is_nsfw_img and IMAGINAIRY_SAFETY_MODE == SafetyMode.FILTER:
logger.info(" ⚠️ Filtering NSFW image")
img = img.filter(ImageFilter.GaussianBlur(radius=40))
safety_score = create_safety_score(
img,
safety_mode=IMAGINAIRY_SAFETY_MODE,
)
if not safety_score.is_filtered:
if prompt.fix_faces: if prompt.fix_faces:
logger.info(" Fixing 😊 's in 🖼 using CodeFormer...") logger.info(" Fixing 😊 's in 🖼 using CodeFormer...")
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity) img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
@ -386,7 +383,8 @@ def imagine(
) )
mask_for_orig_size = mask_image_orig.resize( mask_for_orig_size = mask_image_orig.resize(
prompt.init_image.size, resample=Image.Resampling.LANCZOS prompt.init_image.size,
resample=Image.Resampling.LANCZOS,
) )
mask_for_orig_size = mask_for_orig_size.filter( mask_for_orig_size = mask_for_orig_size.filter(
ImageFilter.GaussianBlur(radius=5) ImageFilter.GaussianBlur(radius=5)
@ -404,7 +402,8 @@ def imagine(
img=img, img=img,
prompt=prompt, prompt=prompt,
upscaled_img=upscaled_img, upscaled_img=upscaled_img,
is_nsfw=is_nsfw_img, is_nsfw=safety_score.is_nsfw,
safety_score=safety_score,
modified_original=rebuilt_orig_img, modified_original=rebuilt_orig_img,
mask_binary=mask_image_orig, mask_binary=mask_image_orig,
mask_grayscale=mask_grayscale, mask_grayscale=mask_grayscale,

View File

@ -1,20 +1,128 @@
import logging
from functools import lru_cache from functools import lru_cache
import numpy as np
import torch import torch
from diffusers.pipelines.stable_diffusion import safety_checker as safety_checker_mod from diffusers.pipelines.stable_diffusion import safety_checker as safety_checker_mod
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
logger = logging.getLogger(__name__)
class SafetyMode:
STRICT = "strict"
RELAXED = "relaxed"
class SafetyResult:
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
_default_adjustment = 0.0
def __init__(self):
self.nsfw_scores = {}
self.special_care_scores = {}
self.is_filtered = False
def add_special_care_score(self, concept_idx, abs_score, threshold):
adjustment = self._default_adjustment
adjusted_score = round(abs_score - threshold + adjustment, 3)
try:
score_name = _SPECIAL_CARE_DESCRIPTIONS[concept_idx]
except LookupError:
score_name = ""
if adjusted_score > 0 and score_name:
logger.debug(
f" 🔞🔞 '{score_name}' abs:{abs_score:.3f} adj:{adjusted_score}"
)
self.special_care_scores[concept_idx] = adjusted_score
def add_nsfw_score(self, concept_idx, abs_score, threshold):
if len(self.special_care_scores) != 3:
raise ValueError("special care scores must be set first")
adjustment = self._default_adjustment
if self.special_care_score > 0:
adjustment += 0.01
adjusted_score = round(abs_score - threshold + adjustment, 3)
try:
score_name = _CONCEPT_DESCRIPTIONS[concept_idx]
except LookupError:
score_name = ""
if adjusted_score > 0 and score_name:
logger.debug(
f" 🔞 '{concept_idx}:{score_name}' abs:{abs_score:.3f} adj:{adjusted_score}"
)
self.nsfw_scores[concept_idx] = adjusted_score
@property
def nsfw_score(self):
return max(self.nsfw_scores.values())
@property
def special_care_score(self):
return max(self.special_care_scores.values())
@property
def special_care_nsfw_score(self):
return min(self.nsfw_score, self.special_care_score)
@property
def is_nsfw(self):
return self.nsfw_score > 0
@property
def is_special_care_nsfw(self):
return self.special_care_nsfw_score > 0
class EnhancedStableDiffusionSafetyChecker(
safety_checker_mod.StableDiffusionSafetyChecker
):
@torch.no_grad()
def forward(self, clip_input):
pooled_output = self.vision_model(clip_input)[1]
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = (
safety_checker_mod.cosine_distance(image_embeds, self.special_care_embeds)
.cpu()
.numpy()
)
cos_dist = (
safety_checker_mod.cosine_distance(image_embeds, self.concept_embeds)
.cpu()
.numpy()
)
safety_results = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
safety_result = SafetyResult()
for concet_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concet_idx]
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
safety_result.add_special_care_score(
concet_idx, concept_cos, concept_threshold
)
for concet_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concet_idx]
concept_threshold = self.concept_embeds_weights[concet_idx].item()
safety_result.add_nsfw_score(concet_idx, concept_cos, concept_threshold)
safety_results.append(safety_result)
return safety_results
@lru_cache() @lru_cache()
def safety_models(): def safety_models():
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"
monkeypatch_safety_cosine_distance() monkeypatch_safety_cosine_distance()
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) safety_checker = EnhancedStableDiffusionSafetyChecker.from_pretrained(
safety_model_id
)
return safety_feature_extractor, safety_checker return safety_feature_extractor, safety_checker
@ -32,13 +140,28 @@ def monkeypatch_safety_cosine_distance():
safety_checker_mod.cosine_distance = cosine_distance_float32 safety_checker_mod.cosine_distance = cosine_distance_float32
def is_nsfw(img): _CONCEPT_DESCRIPTIONS = []
_SPECIAL_CARE_DESCRIPTIONS = []
def create_safety_score(img, safety_mode=SafetyMode.STRICT):
safety_feature_extractor, safety_checker = safety_models() safety_feature_extractor, safety_checker = safety_models()
safety_checker_input = safety_feature_extractor([img], return_tensors="pt") safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
clip_input = safety_checker_input.pixel_values clip_input = safety_checker_input.pixel_values
_, has_nsfw_concept = safety_checker( safety_result = safety_checker(clip_input)[0]
images=[np.empty((2, 2))], clip_input=clip_input
if safety_result.is_special_care_nsfw:
img.paste((150, 0, 0), (0, 0, img.size[0], img.size[1]))
safety_result.is_filtered = True
logger.info(
f" ⚠️🔞️ Filtering NSFW image. nsfw score: {safety_result.nsfw_score}"
)
elif safety_mode == SafetyMode.STRICT and safety_result.is_nsfw:
img.paste((50, 0, 0), (0, 0, img.size[0], img.size[1]))
safety_result.is_filtered = True
logger.info(
f" ⚠️ Filtering NSFW image. nsfw score: {safety_result.nsfw_score}"
) )
return has_nsfw_concept[0] return safety_result

View File

@ -191,6 +191,7 @@ class ImagineResult:
img, img,
prompt: ImaginePrompt, prompt: ImaginePrompt,
is_nsfw, is_nsfw,
safety_score,
upscaled_img=None, upscaled_img=None,
modified_original=None, modified_original=None,
mask_binary=None, mask_binary=None,
@ -217,6 +218,7 @@ class ImagineResult:
self.upscaled_img = upscaled_img self.upscaled_img = upscaled_img
self.is_nsfw = is_nsfw self.is_nsfw = is_nsfw
self.safety_score = safety_score
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc) self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
self.torch_backend = get_device() self.torch_backend = get_device()
self.hardware_name = get_hardware_description(get_device()) self.hardware_name = get_hardware_description(get_device())

View File

@ -146,7 +146,10 @@ def test_clip_mask_parser(mask_text, expected):
def test_describe_picture(): def test_describe_picture():
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg") img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
caption = generate_caption(img) caption = generate_caption(img)
assert caption == "a painting of a girl with a pearl ear" assert (
caption
== "a painting of a girl with a pearl earring wearing a yellow dress and a pearl earring in her ear and a black background"
)
def test_clip_text_comparison(): def test_clip_text_comparison():

View File

@ -1,13 +1,15 @@
from PIL import Image from PIL import Image
from imaginairy.safety import is_nsfw from imaginairy.safety import create_safety_score
from tests import TESTS_FOLDER from tests import TESTS_FOLDER
def test_is_nsfw(): def test_is_nsfw():
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg") img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
assert is_nsfw(img) safety_score = create_safety_score(img)
assert safety_score.is_nsfw
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg") img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
assert not is_nsfw(img) safety_score = create_safety_score(img)
assert not safety_score.is_nsfw