mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
fix: inpainting producing blurry images
while the previous version did produce much better blending it also makes images that lack detail for some reason. tests: Added more tests to help catch this sort of thing earlies fix: found that median blur is really slow, so I made sure we only do it on downsampled masks. Was taking like 3 minutes to run on the large pearl girl picture on M1 - docs: update examples
This commit is contained in:
parent
0fb03f2a1f
commit
95a8fa31a9
3
.gitignore
vendored
3
.gitignore
vendored
@ -16,4 +16,5 @@ dist
|
||||
**/*.egg-info
|
||||
tests/test_output
|
||||
gfpgan/**
|
||||
.python-version
|
||||
.python-version
|
||||
._.DS_Store
|
16
README.md
16
README.md
@ -43,16 +43,18 @@ operators also work. When writing strength modifies know that pixel values are
|
||||
```bash
|
||||
>> imagine \
|
||||
--init-image pearl_earring.jpg \
|
||||
--mask-prompt "face{*1.9}" \
|
||||
--mask-prompt "face AND NOT (bandana OR hair OR blue fabric){*6}" \
|
||||
--mask-mode keep \
|
||||
--init-image-strength .4 \
|
||||
"a female doctor" "an elegant woman"
|
||||
--init-image-strength .2 \
|
||||
--fix-faces \
|
||||
"a modern female president" "a female robot" "a female doctor" "a female firefighter"
|
||||
```
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl000.jpg" height="200">➡️
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl002.jpg" height="200">
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl004.jpg" height="200">
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl001.jpg" height="200">
|
||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/mask_examples/pearl003.jpg" height="200">
|
||||
<img src="assets/mask_examples/pearl_pres.png" height="200">
|
||||
<img src="assets/mask_examples/pearl_robot.png" height="200">
|
||||
<img src="assets/mask_examples/pearl_doctor.png" height="200">
|
||||
<img src="assets/mask_examples/pearl_firefighter.png" height="200">
|
||||
|
||||
```bash
|
||||
>> imagine \
|
||||
--init-image fruit-bowl.jpg \
|
||||
|
BIN
assets/mask_examples/pearl_doctor.png
Normal file
BIN
assets/mask_examples/pearl_doctor.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 311 KiB |
BIN
assets/mask_examples/pearl_firefighter.png
Normal file
BIN
assets/mask_examples/pearl_firefighter.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 327 KiB |
BIN
assets/mask_examples/pearl_pres.png
Normal file
BIN
assets/mask_examples/pearl_pres.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 309 KiB |
BIN
assets/mask_examples/pearl_robot.png
Normal file
BIN
assets/mask_examples/pearl_robot.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 302 KiB |
@ -120,7 +120,10 @@ def imagine_image_files(
|
||||
add_caption=print_caption,
|
||||
):
|
||||
prompt = result.prompt
|
||||
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
|
||||
img_str = ""
|
||||
if prompt.init_image:
|
||||
img_str = f"_img2img-{prompt.init_image_strength}"
|
||||
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}"
|
||||
|
||||
for image_type in result.images:
|
||||
subpath = os.path.join(outdir, image_type)
|
||||
@ -261,6 +264,7 @@ def imagine(
|
||||
|
||||
log_latent(init_latent, "init_latent")
|
||||
# encode (scaled latent)
|
||||
seed_everything(prompt.seed)
|
||||
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
||||
z_enc = sampler.stochastic_encode(
|
||||
init_latent,
|
||||
|
@ -244,7 +244,7 @@ def imagine_cmd(
|
||||
outdir=outdir,
|
||||
ddim_eta=ddim_eta,
|
||||
record_step_images=show_work,
|
||||
output_file_extension="png",
|
||||
output_file_extension="jpg",
|
||||
print_caption=caption,
|
||||
precision=precision,
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ from kornia.filters import median_blur
|
||||
from torchvision import transforms
|
||||
|
||||
from imaginairy.img_log import log_img
|
||||
from imaginairy.img_utils import pillow_fit_image_within
|
||||
from imaginairy.vendored.clipseg import CLIPDensePredT
|
||||
|
||||
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
|
||||
@ -41,6 +42,8 @@ def get_img_mask(
|
||||
parsed = MASK_PROMPT.parseString(mask_description_statement)
|
||||
parsed_mask = parsed[0][0]
|
||||
descriptions = list(parsed_mask.gather_text_descriptions())
|
||||
orig_size = img.size
|
||||
img = pillow_fit_image_within(img, max_height=352, max_width=352)
|
||||
mask_cache = get_img_masks(img, descriptions)
|
||||
mask = parsed_mask.apply_masks(mask_cache)
|
||||
log_img(mask, "combined mask")
|
||||
@ -49,7 +52,7 @@ def get_img_mask(
|
||||
mask = median_blur(mask.unsqueeze(dim=0).unsqueeze(dim=0), (11, 11)).squeeze()
|
||||
log_img(mask, "median blurred")
|
||||
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
mask_g = mask.clone()
|
||||
|
||||
# trial and error shows 0.5 threshold has the best "shape"
|
||||
@ -59,7 +62,7 @@ def get_img_mask(
|
||||
log_img(mask, f"mask threshold {0.5}")
|
||||
|
||||
mask_np = mask.cpu().numpy()
|
||||
smoother_strength = 5
|
||||
smoother_strength = 2
|
||||
# grow the mask area to make sure we've masked the thing we care about
|
||||
for _ in range(smoother_strength):
|
||||
mask_np = cv2.dilate(mask_np, kernel)
|
||||
@ -67,7 +70,13 @@ def get_img_mask(
|
||||
mask = torch.from_numpy(mask_np)
|
||||
log_img(mask, "mask after closing (dilation then erosion)")
|
||||
|
||||
return transforms.ToPILImage()(mask), transforms.ToPILImage()(mask_g)
|
||||
mask_img = transforms.ToPILImage()(mask).resize(
|
||||
orig_size, resample=PIL.Image.Resampling.LANCZOS
|
||||
)
|
||||
mask_img_g = transforms.ToPILImage()(mask_g).resize(
|
||||
orig_size, resample=PIL.Image.Resampling.LANCZOS
|
||||
)
|
||||
return mask_img, mask_img_g
|
||||
|
||||
|
||||
def get_img_masks(img, mask_descriptions: Sequence[str]):
|
||||
|
@ -361,9 +361,12 @@ class DDIMSampler:
|
||||
log_latent(xdec_orig, "xdec_orig")
|
||||
# this helps prevent the weird disjointed images that can happen with masking
|
||||
hint_strength = 0.8
|
||||
xdec_orig_with_hints = (
|
||||
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
||||
)
|
||||
if i < 2:
|
||||
xdec_orig_with_hints = (
|
||||
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
||||
)
|
||||
else:
|
||||
xdec_orig_with_hints = xdec_orig
|
||||
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
|
||||
log_latent(x_dec, "x_dec")
|
||||
|
||||
|
@ -423,14 +423,17 @@ class PLMSSampler:
|
||||
if mask is not None:
|
||||
assert orig_latent is not None
|
||||
xdec_orig = self.model.q_sample(orig_latent, ts, noise)
|
||||
log_latent(xdec_orig, "xdec_orig")
|
||||
log_latent(xdec_orig, f"xdec_orig i={i} index-{index}")
|
||||
# this helps prevent the weird disjointed images that can happen with masking
|
||||
hint_strength = 0.8
|
||||
xdec_orig_with_hints = (
|
||||
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
||||
)
|
||||
if i < 2:
|
||||
xdec_orig_with_hints = (
|
||||
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
||||
)
|
||||
else:
|
||||
xdec_orig_with_hints = xdec_orig
|
||||
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
|
||||
log_latent(x_dec, "x_dec")
|
||||
log_latent(x_dec, f"x_dec {ts}")
|
||||
|
||||
x_dec, pred_x0, e_t = self.p_sample_plms(
|
||||
x_dec,
|
||||
|
@ -1,3 +1,3 @@
|
||||
import os.path
|
||||
|
||||
TESTS_FOLDER = os.path.dirname(__file__)
|
||||
TESTS_FOLDER = os.path.abspath(os.path.dirname(__file__))
|
||||
|
@ -10,6 +10,7 @@ from imaginairy.suppress_logs import suppress_annoying_logs_and_warnings
|
||||
from imaginairy.utils import (
|
||||
fix_torch_group_norm,
|
||||
fix_torch_nn_layer_norm,
|
||||
get_device,
|
||||
platform_appropriate_autocast,
|
||||
)
|
||||
from tests import TESTS_FOLDER
|
||||
@ -24,6 +25,10 @@ logger = logging.getLogger(__name__)
|
||||
def pre_setup():
|
||||
api.IMAGINAIRY_SAFETY_MODE = "disabled"
|
||||
suppress_annoying_logs_and_warnings()
|
||||
# test_output_folder = f"{TESTS_FOLDER}/test_output"
|
||||
|
||||
# delete the testoutput folder and recreate it
|
||||
# rmtree(test_output_folder)
|
||||
os.makedirs(f"{TESTS_FOLDER}/test_output", exist_ok=True)
|
||||
|
||||
orig_urlopen = HTTPConnectionPool.urlopen
|
||||
@ -32,9 +37,17 @@ def pre_setup():
|
||||
# traceback.print_stack()
|
||||
print(os.environ.get("PYTEST_CURRENT_TEST"))
|
||||
print(f"{method} {self.host}{url}")
|
||||
return orig_urlopen(self, method, url, *args, **kwargs)
|
||||
result = orig_urlopen(self, method, url, *args, **kwargs)
|
||||
print(f"{method} {self.host}{url} DONE")
|
||||
return result
|
||||
|
||||
HTTPConnectionPool.urlopen = urlopen_tattle
|
||||
|
||||
with fix_torch_nn_layer_norm(), fix_torch_group_norm(), platform_appropriate_autocast():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def filename_base_for_outputs(request):
|
||||
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_{get_device()}_"
|
||||
return filename_base
|
||||
|
BIN
tests/data/bowl_of_fruit.jpg
Normal file
BIN
tests/data/bowl_of_fruit.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 26 KiB |
268
tests/test_api.py
Normal file
268
tests/test_api.py
Normal file
@ -0,0 +1,268 @@
|
||||
import os.path
|
||||
|
||||
import pytest
|
||||
|
||||
from imaginairy import LazyLoadingImage
|
||||
from imaginairy.api import imagine, imagine_image_files, prompt_normalized
|
||||
from imaginairy.img_utils import pillow_fit_image_within
|
||||
from imaginairy.schema import ImaginePrompt
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
from . import TESTS_FOLDER
|
||||
|
||||
device_sampler_type_test_cases = {
|
||||
"mps:0": [
|
||||
("plms", "78539ae3a3097dc8232da6d630551ab3"),
|
||||
("ddim", "828fc143cd40586347b2f8403c288c9b"),
|
||||
("k_lms", "53d25e59add39c8447537be30e4eff4b"),
|
||||
("k_dpm_2", "5108bceb58a38d88a585f37b2ba1b072"),
|
||||
("k_dpm_2_a", "20396daa6c920d1cfd6db90e73558c01"),
|
||||
("k_euler", "9ab4666ebe6c3aa68673912bb17fb2b1"),
|
||||
("k_euler_a", "c4b03829cc93422801f3243a46bad4bc"),
|
||||
("k_heun", "0d3aad6800d4a9a43f0b0514af9d23b5"),
|
||||
],
|
||||
"cuda": [
|
||||
("plms", "b98e1248ad1f144d34122d8809b39fb8"),
|
||||
("ddim", "a645ca24575ed3f18bf48f11354233bb"),
|
||||
("k_lms", "3ddbdef45e3f38768730961771d01727"),
|
||||
("k_dpm_2", "b6e88e16ec2c43e6382b1adec828479d"),
|
||||
("k_dpm_2_a", "b0791770d48cb22d308ad76c72fb660f"),
|
||||
("k_euler", "bcf375769d64d9ca224864d35565ac1d"),
|
||||
("k_euler_a", "38b970ff6a67428efbf00df66a9e48f7"),
|
||||
("k_heun", "ccbd0804c7ce2bb637c682951bd8b693"),
|
||||
],
|
||||
"cpu": [],
|
||||
}
|
||||
sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases)
|
||||
def test_imagine(sampler_type, expected_md5, filename_base_for_outputs):
|
||||
prompt_text = "a scenic landscape"
|
||||
prompt = ImaginePrompt(
|
||||
prompt_text, width=512, height=256, steps=20, seed=1, sampler_type=sampler_type
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
assert result.md5() == expected_md5
|
||||
|
||||
|
||||
device_sampler_type_test_cases_img_2_img = {
|
||||
"mps:0": {
|
||||
("plms", "0d9c40c348cdac7bdc8d5a472f378f42"),
|
||||
("ddim", "0d9c40c348cdac7bdc8d5a472f378f42"),
|
||||
},
|
||||
"cuda": {
|
||||
("plms", "841723966344dd8678aee1ce5f9cbb3d"),
|
||||
("ddim", "1f0d72370fabcf2ff716e4068d5b2360"),
|
||||
},
|
||||
}
|
||||
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[
|
||||
get_device()
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
|
||||
def test_img2img_beach_to_sunset(sampler_type, expected_md5, filename_base_for_outputs):
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
|
||||
prompt = ImaginePrompt(
|
||||
"a painting of beautiful cloudy sunset at the beach",
|
||||
init_image=img,
|
||||
init_image_strength=0.5,
|
||||
prompt_strength=15,
|
||||
mask_prompt="(sky|clouds) AND !(buildings|trees)",
|
||||
mask_mode="replace",
|
||||
width=512,
|
||||
height=512,
|
||||
steps=40 * 2,
|
||||
seed=1,
|
||||
sampler_type=sampler_type,
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
|
||||
|
||||
device_sampler_type_test_cases_img_2_img = {
|
||||
"mps:0": {
|
||||
("plms", "e9bb714771f7984e61debabc4bb3cd22"),
|
||||
("ddim", "62bacc4ae391e6775a3723c88738ec61"),
|
||||
},
|
||||
"cuda": {
|
||||
("plms", "b8c7b52da977c1531a9a61c0a082404c"),
|
||||
("ddim", "d6784710dd78e4cb628aba28322b04cf"),
|
||||
},
|
||||
}
|
||||
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[
|
||||
get_device()
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
|
||||
def test_img_to_img_from_url_cats(
|
||||
sampler_type, expected_md5, filename_base_for_outputs
|
||||
):
|
||||
img = LazyLoadingImage(url="http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
|
||||
prompt = ImaginePrompt(
|
||||
"dogs lying on a hot pink couch",
|
||||
init_image=img,
|
||||
init_image_strength=0.5,
|
||||
width=512,
|
||||
height=512,
|
||||
steps=50,
|
||||
seed=1,
|
||||
sampler_type=sampler_type,
|
||||
)
|
||||
|
||||
result = next(imagine(prompt))
|
||||
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
|
||||
assert result.md5() == expected_md5
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@pytest.mark.parametrize("sampler_type", ["ddim", "plms"])
|
||||
@pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1])
|
||||
def test_img_to_img_fruit_2_gold(
|
||||
filename_base_for_outputs, sampler_type, init_strength
|
||||
):
|
||||
img = LazyLoadingImage(
|
||||
url="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg"
|
||||
)
|
||||
|
||||
prompt = ImaginePrompt(
|
||||
"a white bowl filled with gold coins",
|
||||
prompt_strength=12,
|
||||
init_image=img,
|
||||
init_image_strength=init_strength,
|
||||
mask_prompt="(fruit OR stem{*5} OR fruit stem)",
|
||||
mask_mode="replace",
|
||||
steps=80,
|
||||
seed=1,
|
||||
sampler_type=sampler_type,
|
||||
)
|
||||
|
||||
result = next(imagine(prompt))
|
||||
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_img_to_img_fruit_2_gold_repeat():
|
||||
"""Run this test manually to"""
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
|
||||
outdir = f"{TESTS_FOLDER}/test_output/"
|
||||
run_count = 1
|
||||
|
||||
def _record_step(img, description, step_count, prompt):
|
||||
steps_path = os.path.join(
|
||||
outdir,
|
||||
f"steps_fruit_2_gold_repeat_{get_device()}_S{prompt.seed}_run_{run_count:02}",
|
||||
)
|
||||
os.makedirs(steps_path, exist_ok=True)
|
||||
filename = f"fruit_2_gold_repeat_{get_device()}_S{prompt.seed}_step{step_count:04}_{prompt_normalized(description)[:40]}.jpg"
|
||||
|
||||
destination = os.path.join(steps_path, filename)
|
||||
img.save(destination)
|
||||
|
||||
kwargs = dict(
|
||||
prompt="a white bowl filled with gold coins. sharp focus",
|
||||
prompt_strength=12,
|
||||
init_image=img,
|
||||
init_image_strength=0.2,
|
||||
mask_prompt="(fruit OR stem{*5} OR fruit stem)",
|
||||
mask_mode="replace",
|
||||
steps=20,
|
||||
seed=946188797,
|
||||
sampler_type="plms",
|
||||
)
|
||||
prompts = [
|
||||
ImaginePrompt(**kwargs),
|
||||
ImaginePrompt(**kwargs),
|
||||
]
|
||||
for result in imagine(prompts, img_callback=None):
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold__orig.jpg")
|
||||
result.img.save(
|
||||
f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_plms_{get_device()}_run-{run_count:02}.jpg"
|
||||
)
|
||||
run_count += 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_img_to_file():
|
||||
prompt = ImaginePrompt(
|
||||
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
|
||||
width=512 + 64,
|
||||
height=512 - 64,
|
||||
steps=20,
|
||||
seed=2,
|
||||
sampler_type="PLMS",
|
||||
upscale=True,
|
||||
)
|
||||
out_folder = f"{TESTS_FOLDER}/test_output"
|
||||
imagine_image_files(prompt, outdir=out_folder)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_inpainting_bench(filename_base_for_outputs):
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
|
||||
prompt = ImaginePrompt(
|
||||
"a wise old man",
|
||||
init_image=img,
|
||||
init_image_strength=0.4,
|
||||
mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png"),
|
||||
width=512,
|
||||
height=512,
|
||||
steps=40,
|
||||
seed=1,
|
||||
sampler_type="plms",
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}.jpg")
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_cliptext_inpainting_pearl_doctor(filename_base_for_outputs):
|
||||
img = LazyLoadingImage(
|
||||
filepath=f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg"
|
||||
)
|
||||
prompt = ImaginePrompt(
|
||||
"a female doctor in the hospital",
|
||||
prompt_strength=12,
|
||||
init_image=img,
|
||||
init_image_strength=0.2,
|
||||
mask_prompt="face AND NOT (bandana OR hair OR blue fabric){*6}",
|
||||
mask_mode=ImaginePrompt.MaskMode.KEEP,
|
||||
width=512,
|
||||
height=512,
|
||||
steps=40,
|
||||
sampler_type="plms",
|
||||
seed=181509347,
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
|
||||
img = pillow_fit_image_within(img)
|
||||
img.save(f"{filename_base_for_outputs}__orig.jpg")
|
||||
result.img.save(f"{filename_base_for_outputs}_{prompt.seed}.jpg")
|
||||
|
||||
found_match = result.md5() in set(
|
||||
[
|
||||
"84868e7477a7375f7089160ac6adc064",
|
||||
"c5c0166185c284fc849901123e78d608",
|
||||
"6ef63037f5a1bd8bce6aec1c7ad46880",
|
||||
] # mps
|
||||
)
|
||||
assert found_match
|
@ -11,6 +11,14 @@ def test_imagine_cmd():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
imagine_cmd,
|
||||
["gold coins", "--steps", "5", "--outdir", f"{TESTS_FOLDER}/test_output"],
|
||||
[
|
||||
"gold coins",
|
||||
"--steps",
|
||||
"25",
|
||||
"--outdir",
|
||||
f"{TESTS_FOLDER}/test_output",
|
||||
"--seed",
|
||||
"703425280",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
@ -35,13 +35,16 @@ def img_hash(img):
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_clip_masking():
|
||||
img = Image.open(f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring_large.jpg")
|
||||
|
||||
for mask_modifier in [
|
||||
"*0.5",
|
||||
"*1",
|
||||
"*10",
|
||||
"*6",
|
||||
]:
|
||||
pred_bin, pred_grayscale = get_img_mask(
|
||||
img, f"(head OR face){{{mask_modifier}}}", threshold=0.1
|
||||
img,
|
||||
f"face AND NOT (bandana OR hair OR blue fabric){{{mask_modifier}}}",
|
||||
threshold=0.5,
|
||||
)
|
||||
pred_grayscale.save(
|
||||
f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}_g.png"
|
||||
@ -51,15 +54,14 @@ def test_clip_masking():
|
||||
)
|
||||
|
||||
prompt = ImaginePrompt(
|
||||
"professional photo of a woman",
|
||||
"a female firefighter in front of a burning building",
|
||||
init_image=img,
|
||||
init_image_strength=0.95,
|
||||
# lower steps for faster tests
|
||||
# steps=40,
|
||||
steps=4,
|
||||
mask_prompt="(head OR face)*5",
|
||||
steps=40,
|
||||
mask_prompt="(head OR face){*5}",
|
||||
mask_mode="replace",
|
||||
upscale=True,
|
||||
upscale=False,
|
||||
fix_faces=True,
|
||||
)
|
||||
|
||||
|
@ -1,152 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from imaginairy import LazyLoadingImage
|
||||
from imaginairy.api import imagine, imagine_image_files
|
||||
from imaginairy.schema import ImaginePrompt
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
from . import TESTS_FOLDER
|
||||
|
||||
device_sampler_type_test_cases = {
|
||||
"mps:0": [
|
||||
("plms", "b4b434ed45919f3505ac2be162791c71"),
|
||||
("ddim", "b369032a025915c0a7ccced165a609b3"),
|
||||
("k_lms", "b87325c189799d646ccd07b331564eb6"),
|
||||
("k_dpm_2", "cb37ca934938466bdbc1dd995da037de"),
|
||||
("k_dpm_2_a", "ef155995ca1638f0ae7db9f573b83767"),
|
||||
("k_euler", "d126da5ca8b08099cde8b5037464e788"),
|
||||
("k_euler_a", "cac5ca2e26c31a544b76a9442eb2ea37"),
|
||||
("k_heun", "0382ef71d9967fefd15676410289ebab"),
|
||||
],
|
||||
"cuda": [
|
||||
("plms", "0c44d2c8222f519a6700ebae54450435"),
|
||||
("ddim", "4493ca85c2b24879525eac2b73e5a538"),
|
||||
("k_lms", "82b38a5638a572d5968422b02e625f66"),
|
||||
("k_dpm_2", "9df2fcd6256ff68c6cc4a6c603ae8f2e"),
|
||||
("k_dpm_2_a", "0c5491c1a73094540ed15785f4106bca"),
|
||||
("k_euler", "c82f628217fab06d8b5d5227827c1d92"),
|
||||
("k_euler_a", "74f748a8371c2fcec54ecc5dcf1dbb64"),
|
||||
("k_heun", "9ae586a7a8b10a0a0bf120405e4937e9"),
|
||||
],
|
||||
"cpu": [],
|
||||
}
|
||||
sampler_type_test_cases = device_sampler_type_test_cases[get_device()]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases)
|
||||
def test_imagine(sampler_type, expected_md5):
|
||||
prompt_text = "a scenic landscape"
|
||||
prompt = ImaginePrompt(
|
||||
prompt_text, width=512, height=256, steps=5, seed=1, sampler_type=sampler_type
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
result.img.save(
|
||||
f"{TESTS_FOLDER}/test_output/sampler_type_{sampler_type.upper()}.jpg"
|
||||
)
|
||||
assert result.md5() == expected_md5
|
||||
|
||||
|
||||
device_sampler_type_test_cases_img_2_img = {
|
||||
"mps:0": {
|
||||
("plms", "0d9c40c348cdac7bdc8d5a472f378f42"),
|
||||
("ddim", "0d9c40c348cdac7bdc8d5a472f378f42"),
|
||||
},
|
||||
"cuda": {
|
||||
("plms", "28752d4e1d778abc3e9424f4f23d1aaf"),
|
||||
("ddim", "28752d4e1d778abc3e9424f4f23d1aaf"),
|
||||
},
|
||||
"cpu": [],
|
||||
}
|
||||
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[
|
||||
get_device()
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img)
|
||||
def test_img_to_img(sampler_type, expected_md5):
|
||||
prompt = ImaginePrompt(
|
||||
"a photo of a beach",
|
||||
init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
|
||||
init_image_strength=0.8,
|
||||
width=512,
|
||||
height=512,
|
||||
steps=5,
|
||||
seed=1,
|
||||
sampler_type=sampler_type,
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
result.img.save(
|
||||
f"{TESTS_FOLDER}/test_output/sampler_type_{sampler_type.upper()}_img2img_beach.jpg"
|
||||
)
|
||||
assert result.md5() == expected_md5
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_img_to_img_from_url():
|
||||
prompt = ImaginePrompt(
|
||||
"dogs lying on a hot pink couch",
|
||||
init_image=LazyLoadingImage(
|
||||
url="http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
),
|
||||
init_image_strength=0.5,
|
||||
width=512,
|
||||
height=512,
|
||||
steps=5,
|
||||
seed=1,
|
||||
sampler_type="DDIM",
|
||||
)
|
||||
out_folder = f"{TESTS_FOLDER}/test_output"
|
||||
imagine_image_files(prompt, outdir=out_folder)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_img_to_file():
|
||||
prompt = ImaginePrompt(
|
||||
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
|
||||
width=512 + 64,
|
||||
height=512 - 64,
|
||||
steps=5,
|
||||
seed=2,
|
||||
sampler_type="PLMS",
|
||||
upscale=True,
|
||||
)
|
||||
out_folder = f"{TESTS_FOLDER}/test_output"
|
||||
imagine_image_files(prompt, outdir=out_folder)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_inpainting():
|
||||
prompt = ImaginePrompt(
|
||||
"a basketball on a bench",
|
||||
init_image=f"{TESTS_FOLDER}/data/bench2.png",
|
||||
init_image_strength=0.4,
|
||||
mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png"),
|
||||
width=512,
|
||||
height=512,
|
||||
steps=5,
|
||||
seed=1,
|
||||
sampler_type="DDIM",
|
||||
)
|
||||
out_folder = f"{TESTS_FOLDER}/test_output"
|
||||
imagine_image_files(prompt, outdir=out_folder)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
def test_cliptext_inpainting():
|
||||
prompts = [
|
||||
ImaginePrompt(
|
||||
"elegant woman. oil painting",
|
||||
prompt_strength=12,
|
||||
init_image=f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg",
|
||||
init_image_strength=0.3,
|
||||
mask_prompt="face{*2}",
|
||||
mask_mode=ImaginePrompt.MaskMode.KEEP,
|
||||
width=512,
|
||||
height=512,
|
||||
steps=5,
|
||||
sampler_type="DDIM",
|
||||
),
|
||||
]
|
||||
out_folder = f"{TESTS_FOLDER}/test_output"
|
||||
imagine_image_files(prompts, outdir=out_folder)
|
Loading…
Reference in New Issue
Block a user