mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
tests: add docker image for testing environment. minor test improvements
This commit is contained in:
parent
09bc1c70e6
commit
cdfeaa4c6f
21
.dockerignore
Normal file
21
.dockerignore
Normal file
@ -0,0 +1,21 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
env
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
.tox
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.log
|
||||
.git
|
||||
.mypy_cache
|
||||
.pytest_cache
|
||||
.hypothesis
|
||||
.DS_Store
|
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
@ -11,7 +10,6 @@ from einops import rearrange
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageDraw, ImageFilter, ImageOps
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from transformers import cached_path
|
||||
|
||||
from imaginairy.enhancers.clip_masking import get_img_mask
|
||||
@ -35,6 +33,7 @@ from imaginairy.utils import (
|
||||
instantiate_from_config,
|
||||
pillow_fit_image_within,
|
||||
pillow_img_to_torch_image,
|
||||
platform_appropriate_autocast,
|
||||
)
|
||||
|
||||
LIB_PATH = os.path.dirname(__file__)
|
||||
@ -159,13 +158,9 @@ def imagine(
|
||||
_img_callback = None
|
||||
if get_device() == "cpu":
|
||||
logger.info("Running in CPU mode. it's gonna be slooooooow.")
|
||||
precision_scope = (
|
||||
autocast
|
||||
if precision == "autocast" and get_device() in ("cuda", "cpu")
|
||||
else nullcontext
|
||||
)
|
||||
with torch.no_grad(), precision_scope(
|
||||
get_device()
|
||||
|
||||
with torch.no_grad(), platform_appropriate_autocast(
|
||||
precision
|
||||
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
|
||||
for prompt in prompts:
|
||||
with ImageLoggingContext(
|
||||
|
@ -61,6 +61,18 @@ def get_obj_from_str(string, reload=False):
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def platform_appropriate_autocast(precision="autocast"):
|
||||
"""
|
||||
allow calculations to run in mixed precision, which can be faster
|
||||
"""
|
||||
precision_scope = nullcontext
|
||||
if precision == "autocast" and get_device() in ("cuda", "cpu"):
|
||||
precision_scope = autocast
|
||||
with precision_scope(get_device()):
|
||||
yield
|
||||
|
||||
|
||||
def _fixed_layer_norm(
|
||||
input: Tensor, # noqa
|
||||
normalized_shape: List[int],
|
||||
@ -119,7 +131,7 @@ def fix_torch_group_norm():
|
||||
orig_group_norm = functional.group_norm
|
||||
|
||||
def _group_norm_wrapper(
|
||||
input: Tensor,
|
||||
input: Tensor, # noqa
|
||||
num_groups: int,
|
||||
weight: Optional[Tensor] = None,
|
||||
bias: Optional[Tensor] = None,
|
||||
|
@ -6,3 +6,4 @@ pydocstyle
|
||||
pylama
|
||||
pylint
|
||||
pytest
|
||||
pytest-randomly
|
||||
|
@ -10,7 +10,7 @@ absl-py==1.2.0
|
||||
# tensorboard
|
||||
addict==2.4.0
|
||||
# via basicsr
|
||||
aiohttp==3.8.1
|
||||
aiohttp==3.8.3
|
||||
# via fsspec
|
||||
aiosignal==1.2.0
|
||||
# via aiohttp
|
||||
@ -68,7 +68,7 @@ filelock==3.8.0
|
||||
# transformers
|
||||
filterpy==1.4.5
|
||||
# via facexlib
|
||||
fonttools==4.37.2
|
||||
fonttools==4.37.3
|
||||
# via matplotlib
|
||||
frozenlist==1.3.1
|
||||
# via
|
||||
@ -86,7 +86,7 @@ gfpgan==1.3.8
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# realesrgan
|
||||
google-auth==2.11.0
|
||||
google-auth==2.11.1
|
||||
# via
|
||||
# google-auth-oauthlib
|
||||
# tb-nightly
|
||||
@ -95,7 +95,7 @@ google-auth-oauthlib==0.4.6
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
grpcio==1.49.0
|
||||
grpcio==1.48.1
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
@ -212,6 +212,7 @@ pillow==9.2.0
|
||||
# diffusers
|
||||
# facexlib
|
||||
# imageio
|
||||
# imaginAIry (setup.py)
|
||||
# matplotlib
|
||||
# realesrgan
|
||||
# scikit-image
|
||||
@ -249,13 +250,17 @@ pyflakes==2.5.0
|
||||
# via pylama
|
||||
pylama==8.4.1
|
||||
# via -r requirements-dev.in
|
||||
pylint==2.15.2
|
||||
pylint==2.15.3
|
||||
# via -r requirements-dev.in
|
||||
pyparsing==3.0.9
|
||||
# via
|
||||
# matplotlib
|
||||
# packaging
|
||||
pytest==7.1.3
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# pytest-randomly
|
||||
pytest-randomly==3.12.0
|
||||
# via -r requirements-dev.in
|
||||
python-dateutil==2.8.2
|
||||
# via matplotlib
|
||||
@ -273,7 +278,7 @@ pyyaml==6.0
|
||||
# pycln
|
||||
# pytorch-lightning
|
||||
# transformers
|
||||
realesrgan==0.2.8
|
||||
realesrgan==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
regex==2022.9.13
|
||||
# via
|
||||
@ -311,7 +316,7 @@ six==1.16.0
|
||||
# python-dateutil
|
||||
snowballstemmer==2.2.0
|
||||
# via pydocstyle
|
||||
tb-nightly==2.11.0a20220918
|
||||
tb-nightly==2.11.0a20220921
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
|
36
tests/Dockerfile
Normal file
36
tests/Dockerfile
Normal file
@ -0,0 +1,36 @@
|
||||
FROM python:3.10.6-slim as base
|
||||
|
||||
RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 make
|
||||
|
||||
ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
||||
PIP_ROOT_USER_ACTION=ignore
|
||||
|
||||
|
||||
FROM base as build_wheel
|
||||
|
||||
RUN pip install wheel
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY imaginairy ./imaginairy
|
||||
COPY setup.py README.md ./
|
||||
|
||||
RUN python setup.py bdist_wheel
|
||||
|
||||
|
||||
|
||||
FROM base as install_wheel
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements-dev.in ./
|
||||
|
||||
RUN pip install -r requirements-dev.in
|
||||
|
||||
COPY --from=build_wheel /app/dist/* ./
|
||||
|
||||
RUN pip install *.whl
|
||||
RUN imagine --help
|
||||
COPY Makefile ./
|
||||
COPY tests ./tests
|
||||
|
@ -4,7 +4,11 @@ import pytest
|
||||
|
||||
from imaginairy import api
|
||||
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
|
||||
from imaginairy.utils import fix_torch_nn_layer_norm
|
||||
from imaginairy.utils import (
|
||||
fix_torch_group_norm,
|
||||
fix_torch_nn_layer_norm,
|
||||
platform_appropriate_autocast,
|
||||
)
|
||||
|
||||
if "pytest" in str(sys.argv):
|
||||
suppress_annoying_logs_and_warnings()
|
||||
@ -13,5 +17,6 @@ if "pytest" in str(sys.argv):
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def pre_setup():
|
||||
api.IMAGINAIRY_SAFETY_MODE = "disabled"
|
||||
with fix_torch_nn_layer_norm():
|
||||
suppress_annoying_logs_and_warnings()
|
||||
with fix_torch_nn_layer_norm(), fix_torch_group_norm(), platform_appropriate_autocast():
|
||||
yield
|
||||
|
@ -17,4 +17,4 @@ def test_text_conditioning():
|
||||
if "mps" in get_device():
|
||||
assert hashed == "263e5ee7d2be087d816e094b80ffc546"
|
||||
elif "cuda" in get_device():
|
||||
assert hashed == "3d7867d5b2ebf15102a9ca9476d63ebc"
|
||||
assert hashed == "41818051d7c469fc57d0a940c9d24d82"
|
||||
|
@ -1,9 +1,12 @@
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from imaginairy.cmds import imagine_cmd
|
||||
from imaginairy.utils import get_device
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_imagine_cmd():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
|
@ -20,7 +20,7 @@ def test_fix_faces():
|
||||
if "mps" in get_device():
|
||||
assert img_hash(img) == "a75991307eda675a26eeb7073f828e93"
|
||||
else:
|
||||
assert img_hash(img) == "5aa847a1464de75b158658a35800b6bf"
|
||||
assert img_hash(img) == "e56c1205bbc8f251be05773f2ba7fa24"
|
||||
|
||||
|
||||
def img_hash(img):
|
||||
|
@ -8,7 +8,7 @@ from imaginairy.utils import get_device
|
||||
from . import TESTS_FOLDER
|
||||
|
||||
device_sampler_type_test_cases = {
|
||||
"mps:0": {
|
||||
"mps:0": [
|
||||
("plms", "b4b434ed45919f3505ac2be162791c71"),
|
||||
("ddim", "b369032a025915c0a7ccced165a609b3"),
|
||||
("k_lms", "b87325c189799d646ccd07b331564eb6"),
|
||||
@ -17,8 +17,8 @@ device_sampler_type_test_cases = {
|
||||
("k_euler", "d126da5ca8b08099cde8b5037464e788"),
|
||||
("k_euler_a", "cac5ca2e26c31a544b76a9442eb2ea37"),
|
||||
("k_heun", "0382ef71d9967fefd15676410289ebab"),
|
||||
},
|
||||
"cuda": {
|
||||
],
|
||||
"cuda": [
|
||||
("plms", "62e78287e7848e48d45a1b207fb84102"),
|
||||
("ddim", "164c2a008b100e5fa07d3db2018605bd"),
|
||||
("k_lms", "450fea507ccfb44b677d30fae9f40a52"),
|
||||
@ -27,7 +27,8 @@ device_sampler_type_test_cases = {
|
||||
("k_euler", "06df9c19d472bfa6530db98be4ea10e8"),
|
||||
("k_euler_a", "79552628ff77914c8b6870703fe116b5"),
|
||||
("k_heun", "8ced3578ae25d34da9f4e4b1a20bf416"),
|
||||
},
|
||||
],
|
||||
"cpu": [],
|
||||
}
|
||||
sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
|
||||
|
||||
@ -54,12 +55,14 @@ device_sampler_type_test_cases_img_2_img = {
|
||||
("plms", "efba8b836b51d262dbf72284844869f8"),
|
||||
("ddim", "a62878000ad3b581a11dd3fb329dc7d2"),
|
||||
},
|
||||
"cpu": [],
|
||||
}
|
||||
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[
|
||||
get_device()
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
|
||||
def test_img_to_img(sampler_type, expected_md5):
|
||||
prompt = ImaginePrompt(
|
||||
@ -79,6 +82,7 @@ def test_img_to_img(sampler_type, expected_md5):
|
||||
assert result.md5() == expected_md5
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_img_to_img_from_url():
|
||||
prompt = ImaginePrompt(
|
||||
"dogs lying on a hot pink couch",
|
||||
@ -96,6 +100,7 @@ def test_img_to_img_from_url():
|
||||
imagine_image_files(prompt, outdir=out_folder)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_img_to_file():
|
||||
prompt = ImaginePrompt(
|
||||
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
|
||||
@ -110,6 +115,7 @@ def test_img_to_file():
|
||||
imagine_image_files(prompt, outdir=out_folder)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_inpainting():
|
||||
prompt = ImaginePrompt(
|
||||
"a basketball on a bench",
|
||||
@ -126,6 +132,7 @@ def test_inpainting():
|
||||
imagine_image_files(prompt, outdir=out_folder)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_cliptext_inpainting():
|
||||
prompts = [
|
||||
ImaginePrompt(
|
||||
|
@ -18,6 +18,7 @@ def test_is_nsfw():
|
||||
|
||||
def _pil_to_latent(img):
|
||||
model = load_model()
|
||||
model.tile_mode(False)
|
||||
img = pillow_img_to_torch_image(img)
|
||||
img = img.to(get_device())
|
||||
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
|
||||
|
Loading…
Reference in New Issue
Block a user