mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
fix: don't report a safety issue when a black image is generated
This commit is contained in:
parent
2dd99c183c
commit
24e10f9e5f
17
imaginairy/enhancers/blur_detect.py
Normal file
17
imaginairy/enhancers/blur_detect.py
Normal file
@ -0,0 +1,17 @@
|
||||
import cv2
|
||||
|
||||
from imaginairy.img_utils import pillow_img_to_opencv_img
|
||||
|
||||
|
||||
def calculate_blurriness_level(img):
|
||||
img = pillow_img_to_opencv_img(img)
|
||||
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
sharpness = cv2.Laplacian(gray, cv2.CV_64F).var()
|
||||
sharpness = max(sharpness, 0.000001)
|
||||
bluriness = 1 / sharpness
|
||||
return bluriness
|
||||
|
||||
|
||||
def is_blurry(img, threshold=0.91):
|
||||
return calculate_blurriness_level(img) > threshold
|
@ -5,6 +5,8 @@ import torch
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker as safety_checker_mod
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
from imaginairy.enhancers.blur_detect import is_blurry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -145,6 +147,14 @@ _SPECIAL_CARE_DESCRIPTIONS = []
|
||||
|
||||
|
||||
def create_safety_score(img, safety_mode=SafetyMode.STRICT):
|
||||
if is_blurry(img):
|
||||
sr = SafetyResult()
|
||||
sr.add_special_care_score(0, 0, 1)
|
||||
sr.add_special_care_score(1, 0, 1)
|
||||
sr.add_special_care_score(2, 0, 1)
|
||||
sr.add_nsfw_score(0, 0, 1)
|
||||
return sr
|
||||
|
||||
safety_feature_extractor, safety_checker = safety_models()
|
||||
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
|
||||
clip_input = safety_checker_input.pixel_values
|
||||
|
BIN
tests/data/black_square.jpg
Normal file
BIN
tests/data/black_square.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.4 KiB |
BIN
tests/data/latent_noise.jpg
Normal file
BIN
tests/data/latent_noise.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 163 KiB |
19
tests/enhancers/test_blur_detect.py
Normal file
19
tests/enhancers/test_blur_detect.py
Normal file
@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy.enhancers.blur_detect import is_blurry
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
blur_params = [
|
||||
(f"{TESTS_FOLDER}/data/black_square.jpg", True),
|
||||
(f"{TESTS_FOLDER}/data/safety.jpg", False),
|
||||
(f"{TESTS_FOLDER}/data/latent_noise.jpg", False),
|
||||
(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg", False),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("img_path,expected", blur_params)
|
||||
def test_calculate_blurriness_level(img_path, expected):
|
||||
img = Image.open(img_path)
|
||||
|
||||
assert is_blurry(img) == expected
|
@ -16,3 +16,8 @@ def test_is_nsfw():
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
||||
safety_score = create_safety_score(img)
|
||||
assert not safety_score.is_nsfw
|
||||
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/black_square.jpg")
|
||||
|
||||
safety_score = create_safety_score(img)
|
||||
assert not safety_score.is_nsfw
|
||||
|
Loading…
Reference in New Issue
Block a user