You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/tests/test_safety.py

26 lines
729 B
Python

from PIL import Image
from imaginairy.api import load_model
from imaginairy.safety import is_nsfw
from imaginairy.utils import get_device, pillow_img_to_torch_image
from tests import TESTS_FOLDER
def test_is_nsfw():
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
latent = _pil_to_latent(img)
assert is_nsfw(img, latent)
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
latent = _pil_to_latent(img)
assert not is_nsfw(img, latent)
def _pil_to_latent(img):
model = load_model()
model.tile_mode(False)
img = pillow_img_to_torch_image(img)
img = img.to(get_device())
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
return latent