Merge pull request #18 from brycedrennan/bugfixes

Bugfixes + per-prompt tile mode
pull/21/head 1.6.0
Bryce Drennan 2 years ago committed by GitHub
commit 08fca72033
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,21 @@
__pycache__
*.pyc
*.pyo
*.pyd
.Python
env
pip-log.txt
pip-delete-this-directory.txt
.tox
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.log
.git
.mypy_cache
.pytest_cache
.hypothesis
.DS_Store

@ -51,6 +51,12 @@ deploy: ## Deploy the package to pypi.org
rm -rf dist rm -rf dist
@echo "Deploy successful! ✨ 🍰 ✨" @echo "Deploy successful! ✨ 🍰 ✨"
build-dev-image:
docker build -f tests/Dockerfile -t imaginairy-dev .
run-dev: build-dev-image
docker run -it -v $$HOME/.cache/huggingface:/root/.cache/huggingface -v $$HOME/.cache/torch:/root/.cache/torch -v `pwd`/outputs:/outputs imaginairy-dev /bin/bash
requirements: ## Freeze the requirements.txt file requirements: ## Freeze the requirements.txt file
pip-compile setup.py requirements-dev.in --output-file=requirements-dev.txt --upgrade pip-compile setup.py requirements-dev.in --output-file=requirements-dev.txt --upgrade

