tests: add docker image for testing environment. minor test improvements

pull/18/head
Bryce 2 years ago
parent 09bc1c70e6
commit cdfeaa4c6f

@ -0,0 +1,21 @@
__pycache__
*.pyc
*.pyo
*.pyd
.Python
env
pip-log.txt
pip-delete-this-directory.txt
.tox
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.log
.git
.mypy_cache
.pytest_cache
.hypothesis
.DS_Store

@ -1,7 +1,6 @@
import logging import logging
import os import os
import re import re
from contextlib import nullcontext
from functools import lru_cache from functools import lru_cache
import numpy as np import numpy as np
@ -11,7 +10,6 @@ from einops import rearrange
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageDraw, ImageFilter, ImageOps from PIL import Image, ImageDraw, ImageFilter, ImageOps
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from torch import autocast
from transformers import cached_path from transformers import cached_path
from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.clip_masking import get_img_mask
@ -35,6 +33,7 @@ from imaginairy.utils import (
instantiate_from_config, instantiate_from_config,
pillow_fit_image_within, pillow_fit_image_within,
pillow_img_to_torch_image, pillow_img_to_torch_image,
platform_appropriate_autocast,
) )
LIB_PATH = os.path.dirname(__file__) LIB_PATH = os.path.dirname(__file__)
@ -159,13 +158,9 @@ def imagine(
_img_callback = None _img_callback = None
if get_device() == "cpu": if get_device() == "cpu":
logger.info("Running in CPU mode. it's gonna be slooooooow.") logger.info("Running in CPU mode. it's gonna be slooooooow.")
precision_scope = (
autocast with torch.no_grad(), platform_appropriate_autocast(
if precision == "autocast" and get_device() in ("cuda", "cpu") precision
else nullcontext
)
with torch.no_grad(), precision_scope(
get_device()
), fix_torch_nn_layer_norm(), fix_torch_group_norm(): ), fix_torch_nn_layer_norm(), fix_torch_group_norm():
for prompt in prompts: for prompt in prompts:
with ImageLoggingContext( with ImageLoggingContext(

@ -61,6 +61,18 @@ def get_obj_from_str(string, reload=False):
return getattr(importlib.import_module(module, package=None), cls) return getattr(importlib.import_module(module, package=None), cls)
@contextmanager
def platform_appropriate_autocast(precision="autocast"):
"""
allow calculations to run in mixed precision, which can be faster
"""
precision_scope = nullcontext
if precision == "autocast" and get_device() in ("cuda", "cpu"):
precision_scope = autocast
with precision_scope(get_device()):
yield
def _fixed_layer_norm( def _fixed_layer_norm(
input: Tensor, # noqa input: Tensor, # noqa
normalized_shape: List[int], normalized_shape: List[int],
@ -119,7 +131,7 @@ def fix_torch_group_norm():
orig_group_norm = functional.group_norm orig_group_norm = functional.group_norm
def _group_norm_wrapper( def _group_norm_wrapper(
input: Tensor, input: Tensor, # noqa
num_groups: int, num_groups: int,
weight: Optional[Tensor] = None, weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None, bias: Optional[Tensor] = None,

@ -6,3 +6,4 @@ pydocstyle
pylama pylama
pylint pylint
pytest pytest
pytest-randomly

@ -10,7 +10,7 @@ absl-py==1.2.0
# tensorboard # tensorboard
addict==2.4.0 addict==2.4.0
# via basicsr # via basicsr
aiohttp==3.8.1 aiohttp==3.8.3
# via fsspec # via fsspec
aiosignal==1.2.0 aiosignal==1.2.0
# via aiohttp # via aiohttp
@ -68,7 +68,7 @@ filelock==3.8.0
# transformers # transformers
filterpy==1.4.5 filterpy==1.4.5
# via facexlib # via facexlib
fonttools==4.37.2 fonttools==4.37.3
# via matplotlib # via matplotlib
frozenlist==1.3.1 frozenlist==1.3.1
# via # via
@ -86,7 +86,7 @@ gfpgan==1.3.8
# via # via
# imaginAIry (setup.py) # imaginAIry (setup.py)
# realesrgan # realesrgan
google-auth==2.11.0 google-auth==2.11.1
# via # via
# google-auth-oauthlib # google-auth-oauthlib
# tb-nightly # tb-nightly
@ -95,7 +95,7 @@ google-auth-oauthlib==0.4.6
# via # via
# tb-nightly # tb-nightly
# tensorboard # tensorboard
grpcio==1.49.0 grpcio==1.48.1
# via # via
# tb-nightly # tb-nightly
# tensorboard # tensorboard
@ -212,6 +212,7 @@ pillow==9.2.0
# diffusers # diffusers
# facexlib # facexlib
# imageio # imageio
# imaginAIry (setup.py)
# matplotlib # matplotlib
# realesrgan # realesrgan
# scikit-image # scikit-image
@ -249,13 +250,17 @@ pyflakes==2.5.0
# via pylama # via pylama
pylama==8.4.1 pylama==8.4.1
# via -r requirements-dev.in # via -r requirements-dev.in
pylint==2.15.2 pylint==2.15.3
# via -r requirements-dev.in # via -r requirements-dev.in
pyparsing==3.0.9 pyparsing==3.0.9
# via # via
# matplotlib # matplotlib
# packaging # packaging
pytest==7.1.3 pytest==7.1.3
# via
# -r requirements-dev.in
# pytest-randomly
pytest-randomly==3.12.0
# via -r requirements-dev.in # via -r requirements-dev.in
python-dateutil==2.8.2 python-dateutil==2.8.2
# via matplotlib # via matplotlib
@ -273,7 +278,7 @@ pyyaml==6.0
# pycln # pycln
# pytorch-lightning # pytorch-lightning
# transformers # transformers
realesrgan==0.2.8 realesrgan==0.3.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
regex==2022.9.13 regex==2022.9.13
# via # via
@ -311,7 +316,7 @@ six==1.16.0
# python-dateutil # python-dateutil
snowballstemmer==2.2.0 snowballstemmer==2.2.0
# via pydocstyle # via pydocstyle
tb-nightly==2.11.0a20220918 tb-nightly==2.11.0a20220921
# via # via
# basicsr # basicsr
# gfpgan # gfpgan

@ -0,0 +1,36 @@
FROM python:3.10.6-slim as base
RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 make
ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \
PIP_ROOT_USER_ACTION=ignore
FROM base as build_wheel
RUN pip install wheel
WORKDIR /app
COPY imaginairy ./imaginairy
COPY setup.py README.md ./
RUN python setup.py bdist_wheel
FROM base as install_wheel
WORKDIR /app
COPY requirements-dev.in ./
RUN pip install -r requirements-dev.in
COPY --from=build_wheel /app/dist/* ./
RUN pip install *.whl
RUN imagine --help
COPY Makefile ./
COPY tests ./tests

@ -4,7 +4,11 @@ import pytest
from imaginairy import api from imaginairy import api
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
from imaginairy.utils import fix_torch_nn_layer_norm from imaginairy.utils import (
fix_torch_group_norm,
fix_torch_nn_layer_norm,
platform_appropriate_autocast,
)
if "pytest" in str(sys.argv): if "pytest" in str(sys.argv):
suppress_annoying_logs_and_warnings() suppress_annoying_logs_and_warnings()
@ -13,5 +17,6 @@ if "pytest" in str(sys.argv):
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def pre_setup(): def pre_setup():
api.IMAGINAIRY_SAFETY_MODE = "disabled" api.IMAGINAIRY_SAFETY_MODE = "disabled"
with fix_torch_nn_layer_norm(): suppress_annoying_logs_and_warnings()
with fix_torch_nn_layer_norm(), fix_torch_group_norm(), platform_appropriate_autocast():
yield yield

@ -17,4 +17,4 @@ def test_text_conditioning():
if "mps" in get_device(): if "mps" in get_device():
assert hashed == "263e5ee7d2be087d816e094b80ffc546" assert hashed == "263e5ee7d2be087d816e094b80ffc546"
elif "cuda" in get_device(): elif "cuda" in get_device():
assert hashed == "3d7867d5b2ebf15102a9ca9476d63ebc" assert hashed == "41818051d7c469fc57d0a940c9d24d82"

@ -1,9 +1,12 @@
import pytest
from click.testing import CliRunner from click.testing import CliRunner
from imaginairy.cmds import imagine_cmd from imaginairy.cmds import imagine_cmd
from imaginairy.utils import get_device
from tests import TESTS_FOLDER from tests import TESTS_FOLDER
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_imagine_cmd(): def test_imagine_cmd():
runner = CliRunner() runner = CliRunner()
result = runner.invoke( result = runner.invoke(

@ -20,7 +20,7 @@ def test_fix_faces():
if "mps" in get_device(): if "mps" in get_device():
assert img_hash(img) == "a75991307eda675a26eeb7073f828e93" assert img_hash(img) == "a75991307eda675a26eeb7073f828e93"
else: else:
assert img_hash(img) == "5aa847a1464de75b158658a35800b6bf" assert img_hash(img) == "e56c1205bbc8f251be05773f2ba7fa24"
def img_hash(img): def img_hash(img):

@ -8,7 +8,7 @@ from imaginairy.utils import get_device
from . import TESTS_FOLDER from . import TESTS_FOLDER
device_sampler_type_test_cases = { device_sampler_type_test_cases = {
"mps:0": { "mps:0": [
("plms", "b4b434ed45919f3505ac2be162791c71"), ("plms", "b4b434ed45919f3505ac2be162791c71"),
("ddim", "b369032a025915c0a7ccced165a609b3"), ("ddim", "b369032a025915c0a7ccced165a609b3"),
("k_lms", "b87325c189799d646ccd07b331564eb6"), ("k_lms", "b87325c189799d646ccd07b331564eb6"),
@ -17,8 +17,8 @@ device_sampler_type_test_cases = {
("k_euler", "d126da5ca8b08099cde8b5037464e788"), ("k_euler", "d126da5ca8b08099cde8b5037464e788"),
("k_euler_a", "cac5ca2e26c31a544b76a9442eb2ea37"), ("k_euler_a", "cac5ca2e26c31a544b76a9442eb2ea37"),
("k_heun", "0382ef71d9967fefd15676410289ebab"), ("k_heun", "0382ef71d9967fefd15676410289ebab"),
}, ],
"cuda": { "cuda": [
("plms", "62e78287e7848e48d45a1b207fb84102"), ("plms", "62e78287e7848e48d45a1b207fb84102"),
("ddim", "164c2a008b100e5fa07d3db2018605bd"), ("ddim", "164c2a008b100e5fa07d3db2018605bd"),
("k_lms", "450fea507ccfb44b677d30fae9f40a52"), ("k_lms", "450fea507ccfb44b677d30fae9f40a52"),
@ -27,7 +27,8 @@ device_sampler_type_test_cases = {
("k_euler", "06df9c19d472bfa6530db98be4ea10e8"), ("k_euler", "06df9c19d472bfa6530db98be4ea10e8"),
("k_euler_a", "79552628ff77914c8b6870703fe116b5"), ("k_euler_a", "79552628ff77914c8b6870703fe116b5"),
("k_heun", "8ced3578ae25d34da9f4e4b1a20bf416"), ("k_heun", "8ced3578ae25d34da9f4e4b1a20bf416"),
}, ],
"cpu": [],
} }
sampler_type_test_cases = device_sampler_type_test_cases[get_device()] sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
@ -54,12 +55,14 @@ device_sampler_type_test_cases_img_2_img = {
("plms", "efba8b836b51d262dbf72284844869f8"), ("plms", "efba8b836b51d262dbf72284844869f8"),
("ddim", "a62878000ad3b581a11dd3fb329dc7d2"), ("ddim", "a62878000ad3b581a11dd3fb329dc7d2"),
}, },
"cpu": [],
} }
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[ sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[
get_device() get_device()
] ]
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img) @pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
def test_img_to_img(sampler_type, expected_md5): def test_img_to_img(sampler_type, expected_md5):
prompt = ImaginePrompt( prompt = ImaginePrompt(
@ -79,6 +82,7 @@ def test_img_to_img(sampler_type, expected_md5):
assert result.md5() == expected_md5 assert result.md5() == expected_md5
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_img_to_img_from_url(): def test_img_to_img_from_url():
prompt = ImaginePrompt( prompt = ImaginePrompt(
"dogs lying on a hot pink couch", "dogs lying on a hot pink couch",
@ -96,6 +100,7 @@ def test_img_to_img_from_url():
imagine_image_files(prompt, outdir=out_folder) imagine_image_files(prompt, outdir=out_folder)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_img_to_file(): def test_img_to_file():
prompt = ImaginePrompt( prompt = ImaginePrompt(
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo", "an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
@ -110,6 +115,7 @@ def test_img_to_file():
imagine_image_files(prompt, outdir=out_folder) imagine_image_files(prompt, outdir=out_folder)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_inpainting(): def test_inpainting():
prompt = ImaginePrompt( prompt = ImaginePrompt(
"a basketball on a bench", "a basketball on a bench",
@ -126,6 +132,7 @@ def test_inpainting():
imagine_image_files(prompt, outdir=out_folder) imagine_image_files(prompt, outdir=out_folder)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_cliptext_inpainting(): def test_cliptext_inpainting():
prompts = [ prompts = [
ImaginePrompt( ImaginePrompt(

@ -18,6 +18,7 @@ def test_is_nsfw():
def _pil_to_latent(img): def _pil_to_latent(img):
model = load_model() model = load_model()
model.tile_mode(False)
img = pillow_img_to_torch_image(img) img = pillow_img_to_torch_image(img)
img = img.to(get_device()) img = img.to(get_device())
latent = model.get_first_stage_encoding(model.encode_first_stage(img)) latent = model.get_first_stage_encoding(model.encode_first_stage(img))

Loading…
Cancel
Save