feature: 🎉 Edit Images with Instructions alone!

pull/173/head
Bryce 1 year ago committed by Bryce Drennan
parent 7285644909
commit 2a3e19f5a1

@ -42,7 +42,27 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000019_786355545_PLMS50_PS7.5_a_scenic_landscape.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000032_337692011_PLMS40_PS7.5_a_photo_of_a_dog.jpg" height="256"><br>
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000078_260972468_PLMS40_PS7.5_portrait_photo_of_a_freckled_woman.jpg" height="256">
### Prompt Based Editing [by clipseg](https://github.com/timojl/clipseg)
### 🎉 Edit Images with Instructions alone! [by InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix)
Just tell imaginairy how to edit the image and it will do it for you!
Use prompt strength to control how strong the edit is. For extra control you can combine
with prompt-based masking.
```bash
>> aimg edit scenic_landscape.jpg "make it winter" --prompt-strength 20
>> aimg edit dog.jpg "make the dog red" --prompt-strength 5
>> aimg edit bowl_of_fruit.jpg "replace the fruit with strawberries"
>> aimg edit freckled_woman.jpg "make her a cyborg" --prompt-strength 13
>> aimg edit pearl_girl.jpg "make her wear clown makup"
>> aimg edit mona-lisa.jpg "make it a color professional photo headshot" --negative-prompt "old, ugly"
```
<img src="assets/scenic_landscape_winter.jpg" height="256"><img src="assets/dog_red.jpg" height="256"><br>
<img src="assets/bowl_of_fruit_strawberries.jpg" height="256"><img src="assets/freckled_woman_cyborg.jpg" height="256"><br>
<img src="assets/girl_with_a_pearl_earring_clown_makeup.jpg" height="256"><img src="assets/mona-lisa-headshot-photo.jpg" height="256"><br>
### Prompt Based Masking [by clipseg](https://github.com/timojl/clipseg)
Specify advanced text based masks using boolean logic and strength modifiers.
Mask syntax:
- mask descriptions must be lowercase
@ -252,6 +272,10 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
**8.0.0**
- feature: 🎉 edit images with instructions alone!
- feature: prune-ckpt command also removes the non-ema weights
**7.6.0**
- fix: default model config was broken
- feature: print version with `--version`

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 252 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

