refactor/test: logging suppression + hashed image test

- simpler logging suppression for `transformers` library
- suppress logging noise for running tests
- get test running for all samplers on mps and cuda platforms
- refactor safety model env variable to allow classification
pull/1/head
Bryce 2 years ago
parent 8c88f495d2
commit 967eb76365

3
.gitignore vendored

@ -15,4 +15,5 @@ dist
**/*.ckpt
**/*.egg-info
tests/test_output
gfpgan/**
gfpgan/**
.python-version

@ -1 +0,0 @@
imaginairy-3.10.6

@ -1,5 +1,6 @@
import os
# tells pytorch to allow MPS usage (for Mac M1 compatibility)
os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1")
from .api import imagine, imagine_image_files # noqa

@ -30,9 +30,17 @@ from imaginairy.utils import (
LIB_PATH = os.path.dirname(__file__)
logger = logging.getLogger(__name__)
# leave undocumented. I'd ask that no one publicize this flag
IMAGINAIRY_ALLOW_NSFW = os.getenv("IMAGINAIRY_ALLOW_NSFW", "False")
IMAGINAIRY_ALLOW_NSFW = bool(IMAGINAIRY_ALLOW_NSFW == "I AM A RESPONSIBLE ADULT")
class SafetyMode:
DISABLED = "disabled"
CLASSIFY = "classify"
FILTER = "filter"
# leave undocumented. I'd ask that no one publicize this flag. Just want a
# slight barrier to entry. Please don't use this is any way that's gonna cause
# the press or governments to freak out about AI...
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.FILTER)
def load_model_from_config(config):
@ -243,11 +251,13 @@ def imagine(
x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig)
upscaled_img = None
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(
img, x_sample, half_mode=half_mode
):
logger.info(" ⚠️ Filtering NSFW image")
img = img.filter(ImageFilter.GaussianBlur(radius=40))
is_nsfw_img = None
if IMAGINAIRY_SAFETY_MODE != SafetyMode.DISABLED:
if is_nsfw(img, x_sample, half_mode=half_mode):
is_nsfw_img = True
if IMAGINAIRY_SAFETY_MODE == SafetyMode.FILTER:
logger.info(" ⚠️ Filtering NSFW image")
img = img.filter(ImageFilter.GaussianBlur(radius=40))
if prompt.fix_faces:
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")
@ -257,7 +267,10 @@ def imagine(
upscaled_img = upscale_image(img)
yield ImagineResult(
img=img, prompt=prompt, upscaled_img=upscaled_img
img=img,
prompt=prompt,
upscaled_img=upscaled_img,
is_nsfw=is_nsfw_img,
)

@ -1,67 +1,8 @@
# only builtin imports allowed at this point since we want to modify
# the environment and code before it's loaded
import importlib.abc
import importlib.util
import logging.config
import os
import site
import sys
import warnings
# tells pytorch to allow MPS usage (for Mac M1 compatibility)
os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1")
def disable_transformers_logging():
"""
Disable `transformers` package custom logging.
I can't believe it came to this. I tried like four other approaches first
Loads up the source code from the transformers file and turns it into a module.
We then modify the module. Every other approach (import hooks, custom import function)
loaded the module before it could be modified.
"""
t_logging_path = f"{site.getsitepackages()[0]}/transformers/utils/logging.py"
with open(t_logging_path, "r", encoding="utf-8") as f:
src_code = f.read()
spec = importlib.util.spec_from_loader("transformers.utils.logging", loader=None)
module = importlib.util.module_from_spec(spec)
exec(src_code, module.__dict__)
module.get_logger = logging.getLogger
sys.modules["transformers.utils.logging"] = module
def disable_pytorch_lighting_custom_logging():
from pytorch_lightning import _logger
_logger.setLevel(logging.NOTSET)
def disable_common_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"The operator .*?is not currently supported.*",
)
warnings.filterwarnings(
"ignore", category=UserWarning, message=r"The parameter 'pretrained' is.*"
)
warnings.filterwarnings(
"ignore", category=UserWarning, message=r"Arguments other than a weight.*"
)
def setup_env():
disable_transformers_logging()
disable_pytorch_lighting_custom_logging()
disable_common_warnings()
def imagine_cmd(*args, **kwargs):
from .suppress_logs import suppress_annoying_logs_and_warnings # noqa
suppress_annoying_logs_and_warnings()
def imagine_cmd(*args, **kwargs):
setup_env()
from imaginairy.cmds import imagine_cmd as imagine_cmd_orig # noqa
imagine_cmd_orig(*args, **kwargs)

@ -39,17 +39,17 @@ class LatentLoggingContext:
global _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = None
def log_latents(self, samples, description):
def log_latents(self, latents, description):
if not self.img_callback:
return
if samples.shape[1] != 4:
if latents.shape[1] != 4:
# logger.info(f"Didn't save tensor of shape {samples.shape} for {description}")
return
self.step_count += 1
description = f"{description} - {samples.shape}"
samples = self.model.decode_first_stage(samples)
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
for pred_x0 in samples:
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
img = Image.fromarray(pred_x0.astype(np.uint8))
description = f"{description} - {latents.shape}"
latents = self.model.decode_first_stage(latents)
latents = torch.clamp((latents + 1.0) / 2.0, min=0.0, max=1.0)
for latent in latents:
latent = 255.0 * rearrange(latent.cpu().numpy(), "c h w -> h w c")
img = Image.fromarray(latent.astype(np.uint8))
self.img_callback(img, description, self.step_count, self.prompt)

@ -91,10 +91,11 @@ class ExifCodes:
class ImagineResult:
def __init__(self, img, prompt: ImaginePrompt, upscaled_img=None):
def __init__(self, img, prompt: ImaginePrompt, is_nsfw, upscaled_img=None):
self.img = img
self.upscaled_img = upscaled_img
self.prompt = prompt
self.is_nsfw = is_nsfw
self.created_at = datetime.utcnow().replace(tzinfo=timezone.utc)
self.torch_backend = get_device()
self.hardware_name = get_device_name(get_device())

@ -0,0 +1,40 @@
import logging.config
import warnings
def disable_transformers_custom_logging():
from transformers.modeling_utils import logger
from transformers.utils.logging import _configure_library_root_logger
_configure_library_root_logger()
logger = logger.parent
logger.handlers = []
logger.propagate = True
logger.setLevel(logging.NOTSET)
def disable_pytorch_lighting_custom_logging():
from pytorch_lightning import _logger
_logger.setLevel(logging.NOTSET)
def disable_common_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"The operator .*?is not currently supported.*",
)
warnings.filterwarnings(
"ignore", category=UserWarning, message=r"The parameter 'pretrained' is.*"
)
warnings.filterwarnings(
"ignore", category=UserWarning, message=r"Arguments other than a weight.*"
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
def suppress_annoying_logs_and_warnings():
disable_transformers_custom_logging()
disable_pytorch_lighting_custom_logging()
disable_common_warnings()

@ -0,0 +1,15 @@
import sys
import pytest
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
if "pytest" in str(sys.argv):
suppress_annoying_logs_and_warnings()
@pytest.fixture(scope="session", autouse=True)
def pre_setup():
from imaginairy import api
api.IMAGINAIRY_SAFETY_MODE = "disabled"

@ -6,24 +6,36 @@ from imaginairy.utils import get_device
from . import TESTS_FOLDER
mps_sampler_type_test_cases = {
("plms", "3f211329796277a1870378288769fcde"),
("ddim", "70dbf2acce2c052e4e7f37412ae0366e"),
("k_lms", "3585c10c8f27bf091c15e761dca4d578"),
("k_dpm_2", "29b07125c9879540f8efac317ae33aea"),
("k_dpm_2_a", "4fd6767980444ca72e97cba2d0491eb4"),
("k_euler", "50609b279cff756db42ab9d2c85328ed"),
("k_euler_a", "ae7ac199c10f303e5ebd675109e59b23"),
("k_heun", "3668fe66770538337ac8c0b7ac210892"),
device_sampler_type_test_cases = {
"mps": {
("plms", "b4b434ed45919f3505ac2be162791c71"),
("ddim", "b369032a025915c0a7ccced165a609b3"),
("k_lms", "b87325c189799d646ccd07b331564eb6"),
("k_dpm_2", "cb37ca934938466bdbc1dd995da037de"),
("k_dpm_2_a", "ef155995ca1638f0ae7db9f573b83767"),
("k_euler", "d126da5ca8b08099cde8b5037464e788"),
("k_euler_a", "cac5ca2e26c31a544b76a9442eb2ea37"),
("k_heun", "0382ef71d9967fefd15676410289ebab"),
},
"cuda": {
("plms", "62e78287e7848e48d45a1b207fb84102"),
("ddim", "164c2a008b100e5fa07d3db2018605bd"),
("k_lms", "450fea507ccfb44b677d30fae9f40a52"),
("k_dpm_2", "901daad7a9e359404d8e3d3f4236c4ce"),
("k_dpm_2_a", "855e80286dfdc89752f6bdd3fdeb1a62"),
("k_euler", "06df9c19d472bfa6530db98be4ea10e8"),
("k_euler_a", "79552628ff77914c8b6870703fe116b5"),
("k_heun", "8ced3578ae25d34da9f4e4b1a20bf416"),
},
}
sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
@pytest.mark.skipif(get_device() != "mps", reason="mps hashes")
@pytest.mark.parametrize("sampler_type,expected_md5", mps_sampler_type_test_cases)
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases)
def test_imagine(sampler_type, expected_md5):
prompt_text = "a scenic landscape"
prompt = ImaginePrompt(
prompt_text, width=512, height=256, steps=10, seed=1, sampler_type=sampler_type
prompt_text, width=512, height=256, steps=5, seed=1, sampler_type=sampler_type
)
result = next(imagine(prompt))
result.img.save(
@ -39,7 +51,7 @@ def test_img_to_img():
init_image_strength=0.8,
width=512,
height=512,
steps=50,
steps=5,
seed=1,
sampler_type="DDIM",
)
@ -52,7 +64,7 @@ def test_img_to_file():
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
width=512 + 64,
height=512 - 64,
steps=50,
steps=5,
seed=2,
sampler_type="PLMS",
upscale=True,

@ -1,7 +1,10 @@
[pytest]
addopts = --doctest-modules -s --tb=native
norecursedirs = build dist downloads other prolly_delete
asyncio_mode = strict
norecursedirs = build dist downloads other prolly_delete imaginairy/vendored
filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
[pylama]
format = pylint

Loading…
Cancel
Save