imaginAIry/tests/test_safety.py

24 lines
695 B
Python
Raw Normal View History

import pytest
2022-09-17 22:49:38 +00:00
from PIL import Image
from imaginairy.utils import get_device
2023-12-15 21:48:53 +00:00
from imaginairy.utils.safety import create_safety_score
2022-09-17 22:49:38 +00:00
from tests import TESTS_FOLDER
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
2022-09-17 22:49:38 +00:00
def test_is_nsfw():
img = Image.open(f"{TESTS_FOLDER}/data/safety.jpg")
safety_score = create_safety_score(img)
assert safety_score.is_nsfw
2022-09-17 22:49:38 +00:00
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg")
safety_score = create_safety_score(img)
assert not safety_score.is_nsfw
img = Image.open(f"{TESTS_FOLDER}/data/black_square.jpg")
safety_score = create_safety_score(img)
assert not safety_score.is_nsfw