mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
feature: improved safety filter
- provides more informative logs - provides a detailed safety score object - adds non-bypassable filter for extreme content
This commit is contained in:
parent
0db5c329bb
commit
6a80759016
@ -213,6 +213,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
|||||||
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
|
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
|
||||||
|
|
||||||
## ChangeLog
|
## ChangeLog
|
||||||
|
**3.0.0**
|
||||||
|
- feature: improved safety filter
|
||||||
|
|
||||||
**2.4.0**
|
**2.4.0**
|
||||||
- 🎉 feature: prompt expansion
|
- 🎉 feature: prompt expansion
|
||||||
|
@ -24,7 +24,7 @@ from imaginairy.img_log import (
|
|||||||
log_latent,
|
log_latent,
|
||||||
)
|
)
|
||||||
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
||||||
from imaginairy.safety import is_nsfw
|
from imaginairy.safety import SafetyMode, create_safety_score
|
||||||
from imaginairy.samplers.base import get_sampler
|
from imaginairy.samplers.base import get_sampler
|
||||||
from imaginairy.samplers.plms import PLMSSchedule
|
from imaginairy.samplers.plms import PLMSSchedule
|
||||||
from imaginairy.schema import ImaginePrompt, ImagineResult
|
from imaginairy.schema import ImaginePrompt, ImagineResult
|
||||||
@ -40,16 +40,14 @@ LIB_PATH = os.path.dirname(__file__)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SafetyMode:
|
|
||||||
DISABLED = "disabled"
|
|
||||||
CLASSIFY = "classify"
|
|
||||||
FILTER = "filter"
|
|
||||||
|
|
||||||
|
|
||||||
# leave undocumented. I'd ask that no one publicize this flag. Just want a
|
# leave undocumented. I'd ask that no one publicize this flag. Just want a
|
||||||
# slight barrier to entry. Please don't use this is any way that's gonna cause
|
# slight barrier to entry. Please don't use this is any way that's gonna cause
|
||||||
# the press or governments to freak out about AI...
|
# the media or politicians to freak out about AI...
|
||||||
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER)
|
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT)
|
||||||
|
if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
|
||||||
|
IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED
|
||||||
|
elif IMAGINAIRY_SAFETY_MODE == "filter":
|
||||||
|
IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT
|
||||||
|
|
||||||
DEFAULT_MODEL_WEIGHTS_LOCATION = (
|
DEFAULT_MODEL_WEIGHTS_LOCATION = (
|
||||||
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
|
"https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
|
||||||
@ -68,8 +66,7 @@ def load_model_from_config(
|
|||||||
ckpt_path = cached_path(model_weights_location)
|
ckpt_path = cached_path(model_weights_location)
|
||||||
else:
|
else:
|
||||||
ckpt_path = model_weights_location
|
ckpt_path = model_weights_location
|
||||||
logger.info(f"Loading model onto {get_device()} backend...")
|
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
|
||||||
logger.debug(f"Loading model from {ckpt_path}")
|
|
||||||
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
||||||
@ -352,16 +349,16 @@ def imagine(
|
|||||||
|
|
||||||
upscaled_img = None
|
upscaled_img = None
|
||||||
rebuilt_orig_img = None
|
rebuilt_orig_img = None
|
||||||
is_nsfw_img = None
|
|
||||||
if add_caption:
|
if add_caption:
|
||||||
caption = generate_caption(img)
|
caption = generate_caption(img)
|
||||||
logger.info(f" Generated caption: {caption}")
|
logger.info(f" Generated caption: {caption}")
|
||||||
if IMAGINAIRY_SAFETY_MODE != SafetyMode.DISABLED:
|
|
||||||
is_nsfw_img = is_nsfw(img)
|
|
||||||
if is_nsfw_img and IMAGINAIRY_SAFETY_MODE == SafetyMode.FILTER:
|
|
||||||
logger.info(" ⚠️ Filtering NSFW image")
|
|
||||||
img = img.filter(ImageFilter.GaussianBlur(radius=40))
|
|
||||||
|
|
||||||
|
safety_score = create_safety_score(
|
||||||
|
img,
|
||||||
|
safety_mode=IMAGINAIRY_SAFETY_MODE,
|
||||||
|
)
|
||||||
|
if not safety_score.is_filtered:
|
||||||
if prompt.fix_faces:
|
if prompt.fix_faces:
|
||||||
logger.info(" Fixing 😊 's in 🖼 using CodeFormer...")
|
logger.info(" Fixing 😊 's in 🖼 using CodeFormer...")
|
||||||
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
|
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
|
||||||
@ -386,7 +383,8 @@ def imagine(
|
|||||||
)
|
)
|
||||||
|
|
||||||
mask_for_orig_size = mask_image_orig.resize(
|
mask_for_orig_size = mask_image_orig.resize(
|
||||||
prompt.init_image.size, resample=Image.Resampling.LANCZOS
|
prompt.init_image.size,
|
||||||
|
resample=Image.Resampling.LANCZOS,
|
||||||
)
|
)
|
||||||
mask_for_orig_size = mask_for_orig_size.filter(
|
mask_for_orig_size = mask_for_orig_size.filter(
|
||||||
ImageFilter.GaussianBlur(radius=5)
|
ImageFilter.GaussianBlur(radius=5)
|
||||||
@ -404,7 +402,8 @@ def imagine(
|
|||||||
img=img,
|
img=img,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
upscaled_img=upscaled_img,
|
upscaled_img=upscaled_img,
|
||||||
is_nsfw=is_nsfw_img,
|
is_nsfw=safety_score.is_nsfw,
|
||||||
|
safety_score=safety_score,
|
||||||
modified_original=rebuilt_orig_img,
|
modified_original=rebuilt_orig_img,
|
||||||
mask_binary=mask_image_orig,
|
mask_binary=mask_image_orig,
|
||||||
mask_grayscale=mask_grayscale,
|
mask_grayscale=mask_grayscale,
|
||||||
|
@ -1,20 +1,128 @@
|
|||||||
|
import logging
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.pipelines.stable_diffusion import safety_checker as safety_checker_mod
|
from diffusers.pipelines.stable_diffusion import safety_checker as safety_checker_mod
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
)
|
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyMode:
|
||||||
|
STRICT = "strict"
|
||||||
|
RELAXED = "relaxed"
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyResult:
|
||||||
|
# increase this value to create a stronger `nfsw` filter
|
||||||
|
# at the cost of increasing the possibility of filtering benign images
|
||||||
|
_default_adjustment = 0.0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.nsfw_scores = {}
|
||||||
|
self.special_care_scores = {}
|
||||||
|
self.is_filtered = False
|
||||||
|
|
||||||
|
def add_special_care_score(self, concept_idx, abs_score, threshold):
|
||||||
|
adjustment = self._default_adjustment
|
||||||
|
adjusted_score = round(abs_score - threshold + adjustment, 3)
|
||||||
|
try:
|
||||||
|
score_name = _SPECIAL_CARE_DESCRIPTIONS[concept_idx]
|
||||||
|
except LookupError:
|
||||||
|
score_name = ""
|
||||||
|
if adjusted_score > 0 and score_name:
|
||||||
|
logger.debug(
|
||||||
|
f" 🔞🔞 '{score_name}' abs:{abs_score:.3f} adj:{adjusted_score}"
|
||||||
|
)
|
||||||
|
self.special_care_scores[concept_idx] = adjusted_score
|
||||||
|
|
||||||
|
def add_nsfw_score(self, concept_idx, abs_score, threshold):
|
||||||
|
if len(self.special_care_scores) != 3:
|
||||||
|
raise ValueError("special care scores must be set first")
|
||||||
|
adjustment = self._default_adjustment
|
||||||
|
if self.special_care_score > 0:
|
||||||
|
adjustment += 0.01
|
||||||
|
adjusted_score = round(abs_score - threshold + adjustment, 3)
|
||||||
|
try:
|
||||||
|
score_name = _CONCEPT_DESCRIPTIONS[concept_idx]
|
||||||
|
except LookupError:
|
||||||
|
score_name = ""
|
||||||
|
if adjusted_score > 0 and score_name:
|
||||||
|
logger.debug(
|
||||||
|
f" 🔞 '{concept_idx}:{score_name}' abs:{abs_score:.3f} adj:{adjusted_score}"
|
||||||
|
)
|
||||||
|
self.nsfw_scores[concept_idx] = adjusted_score
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nsfw_score(self):
|
||||||
|
return max(self.nsfw_scores.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def special_care_score(self):
|
||||||
|
return max(self.special_care_scores.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def special_care_nsfw_score(self):
|
||||||
|
return min(self.nsfw_score, self.special_care_score)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_nsfw(self):
|
||||||
|
return self.nsfw_score > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_special_care_nsfw(self):
|
||||||
|
return self.special_care_nsfw_score > 0
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedStableDiffusionSafetyChecker(
|
||||||
|
safety_checker_mod.StableDiffusionSafetyChecker
|
||||||
|
):
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, clip_input):
|
||||||
|
pooled_output = self.vision_model(clip_input)[1]
|
||||||
|
image_embeds = self.visual_projection(pooled_output)
|
||||||
|
|
||||||
|
special_cos_dist = (
|
||||||
|
safety_checker_mod.cosine_distance(image_embeds, self.special_care_embeds)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
cos_dist = (
|
||||||
|
safety_checker_mod.cosine_distance(image_embeds, self.concept_embeds)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
safety_results = []
|
||||||
|
batch_size = image_embeds.shape[0]
|
||||||
|
for i in range(batch_size):
|
||||||
|
safety_result = SafetyResult()
|
||||||
|
|
||||||
|
for concet_idx in range(len(special_cos_dist[0])):
|
||||||
|
concept_cos = special_cos_dist[i][concet_idx]
|
||||||
|
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
|
||||||
|
safety_result.add_special_care_score(
|
||||||
|
concet_idx, concept_cos, concept_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
for concet_idx in range(len(cos_dist[0])):
|
||||||
|
concept_cos = cos_dist[i][concet_idx]
|
||||||
|
concept_threshold = self.concept_embeds_weights[concet_idx].item()
|
||||||
|
safety_result.add_nsfw_score(concet_idx, concept_cos, concept_threshold)
|
||||||
|
|
||||||
|
safety_results.append(safety_result)
|
||||||
|
|
||||||
|
return safety_results
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def safety_models():
|
def safety_models():
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
monkeypatch_safety_cosine_distance()
|
monkeypatch_safety_cosine_distance()
|
||||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
safety_checker = EnhancedStableDiffusionSafetyChecker.from_pretrained(
|
||||||
|
safety_model_id
|
||||||
|
)
|
||||||
return safety_feature_extractor, safety_checker
|
return safety_feature_extractor, safety_checker
|
||||||
|
|
||||||
|
|
||||||
@ -32,13 +140,28 @@ def monkeypatch_safety_cosine_distance():
|
|||||||
safety_checker_mod.cosine_distance = cosine_distance_float32
|
safety_checker_mod.cosine_distance = cosine_distance_float32
|
||||||
|
|
||||||
|
|
||||||
def is_nsfw(img):
|
_CONCEPT_DESCRIPTIONS = []
|
||||||
|
_SPECIAL_CARE_DESCRIPTIONS = []
|
||||||
|
|
||||||
|
|
||||||
|
def create_safety_score(img, safety_mode=SafetyMode.STRICT):
|
||||||
safety_feature_extractor, safety_checker = safety_models()
|
safety_feature_extractor, safety_checker = safety_models()
|
||||||
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
|
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
|
||||||
clip_input = safety_checker_input.pixel_values
|
clip_input = safety_checker_input.pixel_values
|
||||||
|
|
||||||
_, has_nsfw_concept = safety_checker(
|
safety_result = safety_checker(clip_input)[0]
|
||||||
images=[np.empty((2, 2))], clip_input=clip_input
|
|
||||||
|
if safety_result.is_special_care_nsfw:
|
||||||
|
img.paste((150, 0, 0), (0, 0, img.size[0], img.size[1]))
|
||||||
|
safety_result.is_filtered = True
|
||||||
|
logger.info(
|
||||||
|
f" ⚠️🔞️ Filtering NSFW image. nsfw score: {safety_result.nsfw_score}"
|
||||||
|
)
|
||||||
|
elif safety_mode == SafetyMode.STRICT and safety_result.is_nsfw:
|
||||||
|
img.paste((50, 0, 0), (0, 0, img.size[0], img.size[1]))
|
||||||
|
safety_result.is_filtered = True
|
||||||
|
logger.info(
|
||||||
|
f" ⚠️ Filtering NSFW image. nsfw score: {safety_result.nsfw_score}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return has_nsfw_concept[0]
|
return safety_result
|
||||||
|
@ -191,6 +191,7 @@ class ImagineResult:
|
|||||||
img,
|
img,
|
||||||
prompt: ImaginePrompt,
|
prompt: ImaginePrompt,
|
||||||
is_nsfw,
|
is_nsfw,
|
||||||
|
safety_score,
|
||||||
upscaled_img=None,
|
upscaled_img=None,
|
||||||
modified_original=None,
|
modified_original=None,
|
||||||
mask_binary=None,
|
mask_binary=None,
|
||||||
@ -217,6 +218,7 @@ class ImagineResult:
|
|||||||
self.upscaled_img = upscaled_img
|
self.upscaled_img = upscaled_img
|
||||||
|
|
||||||
self.is_nsfw = is_nsfw
|
self.is_nsfw = is_nsfw
|
||||||
|
self.safety_score = safety_score
|
||||||
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
|
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||||
self.torch_backend = get_device()
|
self.torch_backend = get_device()
|
||||||
self.hardware_name = get_hardware_description(get_device())
|
self.hardware_name = get_hardware_description(get_device())
|
||||||
|
@ -146,7 +146,10 @@ def test_clip_mask_parser(mask_text, expected):
|
|||||||
def test_describe_picture():
|
def test_describe_picture():
|
||||||
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
||||||
caption = generate_caption(img)
|
caption = generate_caption(img)
|
||||||
assert caption == "a painting of a girl with a pearl ear"
|
assert (
|
||||||
|
caption
|
||||||
|
== "a painting of a girl with a pearl earring wearing a yellow dress and a pearl earring in her ear and a black background"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_clip_text_comparison():
|
def test_clip_text_comparison():
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from imaginairy.safety import is_nsfw
|
from imaginairy.safety import create_safety_score
|
||||||
from tests import TESTS_FOLDER
|
from tests import TESTS_FOLDER
|
||||||
|
|
||||||
|
|
||||||
def test_is_nsfw():
|
def test_is_nsfw():
|
||||||
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
|
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
|
||||||
|
|
||||||
assert is_nsfw(img)
|
safety_score = create_safety_score(img)
|
||||||
|
assert safety_score.is_nsfw
|
||||||
|
|
||||||
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
||||||
assert not is_nsfw(img)
|
safety_score = create_safety_score(img)
|
||||||
|
assert not safety_score.is_nsfw
|
||||||
|
Loading…
Reference in New Issue
Block a user