diff --git a/README.md b/README.md index 944b91c..aede45c 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - [Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing) ## ChangeLog +**3.0.0** +- feature: improved safety filter **2.4.0** - 🎉 feature: prompt expansion diff --git a/imaginairy/api.py b/imaginairy/api.py index 2a8059f..c6f209a 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -24,7 +24,7 @@ from imaginairy.img_log import ( log_latent, ) from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image -from imaginairy.safety import is_nsfw +from imaginairy.safety import SafetyMode, create_safety_score from imaginairy.samplers.base import get_sampler from imaginairy.samplers.plms import PLMSSchedule from imaginairy.schema import ImaginePrompt, ImagineResult @@ -40,16 +40,14 @@ LIB_PATH = os.path.dirname(__file__) logger = logging.getLogger(__name__) -class SafetyMode: - DISABLED = "disabled" - CLASSIFY = "classify" - FILTER = "filter" - - # leave undocumented. I'd ask that no one publicize this flag. Just want a # slight barrier to entry. Please don't use this is any way that's gonna cause -# the press or governments to freak out about AI... -IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER) +# the media or politicians to freak out about AI... +IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT) +if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}: + IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED +elif IMAGINAIRY_SAFETY_MODE == "filter": + IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT DEFAULT_MODEL_WEIGHTS_LOCATION = ( "https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media" @@ -68,8 +66,7 @@ def load_model_from_config( ckpt_path = cached_path(model_weights_location) else: ckpt_path = model_weights_location - logger.info(f"Loading model onto {get_device()} backend...") - logger.debug(f"Loading model from {ckpt_path}") + logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...") pl_sd = torch.load(ckpt_path, map_location="cpu") if "global_step" in pl_sd: logger.debug(f"Global Step: {pl_sd['global_step']}") @@ -352,59 +349,61 @@ def imagine( upscaled_img = None rebuilt_orig_img = None - is_nsfw_img = None + if add_caption: caption = generate_caption(img) logger.info(f" Generated caption: {caption}") - if IMAGINAIRY_SAFETY_MODE != SafetyMode.DISABLED: - is_nsfw_img = is_nsfw(img) - if is_nsfw_img and IMAGINAIRY_SAFETY_MODE == SafetyMode.FILTER: - logger.info(" ⚠️ Filtering NSFW image") - img = img.filter(ImageFilter.GaussianBlur(radius=40)) - - if prompt.fix_faces: - logger.info(" Fixing 😊 's in 🖼 using CodeFormer...") - img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity) - if prompt.upscale: - logger.info(" Upscaling 🖼 using real-ESRGAN...") - upscaled_img = upscale_image(img) - - # put the newly generated patch back into the original, full size image - if ( - prompt.mask_modify_original - and mask_image_orig - and prompt.init_image - ): - img_to_add_back_to_original = ( - upscaled_img if upscaled_img else img - ) - img_to_add_back_to_original = ( - img_to_add_back_to_original.resize( + + safety_score = create_safety_score( + img, + safety_mode=IMAGINAIRY_SAFETY_MODE, + ) + if not safety_score.is_filtered: + if prompt.fix_faces: + logger.info(" Fixing 😊 's in 🖼 using CodeFormer...") + img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity) + if prompt.upscale: + logger.info(" Upscaling 🖼 using real-ESRGAN...") + upscaled_img = upscale_image(img) + + # put the newly generated patch back into the original, full size image + if ( + prompt.mask_modify_original + and mask_image_orig + and prompt.init_image + ): + img_to_add_back_to_original = ( + upscaled_img if upscaled_img else img + ) + img_to_add_back_to_original = ( + img_to_add_back_to_original.resize( + prompt.init_image.size, + resample=Image.Resampling.LANCZOS, + ) + ) + + mask_for_orig_size = mask_image_orig.resize( prompt.init_image.size, resample=Image.Resampling.LANCZOS, ) - ) - - mask_for_orig_size = mask_image_orig.resize( - prompt.init_image.size, resample=Image.Resampling.LANCZOS - ) - mask_for_orig_size = mask_for_orig_size.filter( - ImageFilter.GaussianBlur(radius=5) - ) - log_img(mask_for_orig_size, "mask for original image size") + mask_for_orig_size = mask_for_orig_size.filter( + ImageFilter.GaussianBlur(radius=5) + ) + log_img(mask_for_orig_size, "mask for original image size") - rebuilt_orig_img = Image.composite( - prompt.init_image, - img_to_add_back_to_original, - mask_for_orig_size, - ) - log_img(rebuilt_orig_img, "reconstituted original") + rebuilt_orig_img = Image.composite( + prompt.init_image, + img_to_add_back_to_original, + mask_for_orig_size, + ) + log_img(rebuilt_orig_img, "reconstituted original") yield ImagineResult( img=img, prompt=prompt, upscaled_img=upscaled_img, - is_nsfw=is_nsfw_img, + is_nsfw=safety_score.is_nsfw, + safety_score=safety_score, modified_original=rebuilt_orig_img, mask_binary=mask_image_orig, mask_grayscale=mask_grayscale, diff --git a/imaginairy/safety.py b/imaginairy/safety.py index 1808dd9..9c3270f 100644 --- a/imaginairy/safety.py +++ b/imaginairy/safety.py @@ -1,20 +1,128 @@ +import logging from functools import lru_cache -import numpy as np import torch from diffusers.pipelines.stable_diffusion import safety_checker as safety_checker_mod -from diffusers.pipelines.stable_diffusion.safety_checker import ( - StableDiffusionSafetyChecker, -) from transformers import AutoFeatureExtractor +logger = logging.getLogger(__name__) + + +class SafetyMode: + STRICT = "strict" + RELAXED = "relaxed" + + +class SafetyResult: + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + _default_adjustment = 0.0 + + def __init__(self): + self.nsfw_scores = {} + self.special_care_scores = {} + self.is_filtered = False + + def add_special_care_score(self, concept_idx, abs_score, threshold): + adjustment = self._default_adjustment + adjusted_score = round(abs_score - threshold + adjustment, 3) + try: + score_name = _SPECIAL_CARE_DESCRIPTIONS[concept_idx] + except LookupError: + score_name = "" + if adjusted_score > 0 and score_name: + logger.debug( + f" 🔞🔞 '{score_name}' abs:{abs_score:.3f} adj:{adjusted_score}" + ) + self.special_care_scores[concept_idx] = adjusted_score + + def add_nsfw_score(self, concept_idx, abs_score, threshold): + if len(self.special_care_scores) != 3: + raise ValueError("special care scores must be set first") + adjustment = self._default_adjustment + if self.special_care_score > 0: + adjustment += 0.01 + adjusted_score = round(abs_score - threshold + adjustment, 3) + try: + score_name = _CONCEPT_DESCRIPTIONS[concept_idx] + except LookupError: + score_name = "" + if adjusted_score > 0 and score_name: + logger.debug( + f" 🔞 '{concept_idx}:{score_name}' abs:{abs_score:.3f} adj:{adjusted_score}" + ) + self.nsfw_scores[concept_idx] = adjusted_score + + @property + def nsfw_score(self): + return max(self.nsfw_scores.values()) + + @property + def special_care_score(self): + return max(self.special_care_scores.values()) + + @property + def special_care_nsfw_score(self): + return min(self.nsfw_score, self.special_care_score) + + @property + def is_nsfw(self): + return self.nsfw_score > 0 + + @property + def is_special_care_nsfw(self): + return self.special_care_nsfw_score > 0 + + +class EnhancedStableDiffusionSafetyChecker( + safety_checker_mod.StableDiffusionSafetyChecker +): + @torch.no_grad() + def forward(self, clip_input): + pooled_output = self.vision_model(clip_input)[1] + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = ( + safety_checker_mod.cosine_distance(image_embeds, self.special_care_embeds) + .cpu() + .numpy() + ) + cos_dist = ( + safety_checker_mod.cosine_distance(image_embeds, self.concept_embeds) + .cpu() + .numpy() + ) + + safety_results = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + safety_result = SafetyResult() + + for concet_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concet_idx] + concept_threshold = self.special_care_embeds_weights[concet_idx].item() + safety_result.add_special_care_score( + concet_idx, concept_cos, concept_threshold + ) + + for concet_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concet_idx] + concept_threshold = self.concept_embeds_weights[concet_idx].item() + safety_result.add_nsfw_score(concet_idx, concept_cos, concept_threshold) + + safety_results.append(safety_result) + + return safety_results + @lru_cache() def safety_models(): safety_model_id = "CompVis/stable-diffusion-safety-checker" monkeypatch_safety_cosine_distance() safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + safety_checker = EnhancedStableDiffusionSafetyChecker.from_pretrained( + safety_model_id + ) return safety_feature_extractor, safety_checker @@ -32,13 +140,28 @@ def monkeypatch_safety_cosine_distance(): safety_checker_mod.cosine_distance = cosine_distance_float32 -def is_nsfw(img): +_CONCEPT_DESCRIPTIONS = [] +_SPECIAL_CARE_DESCRIPTIONS = [] + + +def create_safety_score(img, safety_mode=SafetyMode.STRICT): safety_feature_extractor, safety_checker = safety_models() safety_checker_input = safety_feature_extractor([img], return_tensors="pt") clip_input = safety_checker_input.pixel_values - _, has_nsfw_concept = safety_checker( - images=[np.empty((2, 2))], clip_input=clip_input - ) + safety_result = safety_checker(clip_input)[0] + + if safety_result.is_special_care_nsfw: + img.paste((150, 0, 0), (0, 0, img.size[0], img.size[1])) + safety_result.is_filtered = True + logger.info( + f" ⚠️🔞️ Filtering NSFW image. nsfw score: {safety_result.nsfw_score}" + ) + elif safety_mode == SafetyMode.STRICT and safety_result.is_nsfw: + img.paste((50, 0, 0), (0, 0, img.size[0], img.size[1])) + safety_result.is_filtered = True + logger.info( + f" ⚠️ Filtering NSFW image. nsfw score: {safety_result.nsfw_score}" + ) - return has_nsfw_concept[0] + return safety_result diff --git a/imaginairy/schema.py b/imaginairy/schema.py index b24ec6f..ff103f6 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -191,6 +191,7 @@ class ImagineResult: img, prompt: ImaginePrompt, is_nsfw, + safety_score, upscaled_img=None, modified_original=None, mask_binary=None, @@ -217,6 +218,7 @@ class ImagineResult: self.upscaled_img = upscaled_img self.is_nsfw = is_nsfw + self.safety_score = safety_score self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc) self.torch_backend = get_device() self.hardware_name = get_hardware_description(get_device()) diff --git a/tests/test_enhancers.py b/tests/test_enhancers.py index 9b5d66e..c0bd815 100644 --- a/tests/test_enhancers.py +++ b/tests/test_enhancers.py @@ -146,7 +146,10 @@ def test_clip_mask_parser(mask_text, expected): def test_describe_picture(): img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg") caption = generate_caption(img) - assert caption == "a painting of a girl with a pearl ear" + assert ( + caption + == "a painting of a girl with a pearl earring wearing a yellow dress and a pearl earring in her ear and a black background" + ) def test_clip_text_comparison(): diff --git a/tests/test_safety.py b/tests/test_safety.py index 6de038f..d421e06 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -1,13 +1,15 @@ from PIL import Image -from imaginairy.safety import is_nsfw +from imaginairy.safety import create_safety_score from tests import TESTS_FOLDER def test_is_nsfw(): img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg") - assert is_nsfw(img) + safety_score = create_safety_score(img) + assert safety_score.is_nsfw img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg") - assert not is_nsfw(img) + safety_score = create_safety_score(img) + assert not safety_score.is_nsfw