@ -26,6 +26,7 @@ from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpai
from imaginairy.safety import SafetyMode, create_safety_score
from imaginairy.samplers import SAMPLER_LOOKUP
from imaginairy.samplers.base import NoiseSchedule, noise_an_image
from imaginairy.samplers.editing import CFGEditingDenoiser
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
fix_torch_group_norm,
@ -180,7 +181,10 @@ def imagine(
if prompt.init_image:
starting_image = prompt.init_image
generation_strength = 1 - prompt.init_image_strength
t_enc = int(prompt.steps * generation_strength)
if model.cond_stage_key == "edit":
t_enc = prompt.steps
else:
t_enc = int(prompt.steps * generation_strength)
if prompt.mask_prompt:
mask_image, mask_grayscale = get_img_mask(
@ -269,6 +273,7 @@ def imagine(
"txt": batch_size * [prompt.prompt_text],
}
c_cat = []
c_cat_neutral = None
depth_image_display = None
if has_depth_channel and starting_image:
midas_model = AddMiDaS()
@ -346,13 +351,21 @@ def imagine(
c_cat.append(cc)
if c_cat:
c_cat = [torch.cat(c_cat, dim=1)]
denoiser_cls = None
if model.cond_stage_key == "edit":
c_cat = [model.encode_first_stage(init_image_t).mode()]
c_cat_neutral = [torch.zeros_like(init_latent)]
denoiser_cls = CFGEditingDenoiser
if c_cat_neutral is None:
c_cat_neutral = c_cat
positive_conditioning = {
"c_concat": c_cat,
"c_crossattn": [positive_conditioning],
}
neutral_conditioning = {
"c_concat": c_cat,
"c_concat": c_cat_neutral,
"c_crossattn": [neutral_conditioning],
}
with lc.timing("sampling"):
@ -367,6 +380,7 @@ def imagine(
orig_latent=init_latent,
shape=shape,
batch_size=1,
denoiser_cls=denoiser_cls,
)
# from torch.nn.functional import interpolate
# samples = interpolate(samples, scale_factor=2, mode='nearest')

@ -255,6 +255,345 @@ def imagine_cmd(
model_config_path,
prompt_library_path,
version, # noqa
):
"""Have the AI generate images. alias:imagine."""
return _imagine_cmd(
ctx,
prompt_texts,
negative_prompt,
prompt_strength,
init_image,
init_image_strength,
outdir,
repeats,
height,
width,
steps,
seed,
upscale,
fix_faces,
fix_faces_fidelity,
sampler_type,
log_level,
quiet,
show_work,
tile,
tile_x,
tile_y,
mask_image,
mask_prompt,
mask_mode,
mask_modify_original,
outpaint,
caption,
precision,
model_weights_path,
model_config_path,
prompt_library_path,
version, # noqa
)
@click.command()
@click.argument("init_image", metavar="PATH|URL", required=True, nargs=1)
@click.argument("prompt_texts", nargs=-1)
@click.option(
"--negative-prompt",
default="",
show_default=True,
help="Negative prompt. Things to try and exclude from images. Same negative prompt will be used for all images.",
)
@click.option(
"--prompt-strength",
default=7.5,
show_default=True,
help="How closely to follow the prompt. Image looks unnatural at higher values",
)
@click.option(
"--init-image",
metavar="PATH|URL",
help="Starting image.",
)
@click.option(
"--outdir",
default="./outputs",
show_default=True,
type=click.Path(),
help="Where to write results to.",
)
@click.option(
"-r",
"--repeats",
default=1,
show_default=True,
type=int,
help="How many times to repeat the renders. If you provide two prompts and --repeat=3 then six images will be generated.",
)
@click.option(
"-h",
"--height",
default=None,
show_default=True,
type=int,
help="Image height. Should be multiple of 64.",
)
@click.option(
"-w",
"--width",
default=None,
show_default=True,
type=int,
help="Image width. Should be multiple of 64.",
)
@click.option(
"--steps",
default=None,
type=int,
show_default=True,
help="How many diffusion steps to run. More steps, more detail, but with diminishing returns.",
)
@click.option(
"--seed",
default=None,
type=int,
help="What seed to use for randomness. Allows reproducible image renders.",
)
@click.option("--upscale", is_flag=True)
@click.option("--fix-faces", is_flag=True)
@click.option(
"--fix-faces-fidelity",
default=1,
type=float,
help="How faithful to the original should face enhancement be. 1 = best fidelity, 0 = best looking face.",
)
@click.option(
"--sampler-type",
"--sampler",
default=config.DEFAULT_SAMPLER,
show_default=True,
type=click.Choice(SAMPLER_TYPE_OPTIONS),
help="What sampling strategy to use.",
)
@click.option(
"--log-level",
default="INFO",
show_default=True,
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"]),
help="What level of logs to show.",
)
@click.option(
"--quiet",
"-q",
is_flag=True,
help="Suppress logs. Alias of `--log-level ERROR`.",
)
@click.option(
"--show-work",
default=False,
is_flag=True,
help="Output a debug images to `steps` folder.",
)
@click.option(
"--tile",
is_flag=True,
help="Any images rendered will be tileable in both X and Y directions.",
)
@click.option(
"--tile-x",
is_flag=True,
help="Any images rendered will be tileable in the X direction.",
)
@click.option(
"--tile-y",
is_flag=True,
help="Any images rendered will be tileable in the Y direction.",
)
@click.option(
"--mask-image",
metavar="PATH|URL",
help="A mask to use for inpainting. White gets painted, Black is left alone.",
)
@click.option(
"--mask-prompt",
help=(
"Describe what you want masked and the AI will mask it for you. "
"You can describe complex masks with AND, OR, NOT keywords and parentheses. "
"The strength of each mask can be modified with {*1.5} notation. \n\n"
"Examples: \n"
"car AND (wheels{*1.1} OR trunk OR engine OR windows OR headlights) AND NOT (truck OR headlights){*10}\n"
"fruit|fruit stem"
),
)
@click.option(
"--mask-mode",
default="replace",
show_default=True,
type=click.Choice(["keep", "replace"]),
help="Should we replace the masked area or keep it?",
)
@click.option(
"--mask-modify-original",
default=True,
is_flag=True,
help="After the inpainting is done, apply the changes to a copy of the original image.",
)
@click.option(
"--outpaint",
help=(
"Specify in what directions to expand the image. Values will be snapped such that output image size is multiples of 64. Examples\n"
" `--outpaint up10,down300,left50,right50`\n"
" `--outpaint u10,d300,l50,r50`\n"
" `--outpaint all200`\n"
" `--outpaint a200`\n"
),
default="",
)
@click.option(
"--caption",
default=False,
is_flag=True,
help="Generate a text description of the generated image.",
)
@click.option(
"--precision",
help="Evaluate at this precision.",
type=click.Choice(["full", "autocast"]),
default="autocast",
show_default=True,
)
@click.option(
"--model-weights-path",
"--model",
help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.",
show_default=True,
default="edit",
)
@click.option(
"--model-config-path",
help="Model config file to use. If a model name is specified, the appropriate config will be used.",
show_default=True,
default=None,
)
@click.option(
"--prompt-library-path",
help="Path to folder containing phrase lists in txt files. Use txt filename in prompt: {_filename_}.",
type=click.Path(exists=True),
default=None,
multiple=True,
)
@click.option(
"--version",
default=False,
is_flag=True,
help="Print the version and exit.",
)
@click.pass_context
def edit_image(
ctx,
init_image,
prompt_texts,
negative_prompt,
prompt_strength,
outdir,
repeats,
height,
width,
steps,
seed,
upscale,
fix_faces,
fix_faces_fidelity,
sampler_type,
log_level,
quiet,
show_work,
tile,
tile_x,
tile_y,
mask_image,
mask_prompt,
mask_mode,
mask_modify_original,
outpaint,
caption,
precision,
model_weights_path,
model_config_path,
prompt_library_path,
version, # noqa
):
init_image_strength = 1
return _imagine_cmd(
ctx,
prompt_texts,
negative_prompt,
prompt_strength,
init_image,
init_image_strength,
outdir,
repeats,
height,
width,
steps,
seed,
upscale,
fix_faces,
fix_faces_fidelity,
sampler_type,
log_level,
quiet,
show_work,
tile,
tile_x,
tile_y,
mask_image,
mask_prompt,
mask_mode,
mask_modify_original,
outpaint,
caption,
precision,
model_weights_path,
model_config_path,
prompt_library_path,
version, # noqa
)
def _imagine_cmd(
ctx,
prompt_texts,
negative_prompt,
prompt_strength,
init_image,
init_image_strength,
outdir,
repeats,
height,
width,
steps,
seed,
upscale,
fix_faces,
fix_faces_fidelity,
sampler_type,
log_level,
quiet,
show_work,
tile,
tile_x,
tile_y,
mask_image,
mask_prompt,
mask_mode,
mask_modify_original,
outpaint,
caption,
precision,
model_weights_path,
model_config_path,
prompt_library_path,
version, # noqa
):
"""Have the AI generate images. alias:imagine."""
if ctx.invoked_subcommand is not None:
@ -595,6 +934,7 @@ def prune_ckpt(ckpt_paths):
aimg.add_command(imagine_cmd, name="imagine")
aimg.add_command(edit_image, name="edit")
if __name__ == "__main__":
imagine_cmd() # noqa

@ -23,6 +23,7 @@ class ModelConfig:
default_image_size: int
weights_url_full: str = None
forced_attn_precision: str = "default"
default_negative_prompt: str = DEFAULT_NEGATIVE_PROMPT
midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt"
@ -90,6 +91,13 @@ MODEL_CONFIGS = [
weights_url="https://huggingface.co/stabilityai/stable-diffusion-2-depth/resolve/main/512-depth-ema.ckpt",
default_image_size=512,
),
ModelConfig(
short_name="edit",
config_path="configs/instruct-pix2pix.yaml",
weights_url="https://huggingface.co/imaginairy/instruct-pix2pix/resolve/ea0009b3d0d4888f410a40bd06d69516d0b5a577/instruct-pix2pix-00-22000-pruned.ckpt",
default_image_size=512,
default_negative_prompt="",
)
# ModelConfig(
# short_name="SD-2.0-upscale",
# config_path="configs/stable-diffusion-v2-upscaling.yaml",

@ -0,0 +1,70 @@
model:
base_learning_rate: 1.0e-04
target: imaginairy.modules.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "edited"
cond_stage_key: "edit"
image_size: 16
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: false
scheduler_config:
target: imaginairy.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: imaginairy.modules.diffusion.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
legacy: False
first_stage_config:
target: imaginairy.modules.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: imaginairy.modules.clip_embedders.FrozenCLIPEmbedder

@ -122,12 +122,7 @@ def load_model_from_config(config, weights_location):
else:
state_dict = pl_sd
model = instantiate_from_config(config.model)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if len(missing_keys) > 0:
logger.debug(f"missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
logger.debug(f"unexpected keys: {unexpected_keys}")
model.init_from_state_dict(state_dict)
model.to(get_device())
model.eval()

@ -318,11 +318,30 @@ class DDPM(pl.LightningModule):
print(f"{context}: Restored training weights")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=(), only_model=False):
sd = torch.load(path, map_location="cpu")
def init_from_state_dict(self, sd, ignore_keys=(), only_model=False):
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
if self.cond_stage_key == "edit":
# from https://github.com/timothybrooks/instruct-pix2pix/blob/main/stable_diffusion/ldm/models/diffusion/ddpm_edit.py#L203-L221
input_keys = [
"model.diffusion_model.input_blocks.0.0.weight",
"model_ema.diffusion_modelinput_blocks00weight",
]
self_sd = self.state_dict()
for input_key in input_keys:
if input_key not in sd or input_key not in self_sd:
continue
input_weight = self_sd[input_key]
if input_weight.size() != sd[input_key].size():
input_weight.zero_()
input_weight[:, :4, :, :].copy_(sd[input_key])
ignore_keys.append(input_key)
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
@ -391,6 +410,11 @@ class DDPM(pl.LightningModule):
# if len(unexpected) > 0:
# print(f"\nUnexpected Keys:\n {unexpected}")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=(), only_model=False):
sd = torch.load(path, map_location="cpu")
self.init_from_state_dict(sd, ignore_keys=ignore_keys, only_model=only_model)
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
@ -1782,6 +1806,11 @@ class LatentFinetuneDiffusion(LatentDiffusion):
def init_from_ckpt(self, path, ignore_keys=(), only_model=False):
sd = torch.load(path, map_location="cpu")
return self.init_from_state_dict(
sd, ignore_keys=ignore_keys, only_model=only_model
)
def init_from_state_dict(self, sd, ignore_keys=(), only_model=False):
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
@ -1811,13 +1840,13 @@ class LatentFinetuneDiffusion(LatentDiffusion):
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
# print(
# f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
# )
# if len(missing) > 0:
# print(f"Missing Keys: {missing}")
# if len(unexpected) > 0:
# print(f"Unexpected Keys: {unexpected}")
@torch.no_grad()
def log_images( # noqa

@ -46,6 +46,7 @@ class DDIMSampler(ImageSampler):
initial_latent=None,
t_start=None,
quantize_x0=False,
**kwargs,
):
# print("Sampling with DDIM")
# print("num_steps", num_steps)

@ -0,0 +1,72 @@
"""
Wrapper for instruct pix2pix model.
modified from https://github.com/timothybrooks/instruct-pix2pix/blob/main/edit_cli.py
"""
import torch
from einops import einops
from torch import nn
from imaginairy.samplers.base import mask_blend
class CFGEditingDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(
self,
z,
sigma,
cond,
uncond,
cond_scale,
image_cfg_scale=1.5,
mask=None,
mask_noise=None,
orig_latent=None,
):
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
cfg_cond = {
"c_crossattn": [
torch.cat(
[
cond["c_crossattn"][0],
uncond["c_crossattn"][0],
uncond["c_crossattn"][0],
]
)
],
"c_concat": [
torch.cat(
[cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]]
)
],
}
if mask is not None:
assert orig_latent is not None
t = self.inner_model.sigma_to_t(sigma, quantize=True)
big_sigma = max(sigma, 1)
cfg_z = mask_blend(
noisy_latent=cfg_z,
orig_latent=orig_latent * big_sigma,
mask=mask,
mask_noise=mask_noise * big_sigma,
ts=t,
model=self.inner_model.inner_model,
)
out_cond, out_img_cond, out_uncond = self.inner_model(
cfg_z, cfg_sigma, cond=cfg_cond
).chunk(3)
result = (
out_uncond
+ cond_scale * (out_cond - out_img_cond)
+ image_cfg_scale * (out_img_cond - out_uncond)
)
return result

@ -81,6 +81,7 @@ class KDiffusionSampler(ImageSampler, ABC):
orig_latent=None,
initial_latent=None,
t_start=None,
denoiser_cls=None,
):
# if positive_conditioning.shape[0] != batch_size:
# raise ValueError(
@ -104,7 +105,9 @@ class KDiffusionSampler(ImageSampler, ABC):
x = initial_latent * sigmas[0]
log_latent(x, "initial_sigma_noised_tensor")
model_wrap_cfg = CFGDenoiser(self.cv_denoiser)
if denoiser_cls is None:
denoiser_cls = CFGDenoiser
model_wrap_cfg = denoiser_cls(self.cv_denoiser)
mask_noise = None
if mask is not None:

@ -94,7 +94,7 @@ class ImaginePrompt:
def __init__(
self,
prompt=None,
negative_prompt=config.DEFAULT_NEGATIVE_PROMPT,
negative_prompt=None,
prompt_strength=7.5,
init_image=None, # Pillow Image, LazyLoadingImage, or filepath str
init_image_strength=0.3,
@ -118,7 +118,6 @@ class ImaginePrompt:
):
self.prompts = self.process_prompt_input(prompt)
self.negative_prompt = self.process_prompt_input(negative_prompt)
self.prompt_strength = prompt_strength
if tile_mode is True:
tile_mode = "xy"
@ -159,6 +158,7 @@ class ImaginePrompt:
self.outpaint = outpaint
self.tile_mode = tile_mode
self.model = model
self.model_config_path = model_config_path
if self.height is None or self.width is None or self.steps is None:
SamplerCls = SAMPLER_LOOKUP[self.sampler_type]
@ -166,9 +166,25 @@ class ImaginePrompt:
self.width = self.width or get_model_default_image_size(self.model)
self.height = self.height or get_model_default_image_size(self.model)
if negative_prompt is None:
model_config = config.MODEL_CONFIG_SHORTCUTS.get(self.model, None)
if model_config:
negative_prompt = model_config.default_negative_prompt
else:
negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
self.negative_prompt = self.process_prompt_input(negative_prompt)
if self.model == "SD-2.0-v" and self.sampler_type == SamplerName.PLMS:
raise ValueError("PLMS sampler is not supported for SD-2.0-v model.")
self.model_config_path = model_config_path
if self.model == "edit" and self.sampler_type in (
SamplerName.PLMS,
SamplerName.DDIM,
):
raise ValueError(
"PLMS and DDIM samplers are not supported for pix2pix edit model."
)
@property
def prompt_text(self):

@ -13,7 +13,7 @@ linters = pylint,pycodestyle,pyflakes,mypy
ignore =
Z999,C0103,C0301,C0302,C0114,C0115,C0116,
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D413,D415,D417,
Z999,E203,E501,E1101,E1131,
Z999,E203,E501,E1101,E1131,E1135,E1136,
Z999,R0901,R0902,R0903,R0904,R0193,R0912,R0913,R0914,R0915,R1702,
Z999,W0221,W0511,W0612,W0613,W0632,W1203

Loading…
Cancel
Save