2022-09-11 07:35:57 +00:00
|
|
|
from functools import lru_cache
|
|
|
|
|
2022-09-17 22:49:38 +00:00
|
|
|
import numpy as np
|
2022-09-11 07:35:57 +00:00
|
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
|
|
StableDiffusionSafetyChecker,
|
|
|
|
)
|
|
|
|
from transformers import AutoFeatureExtractor
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache()
|
|
|
|
def safety_models():
|
|
|
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
|
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
|
|
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
|
|
|
return safety_feature_extractor, safety_checker
|
|
|
|
|
|
|
|
|
2022-09-22 17:56:18 +00:00
|
|
|
def is_nsfw(img):
|
2022-09-11 07:35:57 +00:00
|
|
|
safety_feature_extractor, safety_checker = safety_models()
|
|
|
|
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
|
2022-09-12 04:32:11 +00:00
|
|
|
clip_input = safety_checker_input.pixel_values
|
|
|
|
|
2022-09-11 07:35:57 +00:00
|
|
|
_, has_nsfw_concept = safety_checker(
|
2022-09-17 22:49:38 +00:00
|
|
|
images=[np.empty((2, 2))], clip_input=clip_input
|
2022-09-11 07:35:57 +00:00
|
|
|
)
|
|
|
|
return has_nsfw_concept[0]
|