fix: several cli commands, edit demo, negative prompt

- fix colorize cmd. add test
- fix describe cmd. add test
- fix model-list cmd. add test
- fix stable studio
- hide stack grace for ValueErrors in cli
- set controlnet scale
- fix negative prompt to allow emptystring instead of replacing it with default
- adjust edit-demo parameters
- arg scheduler that works at click level (but disable it). works but not ideal experience.
This commit is contained in:
Bryce 2023-12-10 14:46:11 -08:00 committed by Bryce Drennan
parent c299cfffd9
commit e898e3a799
16 changed files with 270 additions and 154 deletions

View File

@ -3,20 +3,25 @@
### v14 todo
- configurable composition cutoff
- rename model parameter weights
- rename model_config parameter to architecture and make it case insensitive
- add --size parameter that accepts strings (e.g. 256x256, 4k, uhd, 8k, etc)
- detect if cuda torch missing and give better error message
- add method to install correct torch version
- ✅ rename model parameter weights
- ✅ rename model_config parameter to architecture and make it case insensitive
- ✅ add --size parameter that accepts strings (e.g. 256x256, 4k, uhd, 8k, etc)
- ✅ detect if cuda torch missing and give better error message
- ✅ add method to install correct torch version
- ✅ make cli run faster again
- ✅ add tests for cli commands
- add type checker
- only output the main image unless some flag is set
- allow selection of output video format
- chain multiple operations together imggen => videogen
- make sure terminal output on windows doesn't suck
- https://github.com/pallets/click/tree/main/examples/imagepipe
- add interface for loading diffusers weights
- https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic
- make sure terminal output on windows doesn't suck
- add method to show cache size
- add method to clear model cache
- add method to clear cached items not recently used (does diffusers have one?)
- https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic
### Old Todo
- Inference Performance Optimizations

View File