@ -117,7 +117,7 @@ from imaginairy import imagine, imagine_image_files, ImaginePrompt, WeightedProm
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/6c/Thomas_Cole_-_Architect%E2%80%99s_Dream_-_Google_Art_Project.jpg/540px-Thomas_Cole_-_Architect%E2%80%99s_Dream_-_Google_Art_Project.jpg" url = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/6c/Thomas_Cole_-_Architect%E2%80%99s_Dream_-_Google_Art_Project.jpg/540px-Thomas_Cole_-_Architect%E2%80%99s_Dream_-_Google_Art_Project.jpg"
prompts = [ prompts = [
ImaginePrompt("a scenic landscape", seed=1), ImaginePrompt("a scenic landscape", seed=1, upscale=True),
ImaginePrompt("a bowl of fruit"), ImaginePrompt("a bowl of fruit"),
ImaginePrompt([ ImaginePrompt([
WeightedPrompt("cat", weight=1), WeightedPrompt("cat", weight=1),
@ -133,7 +133,8 @@ prompts = [
mask_prompt="fruit|stems", mask_prompt="fruit|stems",
mask_mode="replace", mask_mode="replace",
mask_expansion=3 mask_expansion=3
) ),
ImaginePrompt("strawberries", tile_mode=True),
] ]
for result in imagine(prompts): for result in imagine(prompts):
# do something # do something
@ -162,8 +163,16 @@ docker build . -t imaginairy
docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -v $HOME/.cache/torch:/root/.cache/torch -v `pwd`/outputs:/outputs imaginairy /bin/bash docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -v $HOME/.cache/torch:/root/.cache/torch -v `pwd`/outputs:/outputs imaginairy /bin/bash
``` ```
## Running on Google Colab
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
## ChangeLog ## ChangeLog
**1.6.0**
- fix: *maybe* address #13 with `expected scalar type BFloat16 but found Float`
- at minimum one can specify `--precision full` now and that will probably fix the issue
- feature: tile mode can now be specified per-prompt
**1.5.3** **1.5.3**
- fix: missing config file for describe feature - fix: missing config file for describe feature

@ -1,7 +1,6 @@
import logging import logging
import os import os
import re import re
from contextlib import nullcontext
from functools import lru_cache from functools import lru_cache
import numpy as np import numpy as np
@ -11,7 +10,6 @@ from einops import rearrange
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageDraw, ImageFilter, ImageOps from PIL import Image, ImageDraw, ImageFilter, ImageOps
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from torch import autocast
from transformers import cached_path from transformers import cached_path
from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.clip_masking import get_img_mask
@ -29,11 +27,13 @@ from imaginairy.samplers.base import get_sampler
from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import ( from imaginairy.utils import (
expand_mask, expand_mask,
fix_torch_group_norm,
fix_torch_nn_layer_norm, fix_torch_nn_layer_norm,
get_device, get_device,
instantiate_from_config, instantiate_from_config,
pillow_fit_image_within, pillow_fit_image_within,
pillow_img_to_torch_image, pillow_img_to_torch_image,
platform_appropriate_autocast,
) )
LIB_PATH = os.path.dirname(__file__) LIB_PATH = os.path.dirname(__file__)
@ -73,31 +73,11 @@ def load_model_from_config(config):
return model return model
def patch_conv(**patch):
"""
Patch to enable tiling mode
https://github.com/replicate/cog-stable-diffusion/compare/main...TomMoore515:material_stable_diffusion:main
"""
cls = torch.nn.Conv2d
init = cls.__init__
def __init__(self, *args, **kwargs):
return init(self, *args, **kwargs, **patch)
cls.__init__ = __init__
@lru_cache() @lru_cache()
def load_model(tile_mode=False): def load_model():
if tile_mode:
# generated images are tileable
patch_conv(padding_mode="circular")
config = "configs/stable-diffusion-v1.yaml" config = "configs/stable-diffusion-v1.yaml"
config = OmegaConf.load(f"{LIB_PATH}/{config}") config = OmegaConf.load(f"{LIB_PATH}/{config}")
model = load_model_from_config(config) model = load_model_from_config(config)
model = model.to(get_device()) model = model.to(get_device())
return model return model
@ -111,7 +91,6 @@ def imagine_image_files(
ddim_eta=0.0, ddim_eta=0.0,
record_step_images=False, record_step_images=False,
output_file_extension="jpg", output_file_extension="jpg",
tile_mode=False,
print_caption=False, print_caption=False,
): ):
big_path = os.path.join(outdir, "upscaled") big_path = os.path.join(outdir, "upscaled")
@ -139,7 +118,6 @@ def imagine_image_files(
precision=precision, precision=precision,
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
img_callback=_record_step if record_step_images else None, img_callback=_record_step if record_step_images else None,
tile_mode=tile_mode,
add_caption=print_caption, add_caption=print_caption,
): ):
prompt = result.prompt prompt = result.prompt
@ -164,11 +142,10 @@ def imagine(
precision="autocast", precision="autocast",
ddim_eta=0.0, ddim_eta=0.0,
img_callback=None, img_callback=None,
tile_mode=False,
half_mode=None, half_mode=None,
add_caption=False, add_caption=False,
): ):
model = load_model(tile_mode=tile_mode) model = load_model()
# only run half-mode on cuda. run it by default # only run half-mode on cuda. run it by default
half_mode = half_mode is None and get_device() == "cuda" half_mode = half_mode is None and get_device() == "cuda"
@ -179,13 +156,12 @@ def imagine(
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
_img_callback = None _img_callback = None
if get_device() == "cpu":
logger.info("Running in CPU mode. it's gonna be slooooooow.")
precision_scope = ( with torch.no_grad(), platform_appropriate_autocast(
autocast precision
if precision == "autocast" and get_device() in ("cuda", "cpu") ), fix_torch_nn_layer_norm(), fix_torch_group_norm():
else nullcontext
)
with torch.no_grad(), precision_scope(get_device()), fix_torch_nn_layer_norm():
for prompt in prompts: for prompt in prompts:
with ImageLoggingContext( with ImageLoggingContext(
prompt=prompt, prompt=prompt,
@ -194,6 +170,7 @@ def imagine(
): ):
logger.info(f"Generating {prompt.prompt_description()}") logger.info(f"Generating {prompt.prompt_description()}")
seed_everything(prompt.seed) seed_everything(prompt.seed)
model.tile_mode(prompt.tile_mode)
uc = None uc = None
if prompt.prompt_strength != 1.0: if prompt.prompt_strength != 1.0:

@ -121,7 +121,7 @@ def configure_logging(level="INFO"):
@click.option( @click.option(
"--tile", "--tile",
is_flag=True, is_flag=True,
help="Any images rendered will be tileable. Unfortunately cannot be controlled at the per-image level yet", help="Any images rendered will be tileable.",
) )
@click.option( @click.option(
"--mask-image", "--mask-image",
@ -149,6 +149,12 @@ def configure_logging(level="INFO"):
is_flag=True, is_flag=True,
help="Generate a text description of the generated image", help="Generate a text description of the generated image",
) )
@click.option(
"--precision",
help="evaluate at this precision",
type=click.Choice(["full", "autocast"]),
default="autocast",
)
@click.pass_context @click.pass_context
def imagine_cmd( def imagine_cmd(
ctx, ctx,
@ -174,6 +180,7 @@ def imagine_cmd(
mask_mode, mask_mode,
mask_expansion, mask_expansion,
caption, caption,
precision,
): ):
"""Have the AI generate images. alias:imagine""" """Have the AI generate images. alias:imagine"""
if ctx.invoked_subcommand is not None: if ctx.invoked_subcommand is not None:
@ -190,7 +197,7 @@ def imagine_cmd(
init_image = LazyLoadingImage(url=init_image) init_image = LazyLoadingImage(url=init_image)
prompts = [] prompts = []
load_model(tile_mode=tile) load_model()
for _ in range(repeats): for _ in range(repeats):
for prompt_text in prompt_texts: for prompt_text in prompt_texts:
prompt = ImaginePrompt( prompt = ImaginePrompt(
@ -209,6 +216,7 @@ def imagine_cmd(
mask_mode=mask_mode, mask_mode=mask_mode,
upscale=upscale, upscale=upscale,
fix_faces=fix_faces, fix_faces=fix_faces,
tile_mode=tile,
) )
prompts.append(prompt) prompts.append(prompt)
@ -217,9 +225,9 @@ def imagine_cmd(
outdir=outdir, outdir=outdir,
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
record_step_images="images" in show_work, record_step_images="images" in show_work,
tile_mode=tile,
output_file_extension="png", output_file_extension="png",
print_caption=caption, print_caption=caption,
precision=precision,
) )

@ -273,6 +273,18 @@ class LatentDiffusion(DDPM):
self.init_from_ckpt(ckpt_path, ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True self.restarted_from_ckpt = True
# store initial padding mode so we can switch to 'circular'
# when we want tiled images
for m in self.modules():
if isinstance(m, nn.Conv2d):
m._initial_padding_mode = m.padding_mode
def tile_mode(self, enabled):
"""For creating seamless tiles"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.padding_mode = "circular" if enabled else m._initial_padding_mode
def make_cond_schedule( def make_cond_schedule(
self, self,
): ):

@ -24,23 +24,24 @@ class DiagonalGaussianDistribution:
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.0]) return torch.Tensor([0.0])
else:
if other is None: if other is None:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3], dim=[1, 2, 3],
) )
else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var + self.var / other.var
- 1.0 - 1.0
- self.logvar - self.logvar
+ other.logvar, + other.logvar,
dim=[1, 2, 3], dim=[1, 2, 3],
) )
def nll(self, sample, dims=[1, 2, 3]): def nll(self, sample, dims=None):
dims = dims if dims is None else [1, 2, 3]
if self.deterministic: if self.deterministic:
return torch.Tensor([0.0]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)

@ -103,6 +103,7 @@ class ImaginePrompt:
fix_faces=False, fix_faces=False,
sampler_type="PLMS", sampler_type="PLMS",
conditioning=None, conditioning=None,
tile_mode=False,
): ):
prompt = prompt if prompt is not None else "a scenic landscape" prompt = prompt if prompt is not None else "a scenic landscape"
if isinstance(prompt, str): if isinstance(prompt, str):
@ -131,6 +132,7 @@ class ImaginePrompt:
self.mask_image = mask_image self.mask_image = mask_image
self.mask_mode = mask_mode self.mask_mode = mask_mode
self.mask_expansion = mask_expansion self.mask_expansion = mask_expansion
self.tile_mode = tile_mode
@property @property
def prompt_text(self): def prompt_text(self):

@ -2,7 +2,7 @@ import importlib
import logging import logging
import os.path import os.path
import platform import platform
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from functools import lru_cache from functools import lru_cache
from typing import List, Optional from typing import List, Optional
@ -10,7 +10,7 @@ import numpy as np
import requests import requests
import torch import torch
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from torch import Tensor from torch import Tensor, autocast
from torch.nn import functional from torch.nn import functional
from torch.overrides import handle_torch_function, has_torch_function_variadic from torch.overrides import handle_torch_function, has_torch_function_variadic
from transformers import cached_path from transformers import cached_path
@ -61,6 +61,18 @@ def get_obj_from_str(string, reload=False):
return getattr(importlib.import_module(module, package=None), cls) return getattr(importlib.import_module(module, package=None), cls)
@contextmanager
def platform_appropriate_autocast(precision="autocast"):
"""
allow calculations to run in mixed precision, which can be faster
"""
precision_scope = nullcontext
if precision == "autocast" and get_device() in ("cuda", "cpu"):
precision_scope = autocast
with precision_scope(get_device()):
yield
def _fixed_layer_norm( def _fixed_layer_norm(
input: Tensor, # noqa input: Tensor, # noqa
normalized_shape: List[int], normalized_shape: List[int],
@ -104,6 +116,43 @@ def fix_torch_nn_layer_norm():
functional.layer_norm = orig_function functional.layer_norm = orig_function
@contextmanager
def fix_torch_group_norm():
"""
Patch group_norm to cast the weights to the same type as the inputs
From what I can understand all the other repos just switch to full precision instead
of addressing this. I think this would make things slower but I'm not sure.
https://github.com/pytorch/pytorch/pull/81852
"""
orig_group_norm = functional.group_norm
def _group_norm_wrapper(
input: Tensor, # noqa
num_groups: int,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
) -> Tensor:
if weight is not None and weight.dtype != input.dtype:
weight = weight.to(input.dtype)
if bias is not None and bias.dtype != input.dtype:
bias = bias.to(input.dtype)
return orig_group_norm(
input=input, num_groups=num_groups, weight=weight, bias=bias, eps=eps
)
functional.group_norm = _group_norm_wrapper
try:
yield
finally:
functional.group_norm = orig_group_norm
def expand_mask(mask_image, size): def expand_mask(mask_image, size):
if size < 0: if size < 0:
threshold = 0.95 threshold = 0.95

@ -6,3 +6,4 @@ pydocstyle
pylama pylama
pylint pylint
pytest pytest
pytest-randomly

@ -10,7 +10,7 @@ absl-py==1.2.0
# tensorboard # tensorboard
addict==2.4.0 addict==2.4.0
# via basicsr # via basicsr
aiohttp==3.8.1 aiohttp==3.8.3
# via fsspec # via fsspec
aiosignal==1.2.0 aiosignal==1.2.0
# via aiohttp # via aiohttp
@ -68,7 +68,7 @@ filelock==3.8.0
# transformers # transformers
filterpy==1.4.5 filterpy==1.4.5
# via facexlib # via facexlib
fonttools==4.37.2 fonttools==4.37.3
# via matplotlib # via matplotlib
frozenlist==1.3.1 frozenlist==1.3.1
# via # via
@ -86,7 +86,7 @@ gfpgan==1.3.8
# via # via
# imaginAIry (setup.py) # imaginAIry (setup.py)
# realesrgan # realesrgan
google-auth==2.11.0 google-auth==2.11.1
# via # via
# google-auth-oauthlib # google-auth-oauthlib
# tb-nightly # tb-nightly
@ -95,7 +95,7 @@ google-auth-oauthlib==0.4.6
# via # via
# tb-nightly # tb-nightly
# tensorboard # tensorboard
grpcio==1.49.0 grpcio==1.48.1
# via # via
# tb-nightly # tb-nightly
# tensorboard # tensorboard
@ -212,6 +212,7 @@ pillow==9.2.0
# diffusers # diffusers
# facexlib # facexlib
# imageio # imageio
# imaginAIry (setup.py)
# matplotlib # matplotlib
# realesrgan # realesrgan
# scikit-image # scikit-image
@ -249,13 +250,17 @@ pyflakes==2.5.0
# via pylama # via pylama
pylama==8.4.1 pylama==8.4.1
# via -r requirements-dev.in # via -r requirements-dev.in
pylint==2.15.2 pylint==2.15.3
# via -r requirements-dev.in # via -r requirements-dev.in
pyparsing==3.0.9 pyparsing==3.0.9
# via # via
# matplotlib # matplotlib
# packaging # packaging
pytest==7.1.3 pytest==7.1.3
# via
# -r requirements-dev.in
# pytest-randomly
pytest-randomly==3.12.0
# via -r requirements-dev.in # via -r requirements-dev.in
python-dateutil==2.8.2 python-dateutil==2.8.2
# via matplotlib # via matplotlib
@ -273,7 +278,7 @@ pyyaml==6.0
# pycln # pycln
# pytorch-lightning # pytorch-lightning
# transformers # transformers
realesrgan==0.2.8 realesrgan==0.3.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
regex==2022.9.13 regex==2022.9.13
# via # via
@ -311,7 +316,7 @@ six==1.16.0
# python-dateutil # python-dateutil
snowballstemmer==2.2.0 snowballstemmer==2.2.0
# via pydocstyle # via pydocstyle
tb-nightly==2.11.0a20220918 tb-nightly==2.11.0a20220921
# via # via
# basicsr # basicsr
# gfpgan # gfpgan

@ -7,7 +7,7 @@ setup(
name="imaginAIry", name="imaginAIry",
author="Bryce Drennan", author="Bryce Drennan",
# author_email="b r y p y d o t io", # author_email="b r y p y d o t io",
version="1.5.4", version="1.6.0",
description="AI imagined images. Pythonic generation of stable diffusion images.", description="AI imagined images. Pythonic generation of stable diffusion images.",
long_description=readme, long_description=readme,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

@ -0,0 +1,36 @@
FROM python:3.10.6-slim as base
RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 make
ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \
PIP_ROOT_USER_ACTION=ignore
FROM base as build_wheel
RUN pip install wheel
WORKDIR /app
COPY imaginairy ./imaginairy
COPY setup.py README.md ./
RUN python setup.py bdist_wheel
FROM base as install_wheel
WORKDIR /app
COPY requirements-dev.in ./
RUN pip install -r requirements-dev.in
COPY --from=build_wheel /app/dist/* ./
RUN pip install *.whl
RUN imagine --help
COPY Makefile ./
COPY tests ./tests

@ -4,7 +4,11 @@ import pytest
from imaginairy import api from imaginairy import api
from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
from imaginairy.utils import fix_torch_nn_layer_norm from imaginairy.utils import (
fix_torch_group_norm,
fix_torch_nn_layer_norm,
platform_appropriate_autocast,
)
if "pytest" in str(sys.argv): if "pytest" in str(sys.argv):
suppress_annoying_logs_and_warnings() suppress_annoying_logs_and_warnings()
@ -13,5 +17,6 @@ if "pytest" in str(sys.argv):
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def pre_setup(): def pre_setup():
api.IMAGINAIRY_SAFETY_MODE = "disabled" api.IMAGINAIRY_SAFETY_MODE = "disabled"
with fix_torch_nn_layer_norm(): suppress_annoying_logs_and_warnings()
with fix_torch_nn_layer_norm(), fix_torch_group_norm(), platform_appropriate_autocast():
yield yield

@ -17,4 +17,4 @@ def test_text_conditioning():
if "mps" in get_device(): if "mps" in get_device():
assert hashed == "263e5ee7d2be087d816e094b80ffc546" assert hashed == "263e5ee7d2be087d816e094b80ffc546"
elif "cuda" in get_device(): elif "cuda" in get_device():
assert hashed == "3d7867d5b2ebf15102a9ca9476d63ebc" assert hashed == "41818051d7c469fc57d0a940c9d24d82"

@ -1,9 +1,12 @@
import pytest
from click.testing import CliRunner from click.testing import CliRunner
from imaginairy.cmds import imagine_cmd from imaginairy.cmds import imagine_cmd
from imaginairy.utils import get_device
from tests import TESTS_FOLDER from tests import TESTS_FOLDER
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_imagine_cmd(): def test_imagine_cmd():
runner = CliRunner() runner = CliRunner()
result = runner.invoke( result = runner.invoke(

@ -20,7 +20,7 @@ def test_fix_faces():
if "mps" in get_device(): if "mps" in get_device():
assert img_hash(img) == "a75991307eda675a26eeb7073f828e93" assert img_hash(img) == "a75991307eda675a26eeb7073f828e93"
else: else:
assert img_hash(img) == "5aa847a1464de75b158658a35800b6bf" assert img_hash(img) == "e56c1205bbc8f251be05773f2ba7fa24"
def img_hash(img): def img_hash(img):

@ -8,7 +8,7 @@ from imaginairy.utils import get_device
from . import TESTS_FOLDER from . import TESTS_FOLDER
device_sampler_type_test_cases = { device_sampler_type_test_cases = {
"mps:0": { "mps:0": [
("plms", "b4b434ed45919f3505ac2be162791c71"), ("plms", "b4b434ed45919f3505ac2be162791c71"),
("ddim", "b369032a025915c0a7ccced165a609b3"), ("ddim", "b369032a025915c0a7ccced165a609b3"),
("k_lms", "b87325c189799d646ccd07b331564eb6"), ("k_lms", "b87325c189799d646ccd07b331564eb6"),
@ -17,8 +17,8 @@ device_sampler_type_test_cases = {
("k_euler", "d126da5ca8b08099cde8b5037464e788"), ("k_euler", "d126da5ca8b08099cde8b5037464e788"),
("k_euler_a", "cac5ca2e26c31a544b76a9442eb2ea37"), ("k_euler_a", "cac5ca2e26c31a544b76a9442eb2ea37"),
("k_heun", "0382ef71d9967fefd15676410289ebab"), ("k_heun", "0382ef71d9967fefd15676410289ebab"),
}, ],
"cuda": { "cuda": [
("plms", "62e78287e7848e48d45a1b207fb84102"), ("plms", "62e78287e7848e48d45a1b207fb84102"),
("ddim", "164c2a008b100e5fa07d3db2018605bd"), ("ddim", "164c2a008b100e5fa07d3db2018605bd"),
("k_lms", "450fea507ccfb44b677d30fae9f40a52"), ("k_lms", "450fea507ccfb44b677d30fae9f40a52"),
@ -27,7 +27,8 @@ device_sampler_type_test_cases = {
("k_euler", "06df9c19d472bfa6530db98be4ea10e8"), ("k_euler", "06df9c19d472bfa6530db98be4ea10e8"),
("k_euler_a", "79552628ff77914c8b6870703fe116b5"), ("k_euler_a", "79552628ff77914c8b6870703fe116b5"),
("k_heun", "8ced3578ae25d34da9f4e4b1a20bf416"), ("k_heun", "8ced3578ae25d34da9f4e4b1a20bf416"),
}, ],
"cpu": [],
} }
sampler_type_test_cases = device_sampler_type_test_cases[get_device()] sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
@ -54,12 +55,14 @@ device_sampler_type_test_cases_img_2_img = {
("plms", "efba8b836b51d262dbf72284844869f8"), ("plms", "efba8b836b51d262dbf72284844869f8"),
("ddim", "a62878000ad3b581a11dd3fb329dc7d2"), ("ddim", "a62878000ad3b581a11dd3fb329dc7d2"),
}, },
"cpu": [],
} }
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[ sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[
get_device() get_device()
] ]
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img) @pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
def test_img_to_img(sampler_type, expected_md5): def test_img_to_img(sampler_type, expected_md5):
prompt = ImaginePrompt( prompt = ImaginePrompt(
@ -79,6 +82,7 @@ def test_img_to_img(sampler_type, expected_md5):
assert result.md5() == expected_md5 assert result.md5() == expected_md5
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_img_to_img_from_url(): def test_img_to_img_from_url():
prompt = ImaginePrompt( prompt = ImaginePrompt(
"dogs lying on a hot pink couch", "dogs lying on a hot pink couch",
@ -96,6 +100,7 @@ def test_img_to_img_from_url():
imagine_image_files(prompt, outdir=out_folder) imagine_image_files(prompt, outdir=out_folder)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_img_to_file(): def test_img_to_file():
prompt = ImaginePrompt( prompt = ImaginePrompt(
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo", "an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
@ -110,6 +115,7 @@ def test_img_to_file():
imagine_image_files(prompt, outdir=out_folder) imagine_image_files(prompt, outdir=out_folder)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_inpainting(): def test_inpainting():
prompt = ImaginePrompt( prompt = ImaginePrompt(
"a basketball on a bench", "a basketball on a bench",
@ -126,6 +132,7 @@ def test_inpainting():
imagine_image_files(prompt, outdir=out_folder) imagine_image_files(prompt, outdir=out_folder)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_cliptext_inpainting(): def test_cliptext_inpainting():
prompts = [ prompts = [
ImaginePrompt( ImaginePrompt(

@ -18,6 +18,7 @@ def test_is_nsfw():
def _pil_to_latent(img): def _pil_to_latent(img):
model = load_model() model = load_model()
model.tile_mode(False)
img = pillow_img_to_torch_image(img) img = pillow_img_to_torch_image(img)
img = img.to(get_device()) img = img.to(get_device())
latent = model.get_first_stage_encoding(model.encode_first_stage(img)) latent = model.get_first_stage_encoding(model.encode_first_stage(img))

Loading…
Cancel
Save