2022-09-11 07:35:57 +00:00
|
|
|
from functools import lru_cache
|
|
|
|
|
2022-09-17 22:49:38 +00:00
|
|
|
import numpy as np
|
2022-09-25 02:42:54 +00:00
|
|
|
import torch
|
|
|
|
from diffusers.pipelines.stable_diffusion import safety_checker as safety_checker_mod
|
2022-09-11 07:35:57 +00:00
|
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
|
|
StableDiffusionSafetyChecker,
|
|
|
|
)
|
|
|
|
from transformers import AutoFeatureExtractor
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache()
|
|
|
|
def safety_models():
|
|
|
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
2022-09-25 02:42:54 +00:00
|
|
|
monkeypatch_safety_cosine_distance()
|
2022-09-11 07:35:57 +00:00
|
|
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
|
|
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
|
|
|
return safety_feature_extractor, safety_checker
|
|
|
|
|
|
|
|
|
2022-09-25 02:42:54 +00:00
|
|
|
@lru_cache()
|
|
|
|
def monkeypatch_safety_cosine_distance():
|
|
|
|
orig_cosine_distance = safety_checker_mod.cosine_distance
|
|
|
|
|
|
|
|
def cosine_distance_float32(image_embeds, text_embeds):
|
|
|
|
"""
|
|
|
|
In some environments we need to distance to be in float32
|
|
|
|
but it was coming as BFloat16
|
|
|
|
"""
|
|
|
|
return orig_cosine_distance(image_embeds, text_embeds).to(torch.float32)
|
|
|
|
|
|
|
|
safety_checker_mod.cosine_distance = cosine_distance_float32
|
|
|
|
|
|
|
|
|
2022-09-22 17:56:18 +00:00
|
|
|
def is_nsfw(img):
|
2022-09-11 07:35:57 +00:00
|
|
|
safety_feature_extractor, safety_checker = safety_models()
|
|
|
|
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
|
2022-09-12 04:32:11 +00:00
|
|
|
clip_input = safety_checker_input.pixel_values
|
|
|
|
|
2022-09-11 07:35:57 +00:00
|
|
|
_, has_nsfw_concept = safety_checker(
|
2022-09-17 22:49:38 +00:00
|
|
|
images=[np.empty((2, 2))], clip_input=clip_input
|
2022-09-11 07:35:57 +00:00
|
|
|
)
|
2022-09-25 02:42:54 +00:00
|
|
|
|
2022-09-11 07:35:57 +00:00
|
|
|
return has_nsfw_concept[0]
|