@ -2,7 +2,7 @@ import logging
from typing import List, Optional
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.schema import ImaginePrompt, MaskMode, WeightedPrompt
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode, WeightedPrompt
logger = logging.getLogger(__name__)
@ -20,8 +20,6 @@ def _generate_single_image(
dtype=None,
half_mode=None,
):
import gc
import torch.nn
from PIL import ImageOps
from pytorch_lightning import seed_everything
@ -62,8 +60,7 @@ def _generate_single_image(
dtype = torch.float16 if half_mode else torch.float32
get_device()
gc.collect()
torch.cuda.empty_cache()
clear_gpu_cache()
prompt = prompt.make_concrete_copy()
control_modes = []
@ -178,11 +175,11 @@ def _generate_single_image(
pillow_mask_to_latent_mask(
mask_image, downsampling_factor=downsampling_factor
).to(get_device())
# if inpaint_method == "controlnet":
# result_images["control-inpaint"] = mask_image
# control_inputs.append(
# ControlInput(mode="inpaint", image=mask_image)
# )
if inpaint_method == "controlnet":
result_images["control-inpaint"] = mask_image
control_inputs.append(
ControlInput(mode="inpaint", image=mask_image)
)
seed_everything(prompt.seed)
@ -257,26 +254,30 @@ def _generate_single_image(
target=sd.unet,
weights_location=control_config.weights_location,
)
controlnet.set_scale(control_input.strength)
controlnets.append((controlnet, control_image_t))
if prompt.allow_compose_phase:
compose_kwargs = {
"prompt": prompt,
"target_height": prompt.height,
"target_width": prompt.width,
"cutoff": get_model_default_image_size(prompt.model_architecture),
"dtype": dtype,
}
if prompt.init_image:
comp_image, comp_img_orig = _generate_composition_image(
prompt=prompt,
target_height=init_image.height,
target_width=init_image.width,
cutoff=get_model_default_image_size(prompt.model_architecture),
dtype=dtype,
compose_kwargs.update(
{
"target_height": init_image.height,
"target_width": init_image.width,
}
)
else:
comp_image, comp_img_orig = _generate_composition_image(
prompt=prompt,
target_height=prompt.height,
target_width=prompt.width,
cutoff=get_model_default_image_size(prompt.model_architecture),
dtype=dtype,
**compose_kwargs
)
if comp_image is not None:
result_images["composition"] = comp_img_orig
result_images["composition-upscaled"] = comp_image
@ -325,8 +326,7 @@ def _generate_single_image(
# if "cuda" in str(sd.lda.device):
# sd.lda.to("cpu")
gc.collect()
torch.cuda.empty_cache()
clear_gpu_cache()
# print(f"moving unet to {sd.device}")
# sd.unet.to(device=sd.device, dtype=sd.dtype)
for step in tqdm(sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}"):
@ -338,26 +338,12 @@ def _generate_single_image(
condition_scale=prompt.prompt_strength,
)
# z = sd(
# randn_seeded(seed=prompt.seed, size=[1, 4, 8, 8]).to(
# device=sd.device, dtype=sd.dtype
# ),
# step=step,
# clip_text_embedding=clip_text_embedding,
# condition_scale=prompt.prompt_strength,
# )
if "cuda" in str(sd.unet.device):
# print("moving unet to cpu")
# sd.unet.to("cpu")
gc.collect()
torch.cuda.empty_cache()
clear_gpu_cache()
logger.debug("Decoding image")
if x.device != sd.lda.device:
sd.lda.to(x.device)
gc.collect()
torch.cuda.empty_cache()
clear_gpu_cache()
gen_img = sd.lda.decode_latents(x)
if mask_image_orig and init_image:
@ -428,7 +414,7 @@ def _generate_single_image(
is_nsfw=safety_score.is_nsfw,
safety_score=safety_score,
result_images=result_images,
timings={}, # todo
timings=lc.get_timings(),
progress_latents=[], # todo
)
@ -437,8 +423,7 @@ def _generate_single_image(
logger.info(f"Image Generated. Timings: {result.timings_str()}")
for controlnet, _ in controlnets:
controlnet.eject()
gc.collect()
torch.cuda.empty_cache()
clear_gpu_cache()
return result
@ -480,3 +465,13 @@ def _calc_conditioning(
tensors=(neutral_conditioning, positive_conditioning), dim=0
)
return clip_text_embedding
def clear_gpu_cache():
import gc
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -0,0 +1,72 @@
from typing import Iterable
from imaginairy.utils import frange
def with_arg_schedule(f):
"""Decorator to add arg-schedule functionality to a click command."""
def new_func(*args, **kwargs):
arg_schedules = kwargs.pop("arg_schedules", None)
if arg_schedules:
schedules = parse_schedule_strs(arg_schedules)
schedule_length = len(next(iter(schedules.values())))
for i in range(schedule_length):
for attr_name, schedule in schedules.items():
kwargs[attr_name] = schedule[i]
f(*args, **kwargs)
else:
f(*args, **kwargs)
return new_func
def parse_schedule_strs(schedule_strs: Iterable[str]) -> dict:
"""Parse and validate input prompt schedules."""
schedules = {}
for schedule_str in schedule_strs:
arg_name, arg_values = parse_schedule_str(schedule_str)
schedules[arg_name] = arg_values
# Validate that all schedules have the same length
schedule_lengths = [len(v) for v in schedules.values()]
if len(set(schedule_lengths)) > 1:
raise ValueError("All schedules must have the same length")
return schedules
def parse_schedule_str(schedule_str):
"""Parse a schedule string into a list of values."""
import re
pattern = re.compile(r"([a-zA-Z0-9_-]+)\[([a-zA-Z0-9_:,. -]+)\]")
match = pattern.match(schedule_str)
if not match:
msg = f"Invalid kwarg schedule: {schedule_str}"
raise ValueError(msg)
arg_name = match.group(1).replace("-", "_")
arg_values = match.group(2)
if ":" in arg_values:
start, end, step = arg_values.split(":")
arg_values = list(frange(float(start), float(end), float(step)))
else:
arg_values = parse_csv_line(arg_values)
return arg_name, arg_values
def parse_csv_line(line):
import csv
reader = csv.reader([line])
for row in reader:
parsed_row = []
for value in row:
try:
parsed_row.append(float(value))
except ValueError:
parsed_row.append(value)
return parsed_row

View File

@ -36,9 +36,9 @@ def colorize_cmd(image_filepaths, outdir, repeats, caption):
from tqdm import tqdm
from imaginairy import LazyLoadingImage
from imaginairy.colorize import colorize_img
from imaginairy.log_utils import configure_logging
from imaginairy.schema import LazyLoadingImage
configure_logging()

View File

@ -7,8 +7,8 @@ def describe_cmd(image_filepaths):
"""Generate text descriptions of images."""
import os
from imaginairy import LazyLoadingImage
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.schema import LazyLoadingImage
imgs = []
for p in image_filepaths:

View File

@ -96,13 +96,7 @@ def edit_cmd(
from imaginairy.schema import ControlInput
allow_compose_phase = False
control_inputs = [
ControlInput(
image=None,
image_raw=None,
mode="edit",
)
]
control_inputs = [ControlInput(image=None, image_raw=None, mode="edit", strength=1)]
return _imagine_cmd(
ctx,

View File

@ -1,7 +1,12 @@
import click
from imaginairy.cli.clickshell_mod import ImagineColorsCommand
from imaginairy.cli.shared import _imagine_cmd, add_options, common_options
from imaginairy.cli.shared import (
_imagine_cmd,
add_options,
common_options,
imaginairy_click_context,
)
@click.command(
@ -67,6 +72,7 @@ from imaginairy.cli.shared import _imagine_cmd, add_options, common_options
help="Turns the generated photo into video",
)
@click.pass_context
@imaginairy_click_context()
def imagine_cmd(
ctx,
prompt_texts,

View File

@ -82,17 +82,16 @@ def model_list_cmd():
"""Print list of available models."""
from imaginairy import config
print(f"{'ALIAS': <10} {'NAME': <18} {'DESCRIPTION'}")
print("\nWEIGHT NAMES")
print(f"{'ALIAS': <25} {'NAME': <25} ")
for model_config in config.MODEL_WEIGHT_CONFIGS:
print(
f"{model_config.alias: <10} {model_config.short_name: <18} {model_config.description}"
)
print(f"{model_config.aliases[0]: <25} {model_config.name: <25}")
print("\nCONTROL MODES:")
print(f"{'ALIAS': <10} {'NAME': <18} {'CONTROL TYPE'}")
print("\nCONTROL MODES")
print(f"{'ALIAS': <14} {'NAME': <35} {'CONTROL TYPE'}")
for control_mode in config.CONTROL_CONFIGS:
print(
f"{control_mode.alias: <10} {control_mode.short_name: <18} {control_mode.control_type}"
f"{control_mode.aliases[0]: <14} {control_mode.name: <35} {control_mode.control_type}"
)

View File

@ -15,13 +15,12 @@ def imaginairy_click_context(log_level="INFO"):
from imaginairy.log_utils import configure_logging
errors_to_catch = (FileNotFoundError, ValidationError)
errors_to_catch = (FileNotFoundError, ValidationError, ValueError)
configure_logging(level=log_level)
try:
yield
except errors_to_catch as e:
logger.error(e)
exit(1)
def _imagine_cmd(

File diff suppressed because one or more lines are too long

View File

@ -55,7 +55,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
style: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
solver: Optional[StableStudioSolver] = None
solver: Optional[StableStudioSolver] = Field(None, alias="sampler")
cfg_scale: Optional[float] = Field(None, alias="cfgScale")
steps: Optional[int] = None
seed: Optional[int] = None
@ -94,12 +94,11 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
prompt=positive_prompt,
prompt_strength=self.cfg_scale,
negative_prompt=negative_prompt,
model=self.model,
model_weights=self.model,
solver_type=solver_type,
seed=self.seed,
steps=self.steps,
height=self.height,
width=self.width,
size=(self.width, self.height),
init_image=Image.open(BytesIO(init_image)) if init_image else None,
init_image_strength=init_image_strength,
mask_image=Image.open(BytesIO(mask_image)) if mask_image else None,

View File

@ -56,12 +56,14 @@ async def list_models():
model_objs = []
for model_config in MODEL_WEIGHT_CONFIGS:
if "inpaint" in model_config.description.lower():
if "inpaint" in model_config.name.lower():
continue
if model_config.architecture.output_modality != "image":
continue
model_obj = StableStudioModel(
id=model_config.short_name,
name=model_config.description,
description=model_config.description,
id=model_config.aliases[0],
name=model_config.name,
description=model_config.name,
)
model_objs.append(model_obj)

View File

@ -365,7 +365,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return []
case str():
if value:
if value is not None:
return [WeightedPrompt(text=value)]
else:
return []
@ -395,10 +395,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
@model_validator(mode="after")
def validate_negative_prompt(self):
if (
self.negative_prompt == [WeightedPrompt(text="")]
or self.negative_prompt == []
):
if self.negative_prompt == []:
model_weight_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(
self.model_weights, None
)
@ -411,7 +408,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
self.negative_prompt = [WeightedPrompt(text=default_negative_prompt)]
return self
@field_validator("prompt_strength")
@field_validator("prompt_strength", mode="before")
def validate_prompt_strength(cls, v):
return 7.5 if v is None else v

View File

@ -1,3 +1,4 @@
import logging
import os.path
from imaginairy.animations import make_gif_animation
@ -6,52 +7,65 @@ from imaginairy.enhancers.facecrop import detect_faces
from imaginairy.img_utils import add_caption_to_image, pillow_fit_image_within
from imaginairy.schema import ControlInput, ImaginePrompt, LazyLoadingImage
logger = logging.getLogger(__name__)
preserve_head_kwargs = {
"mask_prompt": "head|face",
"mask_mode": "keep",
}
preserve_face_kwargs = {
"mask_prompt": "face",
"mask_mode": "keep",
}
generic_prompts = [
("add confetti", 6, {}),
("add confetti", 15, {}),
# ("add sparkles", 14, {}),
("make it christmas", 15, preserve_head_kwargs),
("make it halloween", 15, {}),
("give it a dark omninous vibe", 15, {}),
("give it a bright cheery vibe", 15, {}),
("make it christmas", 15, preserve_face_kwargs),
("make it halloween", 15, preserve_face_kwargs),
("give it a depressing vibe", 10, {}),
("give it a bright cheery vibe", 10, {}),
# weather
("make it look like a snowstorm", 20, {}),
("make it midnight", 15, {}),
("make it a sunny day", 15, {}),
("add misty fog", 15, {}),
("make it look like a snowstorm", 15, preserve_face_kwargs),
("make it sunset", 15, preserve_face_kwargs),
("make it a sunny day", 15, preserve_face_kwargs),
("add misty fog", 10, {}),
("make it flooded", 10, {}),
# setting
("make it underwater", 15, {}),
("make it underwater", 10, {}),
("add fireworks to the sky", 15, {}),
# ("make it in a forest", 10, {}),
# ("make it grassy", 11, {}),
("make it on mars", 14, {}),
# style
("add glitter", 10, {}),
("turn it into a still from a western", 15, {}),
("turn it into a still from a western", 10, {}),
("old 1900s photo", 11.5, {}),
("Daguerreotype", 12, {}),
("make it anime style", 18, {}),
("Daguerreotype", 14, {}),
("make it anime style", 15, {}),
("watercolor painting", 10, {}),
("crayon drawing", 10, {}),
# ("make it pen and ink style", 20, {}),
("graphite pencil", 15, {}),
# ("make it a thomas kinkade painting", 20, {}),
("make it pixar style", 20, {}),
("graphite pencil", 10, {"negative_prompt": "low quality"}),
("make it a thomas kinkade painting", 10, {}),
("make it pixar style", 18, {}),
("low-poly", 20, {}),
("make it stained glass", 10, {}),
("make it pop art", 12, {}),
# ("make it street graffiti", 15, {}),
("make it stained glass", 15, {}),
("make it pop art", 15, {}),
("oil painting", 11, {}),
("street graffiti", 10, {}),
("photorealistic", 8, {}),
("vector art", 8, {}),
("comic book style. happy", 9, {}),
("starry night painting", 15, {}),
("make it minecraft", 12, {}),
# materials
("make it look like a marble statue", 15, {}),
("marble statue", 15, {}),
("make it look like a golden statue", 15, {}),
# ("make it claymation", 8, {}),
("golden statue", 15, {}),
# ("make it claymation", 15, {}),
("play-doh", 15, {}),
("voxel", 15, {}),
# ("lego", 15, {}),
@ -74,22 +88,26 @@ only_face_kwargs = {
person_prompt_configs = [
# face
("make the person close their eyes", 10, only_face_kwargs),
# (
# "make the person wear intricate highly detailed facepaint. ornate, artistic",
# 9,
# only_face_kwargs,
# ),
# ("make the person wear makeup. professional photoshoot", 8, only_face_kwargs),
("make the person close their eyes", 7, only_face_kwargs),
(
"make the person wear intricate highly detailed facepaint. ornate, artistic",
6,
only_face_kwargs,
),
# ("make the person wear makeup. professional photoshoot", 15, only_face_kwargs),
# ("make the person wear mime makeup. intricate, artistic", 7, only_face_kwargs),
# ("make the person wear clown makeup. intricate, artistic", 6, only_face_kwargs),
("make the person wear clown makeup. intricate, artistic", 7, only_face_kwargs),
("make the person a cyborg", 14, {}),
# clothes
("make the person wear shiny metal clothes", 14, preserve_head_kwargs),
("make the person wear a tie-dye shirt", 7.5, preserve_head_kwargs),
("make the person wear a suit", 7.5, preserve_head_kwargs),
# ("make the person bald", 7.5, {}),
("change the hair to pink", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}),
("make the person wear shiny metal clothes", 14, preserve_face_kwargs),
("make the person wear a tie-dye shirt", 14, preserve_face_kwargs),
("make the person wear a suit", 14, preserve_face_kwargs),
("make the person bald", 15, {}),
(
"change the hair to pink",
7.5,
{"mask_mode": "keep", "mask_prompt": "face"},
),
# ("change the hair to black", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}),
# ("change the hair to blonde", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}),
# ("change the hair to red", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}),
@ -101,22 +119,22 @@ person_prompt_configs = [
# ("change the hair to silver", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}),
(
"professional corporate photo headshot. Canon EOS, sharp focus, high resolution",
10,
{"negative_prompt": "old, ugly"},
6,
{"negative_prompt": "low quality"},
),
# ("make the person stoic. pensive", 10, only_face_kwargs),
# ("make the person sad", 20, {}),
# ("make the person angry", 20, {}),
# ("make the person look like a celebrity", 10, {}),
("make the person younger", 11, {}),
("make the person 70 years old", 9, {}),
("make the person a disney cartoon character", 7.5, {}),
("make the person stoic. pensive", 7, only_face_kwargs),
("make the person sad", 7, only_face_kwargs),
("make the person angry", 7, only_face_kwargs),
("make the person look like a celebrity", 10, {}),
("make the person younger", 7, {}),
("make the person 70 years old", 10, {}),
("make the person a disney cartoon character", 9, {}),
("turn the humans into robots", 13, {}),
("make the person darth vader", 15, {}),
("make the person a starfleet officer", 15, preserve_head_kwargs),
("make the person a superhero", 15, {}),
("make the person a tiger", 15, only_face_kwargs),
# ("lego minifig", 15, {}),
("make the person a jedi knight. star wars", 15, preserve_head_kwargs),
("make the person a starfleet officer. star trek", 15, preserve_head_kwargs),
("make the person a superhero", 15, preserve_head_kwargs),
# ("a tiger", 15, only_face_kwargs),
("lego minifig", 15, {}),
]
@ -132,22 +150,21 @@ def surprise_me_prompts(
if person is None:
person = bool(detect_faces(img))
prompts = []
logger.info("Person detected in photo. Adjusting edits accordingly.")
init_image_strength = 0.3
for prompt_text, strength, kwargs in generic_prompts:
kwargs.setdefault("negative_prompt", None)
kwargs.setdefault("init_image_strength", init_image_strength)
if use_controlnet:
strength = 5
control_input = ControlInput(mode="edit", strength=2)
control_input = ControlInput(mode="edit")
prompts.append(
ImaginePrompt(
prompt_text,
negative_prompt="",
init_image=img,
init_image_strength=0.3,
prompt_strength=strength,
control_inputs=[control_input],
steps=steps,
width=width,
height=height,
size=(width, height),
**kwargs,
)
)
@ -159,8 +176,7 @@ def surprise_me_prompts(
prompt_strength=strength,
model_weights="edit",
steps=steps,
width=width,
height=height,
size=(width, height),
**kwargs,
)
)
@ -175,16 +191,16 @@ def surprise_me_prompts(
control_input = ControlInput(
mode="edit",
)
kwargs.setdefault("negative_prompt", None)
kwargs.setdefault("init_image_strength", init_image_strength)
prompts.append(
ImaginePrompt(
prompt_text,
init_image=img,
init_image_strength=0.05,
prompt_strength=strength,
control_inputs=[control_input],
steps=steps,
width=width,
height=height,
size=(width, height),
seed=seed,
**kwargs,
)
@ -197,8 +213,7 @@ def surprise_me_prompts(
prompt_strength=strength,
model_weights="edit",
steps=steps,
width=width,
height=height,
size=(width, height),
seed=seed,
**kwargs,
)
@ -241,7 +256,7 @@ def create_surprise_me_images(
gif_imgs.append(gen_img)
make_gif_animation(outpath=new_filename, imgs=gif_imgs)
make_gif_animation(outpath=new_filename, imgs=gif_imgs, frame_duration_ms=1000)
if __name__ == "__main__":

View File

@ -19,10 +19,7 @@ from tests.utils import Timer
@pytest.mark.parametrize("subcommand_name", aimg.commands.keys())
def test_cmd_help_time(subcommand_name):
cmd_parts = [
"python",
"-X",
"importtime",
"imaginairy/cli/main.py",
"aimg",
subcommand_name,
"--help",
]
@ -34,6 +31,41 @@ def test_cmd_help_time(subcommand_name):
assert t.elapsed < 1.0, f"{t.elapsed} > 1.0"
def test_model_info_cmd():
runner = CliRunner()
result = runner.invoke(
aimg,
[
"model-list",
],
)
assert result.exit_code == 0, result.stdout
def test_describe_cmd():
runner = CliRunner()
result = runner.invoke(
aimg,
[
"describe",
f"{TESTS_FOLDER}/data/dog.jpg",
],
)
assert result.exit_code == 0, result.stdout
def test_colorize_cmd():
runner = CliRunner()
result = runner.invoke(
aimg,
[
"colorize",
f"{TESTS_FOLDER}/data/dog.jpg",
],
)
assert result.exit_code == 0, result.stdout
def test_imagine_cmd(monkeypatch):
monkeypatch.setattr(GPUModelCache, "make_gpu_space", mock.MagicMock())
runner = CliRunner()
@ -47,8 +79,6 @@ def test_imagine_cmd(monkeypatch):
f"{TESTS_FOLDER}/test_output",
"--seed",
"703425280",
# "--model",
# "empty",
"--outdir",
f"{TESTS_FOLDER}/test_output",
],

View File

@ -14,11 +14,14 @@ from tests import TESTS_FOLDER
def test_imagine_prompt_default():
prompt = ImaginePrompt()
assert prompt.prompt == []
assert prompt.prompt == [WeightedPrompt(text="")]
assert prompt.negative_prompt == [
WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT)
]
prompt = ImaginePrompt(negative_prompt="")
assert prompt.negative_prompt == [WeightedPrompt(text="")]
def test_imagine_prompt_has_default_negative():
prompt = ImaginePrompt("fruit salad", model_weights="foobar")