You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/safety.py

24 lines
832 B
Python

from functools import lru_cache
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from transformers import AutoFeatureExtractor
@lru_cache()
def safety_models():
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
return safety_feature_extractor, safety_checker
def is_nsfw(img, x_sample):
safety_feature_extractor, safety_checker = safety_models()
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
_, has_nsfw_concept = safety_checker(
images=x_sample[None, :], clip_input=safety_checker_input.pixel_values
)
return has_nsfw_concept[0]