mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-19 03:25:41 +00:00
feature: add compilation animations (#224)
- add generation/compare gifs
This commit is contained in:
parent
a67683d318
commit
9ee09ac842
17
README.md
17
README.md
@ -16,6 +16,9 @@ AI imagined images. Pythonic generation of stable diffusion images.
|
|||||||
>> imagine "a scenic landscape" "a photo of a dog" "photo of a fruit bowl" "portrait photo of a freckled woman"
|
>> imagine "a scenic landscape" "a photo of a dog" "photo of a fruit bowl" "portrait photo of a freckled woman"
|
||||||
# Stable Diffusion 2.1
|
# Stable Diffusion 2.1
|
||||||
>> imagine --model SD-2.1 "a forest"
|
>> imagine --model SD-2.1 "a forest"
|
||||||
|
# Make generation gif
|
||||||
|
>> imagine --gif "a flower"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
<details closed>
|
<details closed>
|
||||||
@ -41,6 +44,7 @@ Generating 🖼 : "portrait photo of a freckled woman" 512x512px seed:500686645
|
|||||||
|
|
||||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000019_786355545_PLMS50_PS7.5_a_scenic_landscape.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000032_337692011_PLMS40_PS7.5_a_photo_of_a_dog.jpg" height="256"><br>
|
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000019_786355545_PLMS50_PS7.5_a_scenic_landscape.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000032_337692011_PLMS40_PS7.5_a_photo_of_a_dog.jpg" height="256"><br>
|
||||||
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000078_260972468_PLMS40_PS7.5_portrait_photo_of_a_freckled_woman.jpg" height="256">
|
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000056_293284644_PLMS40_PS7.5_photo_of_a_bowl_of_fruit.jpg" height="256"><img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/assets/000078_260972468_PLMS40_PS7.5_portrait_photo_of_a_freckled_woman.jpg" height="256">
|
||||||
|
<img src="assets/009719_942389026_kdpmpp2m15_PS7.5_a_flower.gif" height="256">
|
||||||
|
|
||||||
### 🎉 Edit Images with Instructions alone! [by InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix)
|
### 🎉 Edit Images with Instructions alone! [by InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix)
|
||||||
Just tell imaginairy how to edit the image and it will do it for you!
|
Just tell imaginairy how to edit the image and it will do it for you!
|
||||||
@ -49,17 +53,20 @@ with prompt-based masking.
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
>> aimg edit scenic_landscape.jpg "make it winter" --prompt-strength 20
|
>> aimg edit scenic_landscape.jpg "make it winter" --prompt-strength 20
|
||||||
|
>> aimg edit scenic_landscape.jpg "make it winter" --steps 30 --arg-schedule "prompt_strength[2:25:0.5]" --compilation-anim
|
||||||
>> aimg edit dog.jpg "make the dog red" --prompt-strength 5
|
>> aimg edit dog.jpg "make the dog red" --prompt-strength 5
|
||||||
>> aimg edit bowl_of_fruit.jpg "replace the fruit with strawberries"
|
>> aimg edit bowl_of_fruit.jpg "replace the fruit with strawberries"
|
||||||
>> aimg edit freckled_woman.jpg "make her a cyborg" --prompt-strength 13
|
>> aimg edit freckled_woman.jpg "make her a cyborg" --prompt-strength 13
|
||||||
>> aimg edit pearl_girl.jpg "make her wear clown makup"
|
# create a comparison gif
|
||||||
>> aimg edit mona-lisa.jpg "make it a color professional photo headshot" --negative-prompt "old, ugly"
|
>> aimg edit pearl_girl.jpg "make her wear clown makeup" --compare-gif
|
||||||
|
# create an animation showing the edit with increasing prompt strengths
|
||||||
|
>> aimg edit mona-lisa.jpg "make it a color professional photo headshot" --negative-prompt "old, ugly, blurry" --arg-schedule "prompt-strength[2:8:0.5]" --compilation-anim gif
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
<img src="assets/scenic_landscape_winter.jpg" height="256"><img src="assets/dog_red.jpg" height="256"><br>
|
<img src="assets/scenic_landscape_winter.jpg" height="256"><img src="assets/dog_red.jpg" height="256"><br>
|
||||||
<img src="assets/bowl_of_fruit_strawberries.jpg" height="256"><img src="assets/freckled_woman_cyborg.jpg" height="256"><br>
|
<img src="assets/bowl_of_fruit_strawberries.jpg" height="256"><img src="assets/freckled_woman_cyborg.jpg" height="256"><br>
|
||||||
<img src="assets/girl_with_a_pearl_earring_clown_makeup.jpg" height="256"><img src="assets/mona-lisa-headshot-photo.jpg" height="256"><br>
|
<img src="assets/girl-pearl-clown-compare.gif" height="256"><img src="assets/mona-lisa-headshot-anim.gif" height="256"><br>
|
||||||
|
|
||||||
Want just quickly have some fun? Try `--surprise-me` to apply some pre-defined edits.
|
Want just quickly have some fun? Try `--surprise-me` to apply some pre-defined edits.
|
||||||
```bash
|
```bash
|
||||||
@ -283,8 +290,12 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
|||||||
|
|
||||||
## ChangeLog
|
## ChangeLog
|
||||||
|
|
||||||
|
- feature: create `gifs` or `mp4s` from any images made in a single run with `--compilation-anim gif`
|
||||||
|
- feature: create a series of images or edits by iterating over a parameter with the `--arg-schedule` argument
|
||||||
- feature: `openjourney-v1` and `openjourney-v2` models added. available via `--model openjourney-v2`
|
- feature: `openjourney-v1` and `openjourney-v2` models added. available via `--model openjourney-v2`
|
||||||
- feature: add upscale command line function: `aimg upscale`
|
- feature: add upscale command line function: `aimg upscale`
|
||||||
|
- feature: `--gif` option will create a gif showing the generation process for a single image
|
||||||
|
- feature: `--compare-gif` option will create a comparison gif for any image edits
|
||||||
- fix: tile mode was broken since latest perf improvements
|
- fix: tile mode was broken since latest perf improvements
|
||||||
|
|
||||||
**8.2.0**
|
**8.2.0**
|
||||||
|
BIN
assets/009719_942389026_kdpmpp2m15_PS7.5_a_flower.gif
Normal file
BIN
assets/009719_942389026_kdpmpp2m15_PS7.5_a_flower.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.7 MiB |
BIN
assets/girl-pearl-clown-compare.gif
Normal file
BIN
assets/girl-pearl-clown-compare.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 443 KiB |
BIN
assets/mona-lisa-headshot-anim.gif
Normal file
BIN
assets/mona-lisa-headshot-anim.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.3 MiB |
40
docs/examples/generate_doc_examples.py
Normal file
40
docs/examples/generate_doc_examples.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
prompts = [
|
||||||
|
ImaginePrompt(
|
||||||
|
"make her wear clown makeup",
|
||||||
|
seed=952243488,
|
||||||
|
model="edit",
|
||||||
|
init_image=LazyLoadingImage(
|
||||||
|
url="https://github.com/brycedrennan/imaginAIry/raw/2a3e19f5a1a864fcee18c23f17aea02cc0f61bbf/assets/girl_with_a_pearl_earring.jpg"
|
||||||
|
),
|
||||||
|
steps=30,
|
||||||
|
),
|
||||||
|
ImaginePrompt(
|
||||||
|
"make her wear clown makeup",
|
||||||
|
seed=952243488,
|
||||||
|
model="edit",
|
||||||
|
init_image=LazyLoadingImage(
|
||||||
|
url="https://github.com/brycedrennan/imaginAIry/raw/2a3e19f5a1a864fcee18c23f17aea02cc0f61bbf/assets/girl_with_a_pearl_earring.jpg"
|
||||||
|
),
|
||||||
|
steps=30,
|
||||||
|
),
|
||||||
|
ImaginePrompt(
|
||||||
|
"make it a color professional photo headshot",
|
||||||
|
negative_prompt="old, ugly, blurry",
|
||||||
|
seed=390919410,
|
||||||
|
model="edit",
|
||||||
|
init_image=LazyLoadingImage(
|
||||||
|
url="https://github.com/brycedrennan/imaginAIry/raw/2a3e19f5a1a864fcee18c23f17aea02cc0f61bbf/assets/mona-lisa.jpg"
|
||||||
|
),
|
||||||
|
steps=30,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
imagine_image_files(prompts, outdir="./outputs", make_gif=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
112
imaginairy/animations.py
Normal file
112
imaginairy/animations.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import os.path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from imaginairy.img_utils import (
|
||||||
|
add_caption_to_image,
|
||||||
|
imgpaths_to_imgs,
|
||||||
|
model_latents_to_pillow_imgs,
|
||||||
|
pillow_img_to_opencv_img,
|
||||||
|
)
|
||||||
|
from imaginairy.utils import shrink_list
|
||||||
|
|
||||||
|
|
||||||
|
def make_bounce_animation(
|
||||||
|
imgs,
|
||||||
|
outpath,
|
||||||
|
transition_duration_ms=500,
|
||||||
|
start_pause_duration_ms=1000,
|
||||||
|
end_pause_duration_ms=2000,
|
||||||
|
):
|
||||||
|
first_img = imgs[0]
|
||||||
|
last_img = imgs[-1]
|
||||||
|
middle_imgs = imgs[1:-1]
|
||||||
|
max_fps = 20
|
||||||
|
max_frames = int(round(transition_duration_ms / 1000 * max_fps))
|
||||||
|
min_duration = int(1000 / 20)
|
||||||
|
if middle_imgs:
|
||||||
|
progress_duration = int(round(transition_duration_ms / len(middle_imgs)))
|
||||||
|
else:
|
||||||
|
progress_duration = 0
|
||||||
|
progress_duration = max(progress_duration, min_duration)
|
||||||
|
|
||||||
|
middle_imgs = shrink_list(middle_imgs, max_frames)
|
||||||
|
|
||||||
|
frames = [first_img] + middle_imgs + [last_img] + list(reversed(middle_imgs))
|
||||||
|
|
||||||
|
# convert from latents
|
||||||
|
converted_frames = []
|
||||||
|
for frame in frames:
|
||||||
|
if isinstance(frame, torch.Tensor):
|
||||||
|
frame = model_latents_to_pillow_imgs(frame)[0]
|
||||||
|
converted_frames.append(frame)
|
||||||
|
frames = converted_frames
|
||||||
|
|
||||||
|
durations = (
|
||||||
|
[start_pause_duration_ms]
|
||||||
|
+ [progress_duration] * len(middle_imgs)
|
||||||
|
+ [end_pause_duration_ms]
|
||||||
|
+ [progress_duration] * len(middle_imgs)
|
||||||
|
)
|
||||||
|
|
||||||
|
make_animation(imgs=frames, outpath=outpath, frame_duration_ms=durations)
|
||||||
|
|
||||||
|
|
||||||
|
def make_animation(imgs, outpath, frame_duration_ms=100, captions=None):
|
||||||
|
imgs = imgpaths_to_imgs(imgs)
|
||||||
|
ext = os.path.splitext(outpath)[1].lower().strip(".")
|
||||||
|
|
||||||
|
if captions:
|
||||||
|
if len(captions) != len(imgs):
|
||||||
|
raise ValueError("Captions and images must be of same length.")
|
||||||
|
for img, caption in zip(imgs, captions):
|
||||||
|
add_caption_to_image(img, caption)
|
||||||
|
|
||||||
|
if ext == "gif":
|
||||||
|
make_gif_animation(
|
||||||
|
imgs=imgs, outpath=outpath, frame_duration_ms=frame_duration_ms
|
||||||
|
)
|
||||||
|
elif ext == "mp4":
|
||||||
|
make_mp4_animation(
|
||||||
|
imgs=imgs, outpath=outpath, frame_duration_ms=frame_duration_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_gif_animation(imgs, outpath, frame_duration_ms=100, loop=0):
|
||||||
|
imgs = imgpaths_to_imgs(imgs)
|
||||||
|
imgs[0].save(
|
||||||
|
outpath,
|
||||||
|
save_all=True,
|
||||||
|
append_images=imgs[1:],
|
||||||
|
duration=frame_duration_ms,
|
||||||
|
loop=loop,
|
||||||
|
optimize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_mp4_animation(imgs, outpath, frame_duration_ms=50, fps=30, codec="mp4v"):
|
||||||
|
imgs = imgpaths_to_imgs(imgs)
|
||||||
|
frame_size = imgs[0].size
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*codec)
|
||||||
|
out = cv2.VideoWriter(outpath, fourcc, fps, frame_size)
|
||||||
|
if not isinstance(frame_duration_ms, list):
|
||||||
|
frame_duration_ms = [frame_duration_ms] * len(imgs)
|
||||||
|
try:
|
||||||
|
for image in select_images_by_duration_at_fps(imgs, frame_duration_ms, fps):
|
||||||
|
image = pillow_img_to_opencv_img(image)
|
||||||
|
out.write(image)
|
||||||
|
finally:
|
||||||
|
out.release()
|
||||||
|
|
||||||
|
|
||||||
|
def select_images_by_duration_at_fps(images, durations_ms, fps=30):
|
||||||
|
"""select the proper image to show for each frame of a video."""
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
duration = durations_ms[i] / 1000
|
||||||
|
num_frames = int(round(duration * fps))
|
||||||
|
print(
|
||||||
|
f"Showing image {i} for {num_frames} frames for {durations_ms[i]}ms at {fps} fps."
|
||||||
|
)
|
||||||
|
for j in range(num_frames):
|
||||||
|
yield image
|
@ -10,16 +10,12 @@ from PIL import Image, ImageDraw, ImageOps
|
|||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from torch.cuda import OutOfMemoryError
|
from torch.cuda import OutOfMemoryError
|
||||||
|
|
||||||
|
from imaginairy.animations import make_bounce_animation
|
||||||
from imaginairy.enhancers.clip_masking import get_img_mask
|
from imaginairy.enhancers.clip_masking import get_img_mask
|
||||||
from imaginairy.enhancers.describe_image_blip import generate_caption
|
from imaginairy.enhancers.describe_image_blip import generate_caption
|
||||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||||
from imaginairy.img_utils import (
|
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
||||||
make_gif_image,
|
|
||||||
model_latents_to_pillow_imgs,
|
|
||||||
pillow_fit_image_within,
|
|
||||||
pillow_img_to_torch_image,
|
|
||||||
)
|
|
||||||
from imaginairy.log_utils import (
|
from imaginairy.log_utils import (
|
||||||
ImageLoggingContext,
|
ImageLoggingContext,
|
||||||
log_conditioning,
|
log_conditioning,
|
||||||
@ -65,6 +61,7 @@ def imagine_image_files(
|
|||||||
output_file_extension="jpg",
|
output_file_extension="jpg",
|
||||||
print_caption=False,
|
print_caption=False,
|
||||||
make_gif=False,
|
make_gif=False,
|
||||||
|
make_compare_gif=False,
|
||||||
return_filename_type="generated",
|
return_filename_type="generated",
|
||||||
):
|
):
|
||||||
generated_imgs_path = os.path.join(outdir, "generated")
|
generated_imgs_path = os.path.join(outdir, "generated")
|
||||||
@ -123,44 +120,31 @@ def imagine_image_files(
|
|||||||
os.makedirs(subpath, exist_ok=True)
|
os.makedirs(subpath, exist_ok=True)
|
||||||
filepath = os.path.join(subpath, f"{basefilename}.gif")
|
filepath = os.path.join(subpath, f"{basefilename}.gif")
|
||||||
|
|
||||||
transition_length = 1500
|
frames = result.progress_latents + [result.images["generated"]]
|
||||||
pause_length_ms = 500
|
|
||||||
max_fps = 20
|
|
||||||
max_frames = int(round(transition_length / 1000 * max_fps))
|
|
||||||
|
|
||||||
usable_latents = shrink_list(result.progress_latents, max_frames)
|
|
||||||
progress_imgs = [
|
|
||||||
model_latents_to_pillow_imgs(latent)[0] for latent in usable_latents
|
|
||||||
]
|
|
||||||
frames = (
|
|
||||||
progress_imgs
|
|
||||||
+ [result.images["generated"]]
|
|
||||||
+ list(reversed(progress_imgs))
|
|
||||||
)
|
|
||||||
progress_duration = int(round(300 / len(frames)))
|
|
||||||
min_duration = int(1000 / 20)
|
|
||||||
progress_duration = max(progress_duration, min_duration)
|
|
||||||
durations = (
|
|
||||||
[progress_duration] * len(progress_imgs)
|
|
||||||
+ [pause_length_ms]
|
|
||||||
+ [progress_duration] * len(progress_imgs)
|
|
||||||
)
|
|
||||||
assert len(frames) == len(durations)
|
|
||||||
if prompt.init_image:
|
if prompt.init_image:
|
||||||
resized_init_image = pillow_fit_image_within(
|
resized_init_image = pillow_fit_image_within(
|
||||||
prompt.init_image, prompt.width, prompt.height
|
prompt.init_image, prompt.width, prompt.height
|
||||||
)
|
)
|
||||||
frames = [resized_init_image] + frames
|
frames = [resized_init_image] + frames
|
||||||
durations = [pause_length_ms] + durations
|
|
||||||
else:
|
|
||||||
durations[0] = pause_length_ms
|
|
||||||
|
|
||||||
make_gif_image(
|
make_bounce_animation(imgs=frames, outpath=filepath)
|
||||||
filepath,
|
logger.info(f" [gif] {len(frames)} frames saved to: {filepath}")
|
||||||
imgs=frames,
|
if make_compare_gif and prompt.init_image:
|
||||||
duration=durations,
|
subpath = os.path.join(outdir, "gif")
|
||||||
|
os.makedirs(subpath, exist_ok=True)
|
||||||
|
filepath = os.path.join(subpath, f"{basefilename}_[compare].gif")
|
||||||
|
resized_init_image = pillow_fit_image_within(
|
||||||
|
prompt.init_image, prompt.width, prompt.height
|
||||||
)
|
)
|
||||||
logger.info(f" [gif] saved to: {filepath}")
|
frames = [resized_init_image, result.images["generated"]]
|
||||||
|
|
||||||
|
make_bounce_animation(
|
||||||
|
imgs=frames,
|
||||||
|
outpath=filepath,
|
||||||
|
)
|
||||||
|
logger.info(f" [gif-comparison] saved to: {filepath}")
|
||||||
|
|
||||||
base_count += 1
|
base_count += 1
|
||||||
del result
|
del result
|
||||||
|
|
||||||
@ -589,12 +573,3 @@ def _prompts_to_embeddings(prompts, model):
|
|||||||
|
|
||||||
def prompt_normalized(prompt):
|
def prompt_normalized(prompt):
|
||||||
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]
|
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]
|
||||||
|
|
||||||
|
|
||||||
def shrink_list(items, max_size):
|
|
||||||
if len(items) <= max_size:
|
|
||||||
return items
|
|
||||||
num_to_remove = len(items) - max_size
|
|
||||||
interval = int(round(len(items) / num_to_remove))
|
|
||||||
|
|
||||||
return [val for i, val in enumerate(items) if i % interval != 0]
|
|
||||||
|
@ -7,11 +7,13 @@ from click_shell import shell
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from imaginairy import LazyLoadingImage, __version__, config, generate_caption
|
from imaginairy import LazyLoadingImage, __version__, config, generate_caption
|
||||||
|
from imaginairy.animations import make_bounce_animation
|
||||||
from imaginairy.api import imagine_image_files
|
from imaginairy.api import imagine_image_files
|
||||||
from imaginairy.debug_info import get_debug_info
|
from imaginairy.debug_info import get_debug_info
|
||||||
from imaginairy.enhancers.prompt_expansion import expand_prompts
|
from imaginairy.enhancers.prompt_expansion import expand_prompts
|
||||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||||
from imaginairy.log_utils import configure_logging
|
from imaginairy.log_utils import configure_logging
|
||||||
|
from imaginairy.prompt_schedules import parse_schedule_strs, prompt_mutator
|
||||||
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS
|
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS
|
||||||
from imaginairy.schema import ImaginePrompt
|
from imaginairy.schema import ImaginePrompt
|
||||||
from imaginairy.surprise_me import create_surprise_me_images
|
from imaginairy.surprise_me import create_surprise_me_images
|
||||||
@ -47,8 +49,8 @@ logger = logging.getLogger(__name__)
|
|||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--init-image-strength",
|
"--init-image-strength",
|
||||||
default=0.6,
|
default=None,
|
||||||
show_default=True,
|
show_default=False,
|
||||||
help="Starting image strength. Between 0 and 1.",
|
help="Starting image strength. Between 0 and 1.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -231,6 +233,26 @@ logger = logging.getLogger(__name__)
|
|||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="Generate a gif of the generation.",
|
help="Generate a gif of the generation.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--compare-gif",
|
||||||
|
"make_compare_gif",
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Create a gif comparing the original image to the modified one.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--arg-schedule",
|
||||||
|
"arg_schedules",
|
||||||
|
multiple=True,
|
||||||
|
help="Schedule how an argument should change over several generations. Format: `--arg-schedule arg_name[start:end:increment]` or `--arg-schedule arg_name[val,val2,val3]`",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--compilation-anim",
|
||||||
|
"make_compilation_animation",
|
||||||
|
default=None,
|
||||||
|
type=click.Choice(["gif", "mp4"]),
|
||||||
|
help="Generate an animation composed of all the images generated in this run. Defaults to gif but `--compilation-anim mp4` will generate an mp4 instead.",
|
||||||
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def imagine_cmd(
|
def imagine_cmd(
|
||||||
ctx,
|
ctx,
|
||||||
@ -267,6 +289,9 @@ def imagine_cmd(
|
|||||||
prompt_library_path,
|
prompt_library_path,
|
||||||
version, # noqa
|
version, # noqa
|
||||||
make_gif,
|
make_gif,
|
||||||
|
make_compare_gif,
|
||||||
|
arg_schedules,
|
||||||
|
make_compilation_animation,
|
||||||
):
|
):
|
||||||
"""Have the AI generate images. alias:imagine."""
|
"""Have the AI generate images. alias:imagine."""
|
||||||
return _imagine_cmd(
|
return _imagine_cmd(
|
||||||
@ -304,6 +329,9 @@ def imagine_cmd(
|
|||||||
prompt_library_path,
|
prompt_library_path,
|
||||||
version, # noqa
|
version, # noqa
|
||||||
make_gif,
|
make_gif,
|
||||||
|
make_compare_gif,
|
||||||
|
arg_schedules,
|
||||||
|
make_compilation_animation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -505,7 +533,14 @@ def imagine_cmd(
|
|||||||
"make_gif",
|
"make_gif",
|
||||||
default=False,
|
default=False,
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="Generate a gif comparing the original image to the modified one.",
|
help="Create a gif showing the generation process.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--compare-gif",
|
||||||
|
"make_compare_gif",
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Create a gif comparing the original image to the modified one.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--surprise-me",
|
"--surprise-me",
|
||||||
@ -514,6 +549,19 @@ def imagine_cmd(
|
|||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="make some fun edits to the provided image",
|
help="make some fun edits to the provided image",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--arg-schedule",
|
||||||
|
"arg_schedules",
|
||||||
|
multiple=True,
|
||||||
|
help="Schedule how an argument should change over several generations. Format: `--arg-schedule arg_name[start:end:increment]` or `--arg-schedule arg_name[val,val2,val3]`",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--compilation-anim",
|
||||||
|
"make_compilation_animation",
|
||||||
|
default=None,
|
||||||
|
type=click.Choice(["gif", "mp4"]),
|
||||||
|
help="Generate an animation composed of all the images generated in this run. Defaults to gif but `--compilation-anim mp4` will generate an mp4 instead.",
|
||||||
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def edit_image( # noqa
|
def edit_image( # noqa
|
||||||
ctx,
|
ctx,
|
||||||
@ -549,7 +597,10 @@ def edit_image( # noqa
|
|||||||
prompt_library_path,
|
prompt_library_path,
|
||||||
version, # noqa
|
version, # noqa
|
||||||
make_gif,
|
make_gif,
|
||||||
|
make_compare_gif,
|
||||||
surprise_me,
|
surprise_me,
|
||||||
|
arg_schedules,
|
||||||
|
make_compilation_animation,
|
||||||
):
|
):
|
||||||
init_image_strength = 1
|
init_image_strength = 1
|
||||||
if surprise_me and prompt_texts:
|
if surprise_me and prompt_texts:
|
||||||
@ -600,6 +651,9 @@ def edit_image( # noqa
|
|||||||
prompt_library_path,
|
prompt_library_path,
|
||||||
version, # noqa
|
version, # noqa
|
||||||
make_gif,
|
make_gif,
|
||||||
|
make_compare_gif,
|
||||||
|
arg_schedules,
|
||||||
|
make_compilation_animation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -638,6 +692,9 @@ def _imagine_cmd(
|
|||||||
prompt_library_path,
|
prompt_library_path,
|
||||||
version=False, # noqa
|
version=False, # noqa
|
||||||
make_gif=False,
|
make_gif=False,
|
||||||
|
make_compare_gif=False,
|
||||||
|
arg_schedules=None,
|
||||||
|
make_compilation_animation=False,
|
||||||
):
|
):
|
||||||
"""Have the AI generate images. alias:imagine."""
|
"""Have the AI generate images. alias:imagine."""
|
||||||
if ctx.invoked_subcommand is not None:
|
if ctx.invoked_subcommand is not None:
|
||||||
@ -662,6 +719,12 @@ def _imagine_cmd(
|
|||||||
if mask_image and mask_image.startswith("http"):
|
if mask_image and mask_image.startswith("http"):
|
||||||
mask_image = LazyLoadingImage(url=mask_image)
|
mask_image = LazyLoadingImage(url=mask_image)
|
||||||
|
|
||||||
|
if init_image_strength is None:
|
||||||
|
if outpaint or mask_image or mask_prompt:
|
||||||
|
init_image_strength = 0
|
||||||
|
else:
|
||||||
|
init_image_strength = 0.6
|
||||||
|
|
||||||
prompts = []
|
prompts = []
|
||||||
prompt_expanding_iterators = {}
|
prompt_expanding_iterators = {}
|
||||||
for _ in range(repeats):
|
for _ in range(repeats):
|
||||||
@ -705,9 +768,14 @@ def _imagine_cmd(
|
|||||||
model=model_weights_path,
|
model=model_weights_path,
|
||||||
model_config_path=model_config_path,
|
model_config_path=model_config_path,
|
||||||
)
|
)
|
||||||
prompts.append(prompt)
|
if arg_schedules:
|
||||||
|
schedules = parse_schedule_strs(arg_schedules)
|
||||||
|
for new_prompt in prompt_mutator(prompt, schedules):
|
||||||
|
prompts.append(new_prompt)
|
||||||
|
else:
|
||||||
|
prompts.append(prompt)
|
||||||
|
|
||||||
imagine_image_files(
|
filenames = imagine_image_files(
|
||||||
prompts,
|
prompts,
|
||||||
outdir=outdir,
|
outdir=outdir,
|
||||||
record_step_images=show_work,
|
record_step_images=show_work,
|
||||||
@ -715,7 +783,20 @@ def _imagine_cmd(
|
|||||||
print_caption=caption,
|
print_caption=caption,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
make_gif=make_gif,
|
make_gif=make_gif,
|
||||||
|
make_compare_gif=make_compare_gif,
|
||||||
)
|
)
|
||||||
|
if make_compilation_animation:
|
||||||
|
ext = make_compilation_animation
|
||||||
|
|
||||||
|
compilation_outdir = os.path.join(outdir, "compilations")
|
||||||
|
base_count = len(os.listdir(compilation_outdir))
|
||||||
|
new_filename = os.path.join(
|
||||||
|
compilation_outdir, f"{base_count:04d}_compilation.{ext}"
|
||||||
|
)
|
||||||
|
comp_imgs = [LazyLoadingImage(filepath=f) for f in filenames]
|
||||||
|
make_bounce_animation(outpath=new_filename, imgs=comp_imgs)
|
||||||
|
|
||||||
|
logger.info(f"[compilation] saved to: {new_filename}")
|
||||||
|
|
||||||
|
|
||||||
@shell(prompt="🤖🧠> ", intro="Starting imaginAIry...")
|
@shell(prompt="🤖🧠> ", intro="Starting imaginAIry...")
|
||||||
|
@ -4,8 +4,10 @@ import numpy as np
|
|||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from PIL import Image
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
from imaginairy.paths import PKG_ROOT
|
||||||
|
from imaginairy.schema import LazyLoadingImage
|
||||||
from imaginairy.utils import get_device
|
from imaginairy.utils import get_device
|
||||||
|
|
||||||
|
|
||||||
@ -72,13 +74,33 @@ def pillow_img_to_model_latent(model, img, batch_size=1, half=True):
|
|||||||
return model.get_first_stage_encoding(model.encode_first_stage(init_image))
|
return model.get_first_stage_encoding(model.encode_first_stage(init_image))
|
||||||
|
|
||||||
|
|
||||||
def make_gif_image(filepath, imgs, duration=1000, loop=0):
|
def imgpaths_to_imgs(imgpaths):
|
||||||
|
imgs = []
|
||||||
|
for imgpath in imgpaths:
|
||||||
|
if isinstance(imgpath, str):
|
||||||
|
img = LazyLoadingImage(filepath=imgpath)
|
||||||
|
imgs.append(img)
|
||||||
|
else:
|
||||||
|
imgs.append(imgpath)
|
||||||
|
|
||||||
imgs[0].save(
|
return imgs
|
||||||
filepath,
|
|
||||||
save_all=True,
|
|
||||||
append_images=imgs[1:],
|
def add_caption_to_image(
|
||||||
duration=duration,
|
img, caption, font_size=16, font_path=f"{PKG_ROOT}/data/DejaVuSans.ttf"
|
||||||
loop=loop,
|
):
|
||||||
optimize=False,
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
font = ImageFont.truetype(font_path, font_size)
|
||||||
|
|
||||||
|
x = 15
|
||||||
|
y = img.height - 15 - font_size
|
||||||
|
|
||||||
|
draw.text(
|
||||||
|
(x, y),
|
||||||
|
caption,
|
||||||
|
font=font,
|
||||||
|
fill=(255, 255, 255),
|
||||||
|
stroke_width=3,
|
||||||
|
stroke_fill=(0, 0, 0),
|
||||||
)
|
)
|
||||||
|
74
imaginairy/prompt_schedules.py
Normal file
74
imaginairy/prompt_schedules.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import csv
|
||||||
|
import re
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
from imaginairy import ImaginePrompt
|
||||||
|
from imaginairy.utils import frange
|
||||||
|
|
||||||
|
|
||||||
|
def parse_schedule_str(schedule_str):
|
||||||
|
"""Parse a schedule string into a list of values."""
|
||||||
|
pattern = re.compile(r"([a-zA-Z0-9_-]+)\[([a-zA-Z0-9_:,. -]+)\]")
|
||||||
|
match = pattern.match(schedule_str)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Invalid kwarg schedule: {schedule_str}")
|
||||||
|
|
||||||
|
arg_name = match.group(1).replace("-", "_")
|
||||||
|
if not hasattr(ImaginePrompt(), arg_name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid kwarg schedule. Not a valid argument name: {arg_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
arg_values = match.group(2)
|
||||||
|
if ":" in arg_values:
|
||||||
|
start, end, step = arg_values.split(":")
|
||||||
|
arg_values = list(frange(float(start), float(end), float(step)))
|
||||||
|
else:
|
||||||
|
arg_values = parse_csv_line(arg_values)
|
||||||
|
return arg_name, arg_values
|
||||||
|
|
||||||
|
|
||||||
|
def parse_schedule_strs(schedule_strs):
|
||||||
|
"""Parse and validate input prompt schedules."""
|
||||||
|
schedules = {}
|
||||||
|
for schedule_str in schedule_strs:
|
||||||
|
arg_name, arg_values = parse_schedule_str(schedule_str)
|
||||||
|
schedules[arg_name] = arg_values
|
||||||
|
|
||||||
|
# Validate that all schedules have the same length
|
||||||
|
schedule_lengths = [len(v) for v in schedules.values()]
|
||||||
|
if len(set(schedule_lengths)) > 1:
|
||||||
|
raise ValueError("All schedules must have the same length")
|
||||||
|
|
||||||
|
return schedules
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_mutator(prompt, schedules):
|
||||||
|
"""
|
||||||
|
Given a prompt and a list of kwarg schedules, return a series of prompts that follow the schedule.
|
||||||
|
|
||||||
|
kwarg_schedules example:
|
||||||
|
{
|
||||||
|
"prompt_strength": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
schedule_length = len(list(schedules.values())[0])
|
||||||
|
for i in range(schedule_length):
|
||||||
|
new_prompt = copy(prompt)
|
||||||
|
for attr_name, schedule in schedules.items():
|
||||||
|
setattr(new_prompt, attr_name, schedule[i])
|
||||||
|
new_prompt.validate()
|
||||||
|
yield new_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def parse_csv_line(line):
|
||||||
|
reader = csv.reader([line])
|
||||||
|
for row in reader:
|
||||||
|
parsed_row = []
|
||||||
|
for value in row:
|
||||||
|
try:
|
||||||
|
parsed_row.append(float(value))
|
||||||
|
except ValueError:
|
||||||
|
parsed_row.append(value)
|
||||||
|
return parsed_row
|
@ -118,42 +118,20 @@ class ImaginePrompt:
|
|||||||
is_intermediate=False,
|
is_intermediate=False,
|
||||||
collect_progress_latents=False,
|
collect_progress_latents=False,
|
||||||
):
|
):
|
||||||
|
self.prompts = prompt
|
||||||
self.prompts = self.process_prompt_input(prompt)
|
self.negative_prompt = negative_prompt
|
||||||
self.prompt_strength = prompt_strength
|
self.prompt_strength = prompt_strength
|
||||||
if tile_mode is True:
|
|
||||||
tile_mode = "xy"
|
|
||||||
elif tile_mode is False:
|
|
||||||
tile_mode = ""
|
|
||||||
else:
|
|
||||||
tile_mode = tile_mode.lower()
|
|
||||||
assert tile_mode in ("", "x", "y", "xy")
|
|
||||||
|
|
||||||
if isinstance(init_image, str):
|
|
||||||
if not init_image.startswith("*prev."):
|
|
||||||
init_image = LazyLoadingImage(filepath=init_image)
|
|
||||||
|
|
||||||
if isinstance(mask_image, str):
|
|
||||||
if not init_image.startswith("*prev."):
|
|
||||||
mask_image = LazyLoadingImage(filepath=mask_image)
|
|
||||||
|
|
||||||
if mask_image is not None and mask_prompt is not None:
|
|
||||||
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
|
|
||||||
if model is None:
|
|
||||||
model = config.DEFAULT_MODEL
|
|
||||||
|
|
||||||
self.init_image = init_image
|
self.init_image = init_image
|
||||||
self.init_image_strength = init_image_strength
|
self.init_image_strength = init_image_strength
|
||||||
self.seed = random.randint(1, 1_000_000_000) if seed is None else seed
|
self._orig_seed = seed
|
||||||
|
self.seed = seed
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
self.upscale = upscale
|
self.upscale = upscale
|
||||||
self.fix_faces = fix_faces
|
self.fix_faces = fix_faces
|
||||||
self.fix_faces_fidelity = (
|
self.fix_faces_fidelity = fix_faces_fidelity
|
||||||
fix_faces_fidelity if fix_faces_fidelity else self.DEFAULT_FACE_FIDELITY
|
self.sampler_type = sampler_type
|
||||||
)
|
|
||||||
self.sampler_type = sampler_type.lower()
|
|
||||||
self.conditioning = conditioning
|
self.conditioning = conditioning
|
||||||
self.mask_prompt = mask_prompt
|
self.mask_prompt = mask_prompt
|
||||||
self.mask_image = mask_image
|
self.mask_image = mask_image
|
||||||
@ -167,20 +145,56 @@ class ImaginePrompt:
|
|||||||
self.is_intermediate = is_intermediate
|
self.is_intermediate = is_intermediate
|
||||||
self.collect_progress_latents = collect_progress_latents
|
self.collect_progress_latents = collect_progress_latents
|
||||||
|
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
self.prompts = self.process_prompt_input(self.prompts)
|
||||||
|
|
||||||
|
if self.tile_mode is True:
|
||||||
|
self.tile_mode = "xy"
|
||||||
|
elif self.tile_mode is False:
|
||||||
|
self.tile_mode = ""
|
||||||
|
else:
|
||||||
|
self.tile_mode = self.tile_mode.lower()
|
||||||
|
assert self.tile_mode in ("", "x", "y", "xy")
|
||||||
|
|
||||||
|
if isinstance(self.init_image, str):
|
||||||
|
if not self.init_image.startswith("*prev."):
|
||||||
|
self.init_image = LazyLoadingImage(filepath=self.init_image)
|
||||||
|
|
||||||
|
if isinstance(self.mask_image, str):
|
||||||
|
if not self.mask_image.startswith("*prev."):
|
||||||
|
self.mask_image = LazyLoadingImage(filepath=self.mask_image)
|
||||||
|
|
||||||
|
if self.mask_image is not None and self.mask_prompt is not None:
|
||||||
|
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
|
||||||
|
if self.model is None:
|
||||||
|
self.model = config.DEFAULT_MODEL
|
||||||
|
|
||||||
|
self.seed = random.randint(1, 1_000_000_000) if self.seed is None else self.seed
|
||||||
|
|
||||||
|
self.sampler_type = self.sampler_type.lower()
|
||||||
|
|
||||||
|
self.fix_faces_fidelity = (
|
||||||
|
self.fix_faces_fidelity
|
||||||
|
if self.fix_faces_fidelity
|
||||||
|
else self.DEFAULT_FACE_FIDELITY
|
||||||
|
)
|
||||||
|
|
||||||
if self.height is None or self.width is None or self.steps is None:
|
if self.height is None or self.width is None or self.steps is None:
|
||||||
SamplerCls = SAMPLER_LOOKUP[self.sampler_type]
|
SamplerCls = SAMPLER_LOOKUP[self.sampler_type]
|
||||||
self.steps = self.steps or SamplerCls.default_steps
|
self.steps = self.steps or SamplerCls.default_steps
|
||||||
self.width = self.width or get_model_default_image_size(self.model)
|
self.width = self.width or get_model_default_image_size(self.model)
|
||||||
self.height = self.height or get_model_default_image_size(self.model)
|
self.height = self.height or get_model_default_image_size(self.model)
|
||||||
|
|
||||||
if negative_prompt is None:
|
if self.negative_prompt is None:
|
||||||
model_config = config.MODEL_CONFIG_SHORTCUTS.get(self.model, None)
|
model_config = config.MODEL_CONFIG_SHORTCUTS.get(self.model, None)
|
||||||
if model_config:
|
if model_config:
|
||||||
negative_prompt = model_config.default_negative_prompt
|
self.negative_prompt = model_config.default_negative_prompt
|
||||||
else:
|
else:
|
||||||
negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
|
self.negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
|
||||||
|
|
||||||
self.negative_prompt = self.process_prompt_input(negative_prompt)
|
self.negative_prompt = self.process_prompt_input(self.negative_prompt)
|
||||||
|
|
||||||
if self.model == "SD-2.0-v" and self.sampler_type == SamplerName.PLMS:
|
if self.model == "SD-2.0-v" and self.sampler_type == SamplerName.PLMS:
|
||||||
raise ValueError("PLMS sampler is not supported for SD-2.0-v model.")
|
raise ValueError("PLMS sampler is not supported for SD-2.0-v model.")
|
||||||
|
@ -6,12 +6,10 @@ aimg.
|
|||||||
|
|
||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
from PIL import ImageDraw, ImageFont
|
|
||||||
|
|
||||||
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files
|
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files
|
||||||
|
from imaginairy.animations import make_gif_animation
|
||||||
from imaginairy.enhancers.facecrop import detect_faces
|
from imaginairy.enhancers.facecrop import detect_faces
|
||||||
from imaginairy.img_utils import make_gif_image, pillow_fit_image_within
|
from imaginairy.img_utils import add_caption_to_image, pillow_fit_image_within
|
||||||
from imaginairy.paths import PKG_ROOT
|
|
||||||
|
|
||||||
preserve_head_kwargs = {
|
preserve_head_kwargs = {
|
||||||
"mask_prompt": "head|face",
|
"mask_prompt": "head|face",
|
||||||
@ -202,25 +200,11 @@ def create_surprise_me_images(
|
|||||||
gif_imgs = [simg]
|
gif_imgs = [simg]
|
||||||
for prompt, filename in zip(prompts, generated_filenames):
|
for prompt, filename in zip(prompts, generated_filenames):
|
||||||
gen_img = LazyLoadingImage(filepath=filename)
|
gen_img = LazyLoadingImage(filepath=filename)
|
||||||
draw = ImageDraw.Draw(gen_img)
|
add_caption_to_image(gen_img, prompt.prompt_text)
|
||||||
|
|
||||||
font_size = 16
|
|
||||||
font = ImageFont.truetype(f"{PKG_ROOT}/data/DejaVuSans.ttf", font_size)
|
|
||||||
|
|
||||||
x = 15
|
|
||||||
y = gen_img.height - 15 - font_size
|
|
||||||
|
|
||||||
draw.text(
|
|
||||||
(x, y),
|
|
||||||
prompt.prompt_text,
|
|
||||||
font=font,
|
|
||||||
fill=(255, 255, 255),
|
|
||||||
stroke_width=3,
|
|
||||||
stroke_fill=(0, 0, 0),
|
|
||||||
)
|
|
||||||
gif_imgs.append(gen_img)
|
gif_imgs.append(gen_img)
|
||||||
|
|
||||||
make_gif_image(new_filename, gif_imgs)
|
make_gif_animation(outpath=new_filename, imgs=gif_imgs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -179,3 +179,24 @@ def check_torch_working():
|
|||||||
"CUDA is not working. Make sure you have a GPU and CUDA installed."
|
"CUDA is not working. Make sure you have a GPU and CUDA installed."
|
||||||
) from e
|
) from e
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def frange(start, stop, step):
|
||||||
|
"""Range but handles floats."""
|
||||||
|
x = start
|
||||||
|
while True:
|
||||||
|
if x >= stop:
|
||||||
|
return
|
||||||
|
yield x
|
||||||
|
x += step
|
||||||
|
|
||||||
|
|
||||||
|
def shrink_list(items, max_size):
|
||||||
|
if len(items) <= max_size:
|
||||||
|
return items
|
||||||
|
|
||||||
|
removal_ratio = len(items) / (max_size - 1)
|
||||||
|
new_items = {}
|
||||||
|
for i, item in enumerate(items):
|
||||||
|
new_items[int(i / removal_ratio)] = item
|
||||||
|
return [items[0]] + list(new_items.values())
|
||||||
|
21
tests/test_prompt_schedules.py
Normal file
21
tests/test_prompt_schedules.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from imaginairy.prompt_schedules import parse_schedule_str
|
||||||
|
from imaginairy.utils import frange
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"schedule_str,expected",
|
||||||
|
[
|
||||||
|
("prompt_strength[2:40:1]", ("prompt_strength", list(range(2, 40)))),
|
||||||
|
("prompt_strength[2:40:0.5]", ("prompt_strength", list(frange(2, 40, 0.5)))),
|
||||||
|
("prompt_strength[2,5,10,15]", ("prompt_strength", [2, 5, 10, 15])),
|
||||||
|
(
|
||||||
|
"prompt_strength[red,blue,10,15]",
|
||||||
|
("prompt_strength", ["red", "blue", 10, 15]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_schedule_str(schedule_str, expected):
|
||||||
|
cleaned_schedules = parse_schedule_str(schedule_str)
|
||||||
|
assert cleaned_schedules == expected
|
Loading…
Reference in New Issue
Block a user