mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
tests: add docker image for testing environment. minor test improvements
This commit is contained in:
parent
09bc1c70e6
commit
cdfeaa4c6f
21
.dockerignore
Normal file
21
.dockerignore
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
env
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
.tox
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.log
|
||||||
|
.git
|
||||||
|
.mypy_cache
|
||||||
|
.pytest_cache
|
||||||
|
.hypothesis
|
||||||
|
.DS_Store
|
@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from contextlib import nullcontext
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -11,7 +10,6 @@ from einops import rearrange
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image, ImageDraw, ImageFilter, ImageOps
|
from PIL import Image, ImageDraw, ImageFilter, ImageOps
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from torch import autocast
|
|
||||||
from transformers import cached_path
|
from transformers import cached_path
|
||||||
|
|
||||||
from imaginairy.enhancers.clip_masking import get_img_mask
|
from imaginairy.enhancers.clip_masking import get_img_mask
|
||||||
@ -35,6 +33,7 @@ from imaginairy.utils import (
|
|||||||
instantiate_from_config,
|
instantiate_from_config,
|
||||||
pillow_fit_image_within,
|
pillow_fit_image_within,
|
||||||
pillow_img_to_torch_image,
|
pillow_img_to_torch_image,
|
||||||
|
platform_appropriate_autocast,
|
||||||
)
|
)
|
||||||
|
|
||||||
LIB_PATH = os.path.dirname(__file__)
|
LIB_PATH = os.path.dirname(__file__)
|
||||||
@ -159,13 +158,9 @@ def imagine(
|
|||||||
_img_callback = None
|
_img_callback = None
|
||||||
if get_device() == "cpu":
|
if get_device() == "cpu":
|
||||||
logger.info("Running in CPU mode. it's gonna be slooooooow.")
|
logger.info("Running in CPU mode. it's gonna be slooooooow.")
|
||||||
precision_scope = (
|
|
||||||
autocast
|
with torch.no_grad(), platform_appropriate_autocast(
|
||||||
if precision == "autocast" and get_device() in ("cuda", "cpu")
|
precision
|
||||||
else nullcontext
|
|
||||||
)
|
|
||||||
with torch.no_grad(), precision_scope(
|
|
||||||
get_device()
|
|
||||||
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
|
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
with ImageLoggingContext(
|
with ImageLoggingContext(
|
||||||
|
@ -61,6 +61,18 @@ def get_obj_from_str(string, reload=False):
|
|||||||
return getattr(importlib.import_module(module, package=None), cls)
|
return getattr(importlib.import_module(module, package=None), cls)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def platform_appropriate_autocast(precision="autocast"):
|
||||||
|
"""
|
||||||
|
allow calculations to run in mixed precision, which can be faster
|
||||||
|
"""
|
||||||
|
precision_scope = nullcontext
|
||||||
|
if precision == "autocast" and get_device() in ("cuda", "cpu"):
|
||||||
|
precision_scope = autocast
|
||||||
|
with precision_scope(get_device()):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def _fixed_layer_norm(
|
def _fixed_layer_norm(
|
||||||
input: Tensor, # noqa
|
input: Tensor, # noqa
|
||||||
normalized_shape: List[int],
|
normalized_shape: List[int],
|
||||||
@ -119,7 +131,7 @@ def fix_torch_group_norm():
|
|||||||
orig_group_norm = functional.group_norm
|
orig_group_norm = functional.group_norm
|
||||||
|
|
||||||
def _group_norm_wrapper(
|
def _group_norm_wrapper(
|
||||||
input: Tensor,
|
input: Tensor, # noqa
|
||||||
num_groups: int,
|
num_groups: int,
|
||||||
weight: Optional[Tensor] = None,
|
weight: Optional[Tensor] = None,
|
||||||
bias: Optional[Tensor] = None,
|
bias: Optional[Tensor] = None,
|
||||||
|
@ -6,3 +6,4 @@ pydocstyle
|
|||||||
pylama
|
pylama
|
||||||
pylint
|
pylint
|
||||||
pytest
|
pytest
|
||||||
|
pytest-randomly
|
||||||
|
@ -10,7 +10,7 @@ absl-py==1.2.0
|
|||||||
# tensorboard
|
# tensorboard
|
||||||
addict==2.4.0
|
addict==2.4.0
|
||||||
# via basicsr
|
# via basicsr
|
||||||
aiohttp==3.8.1
|
aiohttp==3.8.3
|
||||||
# via fsspec
|
# via fsspec
|
||||||
aiosignal==1.2.0
|
aiosignal==1.2.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
@ -68,7 +68,7 @@ filelock==3.8.0
|
|||||||
# transformers
|
# transformers
|
||||||
filterpy==1.4.5
|
filterpy==1.4.5
|
||||||
# via facexlib
|
# via facexlib
|
||||||
fonttools==4.37.2
|
fonttools==4.37.3
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
frozenlist==1.3.1
|
frozenlist==1.3.1
|
||||||
# via
|
# via
|
||||||
@ -86,7 +86,7 @@ gfpgan==1.3.8
|
|||||||
# via
|
# via
|
||||||
# imaginAIry (setup.py)
|
# imaginAIry (setup.py)
|
||||||
# realesrgan
|
# realesrgan
|
||||||
google-auth==2.11.0
|
google-auth==2.11.1
|
||||||
# via
|
# via
|
||||||
# google-auth-oauthlib
|
# google-auth-oauthlib
|
||||||
# tb-nightly
|
# tb-nightly
|
||||||
@ -95,7 +95,7 @@ google-auth-oauthlib==0.4.6
|
|||||||
# via
|
# via
|
||||||
# tb-nightly
|
# tb-nightly
|
||||||
# tensorboard
|
# tensorboard
|
||||||
grpcio==1.49.0
|
grpcio==1.48.1
|
||||||
# via
|
# via
|
||||||
# tb-nightly
|
# tb-nightly
|
||||||
# tensorboard
|
# tensorboard
|
||||||
@ -212,6 +212,7 @@ pillow==9.2.0
|
|||||||
# diffusers
|
# diffusers
|
||||||
# facexlib
|
# facexlib
|
||||||
# imageio
|
# imageio
|
||||||
|
# imaginAIry (setup.py)
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# realesrgan
|
# realesrgan
|
||||||
# scikit-image
|
# scikit-image
|
||||||
@ -249,13 +250,17 @@ pyflakes==2.5.0
|
|||||||
# via pylama
|
# via pylama
|
||||||
pylama==8.4.1
|
pylama==8.4.1
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
pylint==2.15.2
|
pylint==2.15.3
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
pyparsing==3.0.9
|
pyparsing==3.0.9
|
||||||
# via
|
# via
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# packaging
|
# packaging
|
||||||
pytest==7.1.3
|
pytest==7.1.3
|
||||||
|
# via
|
||||||
|
# -r requirements-dev.in
|
||||||
|
# pytest-randomly
|
||||||
|
pytest-randomly==3.12.0
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
python-dateutil==2.8.2
|
python-dateutil==2.8.2
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
@ -273,7 +278,7 @@ pyyaml==6.0
|
|||||||
# pycln
|
# pycln
|
||||||
# pytorch-lightning
|
# pytorch-lightning
|
||||||
# transformers
|
# transformers
|
||||||
realesrgan==0.2.8
|
realesrgan==0.3.0
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
regex==2022.9.13
|
regex==2022.9.13
|
||||||
# via
|
# via
|
||||||
@ -311,7 +316,7 @@ six==1.16.0
|
|||||||
# python-dateutil
|
# python-dateutil
|
||||||
snowballstemmer==2.2.0
|
snowballstemmer==2.2.0
|
||||||
# via pydocstyle
|
# via pydocstyle
|
||||||
tb-nightly==2.11.0a20220918
|
tb-nightly==2.11.0a20220921
|
||||||
# via
|
# via
|
||||||
# basicsr
|
# basicsr
|
||||||
# gfpgan
|
# gfpgan
|
||||||
|
36
tests/Dockerfile
Normal file
36
tests/Dockerfile
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
FROM python:3.10.6-slim as base
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 make
|
||||||
|
|
||||||
|
ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
||||||
|
PIP_ROOT_USER_ACTION=ignore
|
||||||
|
|
||||||
|
|
||||||
|
FROM base as build_wheel
|
||||||
|
|
||||||
|
RUN pip install wheel
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY imaginairy ./imaginairy
|
||||||
|
COPY setup.py README.md ./
|
||||||
|
|
||||||
|
RUN python setup.py bdist_wheel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FROM base as install_wheel
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements-dev.in ./
|
||||||
|
|
||||||
|
RUN pip install -r requirements-dev.in
|
||||||
|
|
||||||
|
COPY --from=build_wheel /app/dist/* ./
|
||||||
|
|
||||||
|
RUN pip install *.whl
|
||||||
|
RUN imagine --help
|
||||||
|
COPY Makefile ./
|
||||||
|
COPY tests ./tests
|
||||||
|
|
@ -4,7 +4,11 @@ import pytest
|
|||||||
|
|
||||||
from imaginairy import api
|
from imaginairy import api
|
||||||
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
|
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
|
||||||
from imaginairy.utils import fix_torch_nn_layer_norm
|
from imaginairy.utils import (
|
||||||
|
fix_torch_group_norm,
|
||||||
|
fix_torch_nn_layer_norm,
|
||||||
|
platform_appropriate_autocast,
|
||||||
|
)
|
||||||
|
|
||||||
if "pytest" in str(sys.argv):
|
if "pytest" in str(sys.argv):
|
||||||
suppress_annoying_logs_and_warnings()
|
suppress_annoying_logs_and_warnings()
|
||||||
@ -13,5 +17,6 @@ if "pytest" in str(sys.argv):
|
|||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def pre_setup():
|
def pre_setup():
|
||||||
api.IMAGINAIRY_SAFETY_MODE = "disabled"
|
api.IMAGINAIRY_SAFETY_MODE = "disabled"
|
||||||
with fix_torch_nn_layer_norm():
|
suppress_annoying_logs_and_warnings()
|
||||||
|
with fix_torch_nn_layer_norm(), fix_torch_group_norm(), platform_appropriate_autocast():
|
||||||
yield
|
yield
|
||||||
|
@ -17,4 +17,4 @@ def test_text_conditioning():
|
|||||||
if "mps" in get_device():
|
if "mps" in get_device():
|
||||||
assert hashed == "263e5ee7d2be087d816e094b80ffc546"
|
assert hashed == "263e5ee7d2be087d816e094b80ffc546"
|
||||||
elif "cuda" in get_device():
|
elif "cuda" in get_device():
|
||||||
assert hashed == "3d7867d5b2ebf15102a9ca9476d63ebc"
|
assert hashed == "41818051d7c469fc57d0a940c9d24d82"
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
import pytest
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
from imaginairy.cmds import imagine_cmd
|
from imaginairy.cmds import imagine_cmd
|
||||||
|
from imaginairy.utils import get_device
|
||||||
from tests import TESTS_FOLDER
|
from tests import TESTS_FOLDER
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||||
def test_imagine_cmd():
|
def test_imagine_cmd():
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
|
@ -20,7 +20,7 @@ def test_fix_faces():
|
|||||||
if "mps" in get_device():
|
if "mps" in get_device():
|
||||||
assert img_hash(img) == "a75991307eda675a26eeb7073f828e93"
|
assert img_hash(img) == "a75991307eda675a26eeb7073f828e93"
|
||||||
else:
|
else:
|
||||||
assert img_hash(img) == "5aa847a1464de75b158658a35800b6bf"
|
assert img_hash(img) == "e56c1205bbc8f251be05773f2ba7fa24"
|
||||||
|
|
||||||
|
|
||||||
def img_hash(img):
|
def img_hash(img):
|
||||||
|
@ -8,7 +8,7 @@ from imaginairy.utils import get_device
|
|||||||
from . import TESTS_FOLDER
|
from . import TESTS_FOLDER
|
||||||
|
|
||||||
device_sampler_type_test_cases = {
|
device_sampler_type_test_cases = {
|
||||||
"mps:0": {
|
"mps:0": [
|
||||||
("plms", "b4b434ed45919f3505ac2be162791c71"),
|
("plms", "b4b434ed45919f3505ac2be162791c71"),
|
||||||
("ddim", "b369032a025915c0a7ccced165a609b3"),
|
("ddim", "b369032a025915c0a7ccced165a609b3"),
|
||||||
("k_lms", "b87325c189799d646ccd07b331564eb6"),
|
("k_lms", "b87325c189799d646ccd07b331564eb6"),
|
||||||
@ -17,8 +17,8 @@ device_sampler_type_test_cases = {
|
|||||||
("k_euler", "d126da5ca8b08099cde8b5037464e788"),
|
("k_euler", "d126da5ca8b08099cde8b5037464e788"),
|
||||||
("k_euler_a", "cac5ca2e26c31a544b76a9442eb2ea37"),
|
("k_euler_a", "cac5ca2e26c31a544b76a9442eb2ea37"),
|
||||||
("k_heun", "0382ef71d9967fefd15676410289ebab"),
|
("k_heun", "0382ef71d9967fefd15676410289ebab"),
|
||||||
},
|
],
|
||||||
"cuda": {
|
"cuda": [
|
||||||
("plms", "62e78287e7848e48d45a1b207fb84102"),
|
("plms", "62e78287e7848e48d45a1b207fb84102"),
|
||||||
("ddim", "164c2a008b100e5fa07d3db2018605bd"),
|
("ddim", "164c2a008b100e5fa07d3db2018605bd"),
|
||||||
("k_lms", "450fea507ccfb44b677d30fae9f40a52"),
|
("k_lms", "450fea507ccfb44b677d30fae9f40a52"),
|
||||||
@ -27,7 +27,8 @@ device_sampler_type_test_cases = {
|
|||||||
("k_euler", "06df9c19d472bfa6530db98be4ea10e8"),
|
("k_euler", "06df9c19d472bfa6530db98be4ea10e8"),
|
||||||
("k_euler_a", "79552628ff77914c8b6870703fe116b5"),
|
("k_euler_a", "79552628ff77914c8b6870703fe116b5"),
|
||||||
("k_heun", "8ced3578ae25d34da9f4e4b1a20bf416"),
|
("k_heun", "8ced3578ae25d34da9f4e4b1a20bf416"),
|
||||||
},
|
],
|
||||||
|
"cpu": [],
|
||||||
}
|
}
|
||||||
sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
|
sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
|
||||||
|
|
||||||
@ -54,12 +55,14 @@ device_sampler_type_test_cases_img_2_img = {
|
|||||||
("plms", "efba8b836b51d262dbf72284844869f8"),
|
("plms", "efba8b836b51d262dbf72284844869f8"),
|
||||||
("ddim", "a62878000ad3b581a11dd3fb329dc7d2"),
|
("ddim", "a62878000ad3b581a11dd3fb329dc7d2"),
|
||||||
},
|
},
|
||||||
|
"cpu": [],
|
||||||
}
|
}
|
||||||
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[
|
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[
|
||||||
get_device()
|
get_device()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
|
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
|
||||||
def test_img_to_img(sampler_type, expected_md5):
|
def test_img_to_img(sampler_type, expected_md5):
|
||||||
prompt = ImaginePrompt(
|
prompt = ImaginePrompt(
|
||||||
@ -79,6 +82,7 @@ def test_img_to_img(sampler_type, expected_md5):
|
|||||||
assert result.md5() == expected_md5
|
assert result.md5() == expected_md5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||||
def test_img_to_img_from_url():
|
def test_img_to_img_from_url():
|
||||||
prompt = ImaginePrompt(
|
prompt = ImaginePrompt(
|
||||||
"dogs lying on a hot pink couch",
|
"dogs lying on a hot pink couch",
|
||||||
@ -96,6 +100,7 @@ def test_img_to_img_from_url():
|
|||||||
imagine_image_files(prompt, outdir=out_folder)
|
imagine_image_files(prompt, outdir=out_folder)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||||
def test_img_to_file():
|
def test_img_to_file():
|
||||||
prompt = ImaginePrompt(
|
prompt = ImaginePrompt(
|
||||||
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
|
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
|
||||||
@ -110,6 +115,7 @@ def test_img_to_file():
|
|||||||
imagine_image_files(prompt, outdir=out_folder)
|
imagine_image_files(prompt, outdir=out_folder)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||||
def test_inpainting():
|
def test_inpainting():
|
||||||
prompt = ImaginePrompt(
|
prompt = ImaginePrompt(
|
||||||
"a basketball on a bench",
|
"a basketball on a bench",
|
||||||
@ -126,6 +132,7 @@ def test_inpainting():
|
|||||||
imagine_image_files(prompt, outdir=out_folder)
|
imagine_image_files(prompt, outdir=out_folder)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||||
def test_cliptext_inpainting():
|
def test_cliptext_inpainting():
|
||||||
prompts = [
|
prompts = [
|
||||||
ImaginePrompt(
|
ImaginePrompt(
|
||||||
|
@ -18,6 +18,7 @@ def test_is_nsfw():
|
|||||||
|
|
||||||
def _pil_to_latent(img):
|
def _pil_to_latent(img):
|
||||||
model = load_model()
|
model = load_model()
|
||||||
|
model.tile_mode(False)
|
||||||
img = pillow_img_to_torch_image(img)
|
img = pillow_img_to_torch_image(img)
|
||||||
img = img.to(get_device())
|
img = img.to(get_device())
|
||||||
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
|
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
|
||||||
|
Loading…
Reference in New Issue
Block a user