2022-10-17 02:46:32 +00:00
|
|
|
import pytest
|
2022-09-17 22:49:38 +00:00
|
|
|
from PIL import Image
|
|
|
|
|
2022-10-10 08:22:11 +00:00
|
|
|
from imaginairy.safety import create_safety_score
|
2022-10-17 02:46:32 +00:00
|
|
|
from imaginairy.utils import get_device
|
2022-09-17 22:49:38 +00:00
|
|
|
from tests import TESTS_FOLDER
|
|
|
|
|
|
|
|
|
2022-10-17 02:46:32 +00:00
|
|
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
2022-09-17 22:49:38 +00:00
|
|
|
def test_is_nsfw():
|
|
|
|
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
|
2022-09-25 02:42:54 +00:00
|
|
|
|
2022-10-10 08:22:11 +00:00
|
|
|
safety_score = create_safety_score(img)
|
|
|
|
assert safety_score.is_nsfw
|
2022-09-17 22:49:38 +00:00
|
|
|
|
|
|
|
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
|
2022-10-10 08:22:11 +00:00
|
|
|
safety_score = create_safety_score(img)
|
|
|
|
assert not safety_score.is_nsfw
|
2023-01-16 22:45:17 +00:00
|
|
|
|
|
|
|
img = Image.open(f"{TESTS_FOLDER}/data/black_square.jpg")
|
|
|
|
|
|
|
|
safety_score = create_safety_score(img)
|
|
|
|
assert not safety_score.is_nsfw
|