From 6d39d791b13bf6739b88b8e384352343aae4b6dc Mon Sep 17 00:00:00 2001 From: Bryce Date: Fri, 15 Dec 2023 13:48:53 -0800 Subject: [PATCH] refactor: move safety to utils --- imaginairy/api/generate.py | 2 +- imaginairy/api/generate_refiners.py | 2 +- imaginairy/{ => utils}/safety.py | 0 tests/test_safety.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename imaginairy/{ => utils}/safety.py (100%) diff --git a/imaginairy/api/generate.py b/imaginairy/api/generate.py index cd051d2..e826a2c 100755 --- a/imaginairy/api/generate.py +++ b/imaginairy/api/generate.py @@ -238,7 +238,6 @@ def _generate_single_image_compvis( from imaginairy.enhancers.face_restoration_codeformer import enhance_faces from imaginairy.enhancers.upscale_realesrgan import upscale_image from imaginairy.modules.midas.api import torch_image_to_depth_map - from imaginairy.safety import create_safety_score from imaginairy.samplers import SOLVER_LOOKUP from imaginairy.samplers.editing import CFGEditingDenoiser from imaginairy.schema import ControlInput, ImagineResult, MaskMode @@ -264,6 +263,7 @@ def _generate_single_image_compvis( outpaint_arg_str_parse, prepare_image_for_outpaint, ) + from imaginairy.utils.safety import create_safety_score latent_channels = 4 downsampling_factor = 8 diff --git a/imaginairy/api/generate_refiners.py b/imaginairy/api/generate_refiners.py index d0ac026..3b6ae32 100644 --- a/imaginairy/api/generate_refiners.py +++ b/imaginairy/api/generate_refiners.py @@ -35,7 +35,6 @@ def _generate_single_image( from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.face_restoration_codeformer import enhance_faces from imaginairy.enhancers.upscale_realesrgan import upscale_image - from imaginairy.safety import create_safety_score from imaginairy.samplers import SolverName from imaginairy.schema import ImagineResult from imaginairy.utils import get_device, randn_seeded @@ -58,6 +57,7 @@ def _generate_single_image( outpaint_arg_str_parse, prepare_image_for_outpaint, ) + from imaginairy.utils.safety import create_safety_score if dtype is None: dtype = torch.float16 if half_mode else torch.float32 diff --git a/imaginairy/safety.py b/imaginairy/utils/safety.py similarity index 100% rename from imaginairy/safety.py rename to imaginairy/utils/safety.py diff --git a/tests/test_safety.py b/tests/test_safety.py index c5bf480..c2299e5 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -1,8 +1,8 @@ import pytest from PIL import Image -from imaginairy.safety import create_safety_score from imaginairy.utils import get_device +from imaginairy.utils.safety import create_safety_score from tests import TESTS_FOLDER