fix: don't report a safety issue when a black image is generated

This commit is contained in:
Bryce 2023-01-16 14:45:17 -08:00 committed by Bryce Drennan
parent 2dd99c183c
commit 24e10f9e5f
6 changed files with 51 additions and 0 deletions

View File

@ -0,0 +1,17 @@
import cv2
from imaginairy.img_utils import pillow_img_to_opencv_img
def calculate_blurriness_level(img):
img = pillow_img_to_opencv_img(img)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
sharpness = cv2.Laplacian(gray, cv2.CV_64F).var()
sharpness = max(sharpness, 0.000001)
bluriness = 1 / sharpness
return bluriness
def is_blurry(img, threshold=0.91):
return calculate_blurriness_level(img) > threshold

View File

@ -5,6 +5,8 @@ import torch
from diffusers.pipelines.stable_diffusion import safety_checker as safety_checker_mod
from transformers import AutoFeatureExtractor
from imaginairy.enhancers.blur_detect import is_blurry
logger = logging.getLogger(__name__)
@ -145,6 +147,14 @@ _SPECIAL_CARE_DESCRIPTIONS = []
def create_safety_score(img, safety_mode=SafetyMode.STRICT):
if is_blurry(img):
sr = SafetyResult()
sr.add_special_care_score(0, 0, 1)
sr.add_special_care_score(1, 0, 1)
sr.add_special_care_score(2, 0, 1)
sr.add_nsfw_score(0, 0, 1)
return sr
safety_feature_extractor, safety_checker = safety_models()
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
clip_input = safety_checker_input.pixel_values

BIN
tests/data/black_square.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 KiB

BIN
tests/data/latent_noise.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 163 KiB

View File

@ -0,0 +1,19 @@
import pytest
from PIL import Image
from imaginairy.enhancers.blur_detect import is_blurry
from tests import TESTS_FOLDER
blur_params = [
(f"{TESTS_FOLDER}/data/black_square.jpg", True),
(f"{TESTS_FOLDER}/data/safety.jpg", False),
(f"{TESTS_FOLDER}/data/latent_noise.jpg", False),
(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg", False),
]
@pytest.mark.parametrize("img_path,expected", blur_params)
def test_calculate_blurriness_level(img_path, expected):
img = Image.open(img_path)
assert is_blurry(img) == expected

View File

@ -16,3 +16,8 @@ def test_is_nsfw():
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
safety_score = create_safety_score(img)
assert not safety_score.is_nsfw
img = Image.open(f"{TESTS_FOLDER}/data/black_square.jpg")
safety_score = create_safety_score(img)
assert not safety_score.is_nsfw