You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/tests/test_schema/test_imagineprompt.py

236 lines
7.3 KiB
Python

import pytest
from pydantic import ValidationError
from imaginairy import LazyLoadingImage, config
from imaginairy.schema import ControlInput, ImaginePrompt, WeightedPrompt
from imaginairy.utils.data_distorter import DataDistorter
from tests import TESTS_FOLDER
def test_imagine_prompt_default():
prompt = ImaginePrompt()
assert prompt.prompt == []
assert prompt.negative_prompt == [
WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT)
]
def test_imagine_prompt_has_default_negative():
prompt = ImaginePrompt("fruit salad", model_weights="foobar")
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
def test_imagine_prompt_custom_negative_prompt():
prompt = ImaginePrompt("fruit salad", negative_prompt="pizza")
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
assert prompt.negative_prompt[0].text == "pizza"
def test_imagine_prompt_model_specific_negative_prompt():
prompt = ImaginePrompt("fruit salad", model_weights="openjourney-v1")
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
assert prompt.negative_prompt[0].text == "poor quality"
def test_imagine_prompt_weighted_prompts():
prompt = ImaginePrompt(WeightedPrompt(text="cat", weight=0.1))
assert isinstance(prompt.prompt[0], WeightedPrompt)
prompt = ImaginePrompt(
[
WeightedPrompt(text="cat", weight=0.1),
WeightedPrompt(text="dog", weight=0.2),
]
)
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert prompt.prompt[0].text == "dog"
def test_imagine_prompt_tile_mode():
prompt = ImaginePrompt("fruit")
assert prompt.tile_mode == ""
prompt = ImaginePrompt("fruit", tile_mode=True)
assert prompt.tile_mode == "xy"
prompt = ImaginePrompt("fruit", tile_mode=False)
assert prompt.tile_mode == ""
prompt = ImaginePrompt("fruit", tile_mode="X")
assert prompt.tile_mode == "x"
with pytest.raises(ValueError, match=r".*Invalid tile_mode.*"):
ImaginePrompt("fruit", tile_mode="pizza")
def test_imagine_prompt_copy():
p1 = ImaginePrompt("fruit")
p2 = p1.full_copy()
assert p1 == p2
assert id(p1) != id(p2)
def test_imagine_prompt_concrete_copy():
p1 = ImaginePrompt("fruit")
p2 = p1.make_concrete_copy()
assert p1 != p2
assert id(p1) != id(p2)
assert p1.seed is None
assert p2.seed is not None
def test_imagine_prompt_image_paths():
p = ImaginePrompt("fruit", init_image=f"{TESTS_FOLDER}/data/red.png")
assert isinstance(p.init_image, LazyLoadingImage)
def test_imagine_prompt_control_inputs():
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
prompt = ImaginePrompt(
"fruit",
control_inputs=[
ControlInput(mode="depth", image=img),
],
)
prompt.control_inputs[0].image.convert("RGB")
# init image should be set from first control-image if init image wasn't set
assert prompt.init_image is not None
assert isinstance(prompt.init_image, LazyLoadingImage)
# if an image isn't specified for a controlnet, use an init image
prompt = ImaginePrompt(
"fruit",
init_image=img,
control_inputs=[
ControlInput(mode="depth"),
],
)
assert prompt.control_inputs[0].image is not None
# if an image isn't specified for a controlnet or init image, what should happen?
prompt = ImaginePrompt(
"fruit",
control_inputs=[
ControlInput(mode="depth"),
],
)
assert prompt.control_inputs[0].image is None
def test_imagine_prompt_mask_params():
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
with pytest.raises(ValueError, match=r".*only set one.*"):
ImaginePrompt(
"fruit",
init_image=img,
mask_prompt="apple",
mask_image=img,
)
with pytest.raises(ValueError, match=r".*if you want to use a mask.*"):
ImaginePrompt(
"fruit",
mask_prompt="apple",
)
with pytest.raises(ValueError, match=r".*if you want to use a mask.*"):
ImaginePrompt(
"fruit",
mask_image=img,
)
def test_imagine_prompt_default_model():
prompt = ImaginePrompt("fruit", model_weights=None)
assert prompt.model_weights == config.DEFAULT_MODEL_WEIGHTS
def test_imagine_prompt_default_negative():
prompt = ImaginePrompt("fruit")
assert prompt.negative_prompt[0].text == config.DEFAULT_NEGATIVE_PROMPT
def test_imagine_prompt_fix_faces_fidelity():
assert ImaginePrompt("fruit", fix_faces_fidelity=None).fix_faces_fidelity == 0.2
def test_imagine_prompt_init_strength_zero():
lazy_img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
prompt = ImaginePrompt(
"fruit", control_inputs=[ControlInput(mode="depth", image=lazy_img)]
)
assert prompt.init_image_strength == 0.0
prompt = ImaginePrompt("fruit")
assert prompt.init_image_strength == 0.2
def test_distorted_prompts():
prompt_obj = ImaginePrompt(
prompt=[
WeightedPrompt(text="sunset", weight=0.7),
WeightedPrompt(text="beach", weight=1.3),
],
negative_prompt=[WeightedPrompt(text="night", weight=1.0)],
prompt_strength=7.0,
init_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
init_image_strength=0.5,
control_inputs=[
ControlInput(
mode="details",
image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
strength=2,
),
ControlInput(
mode="depth",
image_raw=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
strength=3,
),
],
mask_prompt=None,
mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
mask_mode="replace",
mask_modify_original=False,
outpaint="all5,up0,down20",
model_weights=config.DEFAULT_MODEL_WEIGHTS,
solver_type=config.DEFAULT_SOLVER,
seed=42,
steps=10,
size=256,
upscale=True,
fix_faces=True,
fix_faces_fidelity=0.7,
conditioning=None,
tile_mode="xy",
allow_compose_phase=False,
is_intermediate=False,
collect_progress_latents=False,
caption_text="Sample Caption",
)
data = prompt_obj.model_dump(mode="python")
valid_prompts = []
total_prompts = 0
for i, distorted_data in enumerate(DataDistorter(data)):
total_prompts += 1
try:
distorted_prompt = ImaginePrompt.model_validate(distorted_data)
valid_prompts.append(distorted_prompt)
except ValidationError:
continue
print(f"Valid prompts: {len(valid_prompts)}")
print(f"Invalid prompts: {total_prompts - len(valid_prompts)}")
# for p in valid_prompts:
# try:
# imagine_image_files(p, f"{TESTS_FOLDER}/test_output/distorted_prompts/")
# except ValueError as e:
# print(f"################{e}")
# continue
# except Exception as e:
# print("################")
# print(p)
# raise e