2023-12-15 20:31:28 +00:00
|
|
|
"""Functions for generating and processing images"""
|
|
|
|
|
2022-09-09 04:51:25 +00:00
|
|
|
import logging
|
2022-09-08 03:59:30 +00:00
|
|
|
import os
|
2023-12-21 05:23:13 +00:00
|
|
|
from typing import TYPE_CHECKING, Callable
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2023-12-21 05:23:13 +00:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from imaginairy.schema import ImaginePrompt
|
2022-09-10 07:32:31 +00:00
|
|
|
|
2022-09-09 04:51:25 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-15 02:40:50 +00:00
|
|
|
# leave undocumented. I'd ask that no one publicize this flag. Just want a
|
|
|
|
# slight barrier to entry. Please don't use this is any way that's gonna cause
|
2022-10-10 08:22:11 +00:00
|
|
|
# the media or politicians to freak out about AI...
|
2023-12-08 04:57:55 +00:00
|
|
|
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", "strict")
|
2022-10-10 08:22:11 +00:00
|
|
|
if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
|
2023-12-08 04:57:55 +00:00
|
|
|
IMAGINAIRY_SAFETY_MODE = "relaxed"
|
2022-10-10 08:22:11 +00:00
|
|
|
elif IMAGINAIRY_SAFETY_MODE == "filter":
|
2023-12-08 04:57:55 +00:00
|
|
|
IMAGINAIRY_SAFETY_MODE = "strict"
|
2022-09-11 07:35:57 +00:00
|
|
|
|
2023-01-25 16:55:05 +00:00
|
|
|
# we put this in the global scope so it can be used in the interactive shell
|
|
|
|
_most_recent_result = None
|
|
|
|
|
2022-09-09 04:51:25 +00:00
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
def imagine_image_files(
|
2023-12-08 04:57:55 +00:00
|
|
|
prompts: "list[ImaginePrompt] | ImaginePrompt",
|
|
|
|
outdir: str,
|
|
|
|
precision: str = "autocast",
|
|
|
|
record_step_images: bool = False,
|
|
|
|
output_file_extension: str = "jpg",
|
|
|
|
print_caption: bool = False,
|
|
|
|
make_gif: bool = False,
|
|
|
|
make_compare_gif: bool = False,
|
|
|
|
return_filename_type: str = "generated",
|
|
|
|
videogen: bool = False,
|
2022-09-08 03:59:30 +00:00
|
|
|
):
|
2023-02-03 05:43:04 +00:00
|
|
|
from PIL import ImageDraw
|
|
|
|
|
2023-12-15 21:47:39 +00:00
|
|
|
from imaginairy.api.video_sample import generate_video
|
2023-12-21 05:23:13 +00:00
|
|
|
from imaginairy.utils import get_next_filenumber, prompt_normalized
|
2023-12-15 21:40:10 +00:00
|
|
|
from imaginairy.utils.animations import make_bounce_animation
|
|
|
|
from imaginairy.utils.img_utils import pillow_fit_image_within
|
2023-02-03 05:43:04 +00:00
|
|
|
|
2022-09-24 18:21:53 +00:00
|
|
|
generated_imgs_path = os.path.join(outdir, "generated")
|
|
|
|
os.makedirs(generated_imgs_path, exist_ok=True)
|
2022-09-13 07:27:53 +00:00
|
|
|
|
2023-04-29 02:25:56 +00:00
|
|
|
base_count = get_next_filenumber(generated_imgs_path)
|
2022-09-11 06:27:22 +00:00
|
|
|
output_file_extension = output_file_extension.lower()
|
|
|
|
if output_file_extension not in {"jpg", "png"}:
|
|
|
|
raise ValueError("Must output a png or jpg")
|
2022-09-10 05:14:04 +00:00
|
|
|
|
2023-12-08 04:57:55 +00:00
|
|
|
if not isinstance(prompts, list):
|
|
|
|
prompts = [prompts]
|
|
|
|
|
2022-11-13 03:24:03 +00:00
|
|
|
def _record_step(img, description, image_count, step_count, prompt):
|
2022-09-10 05:14:04 +00:00
|
|
|
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
|
|
|
|
os.makedirs(steps_path, exist_ok=True)
|
2022-11-13 03:24:03 +00:00
|
|
|
filename = f"{base_count:08}_S{prompt.seed}_{image_count:04}_step{step_count:03}_{prompt_normalized(description)[:40]}.jpg"
|
2022-09-20 04:15:38 +00:00
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
destination = os.path.join(steps_path, filename)
|
|
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
draw.text((10, 10), str(description))
|
|
|
|
img.save(destination)
|
|
|
|
|
2023-01-28 01:18:42 +00:00
|
|
|
if make_gif:
|
|
|
|
for p in prompts:
|
|
|
|
p.collect_progress_latents = True
|
2023-01-22 01:36:47 +00:00
|
|
|
result_filenames = []
|
2022-09-13 07:27:53 +00:00
|
|
|
for result in imagine(
|
2022-09-10 05:14:04 +00:00
|
|
|
prompts,
|
|
|
|
precision=precision,
|
2022-11-14 06:51:23 +00:00
|
|
|
debug_img_callback=_record_step if record_step_images else None,
|
2022-09-20 04:15:38 +00:00
|
|
|
add_caption=print_caption,
|
2022-09-10 05:14:04 +00:00
|
|
|
):
|
|
|
|
prompt = result.prompt
|
2023-01-25 16:55:05 +00:00
|
|
|
if prompt.is_intermediate:
|
|
|
|
# we don't save intermediate images
|
|
|
|
continue
|
2022-09-28 00:04:16 +00:00
|
|
|
img_str = ""
|
|
|
|
if prompt.init_image:
|
|
|
|
img_str = f"_img2img-{prompt.init_image_strength}"
|
2023-12-21 05:23:13 +00:00
|
|
|
|
2022-10-22 09:13:06 +00:00
|
|
|
basefilename = (
|
2023-12-08 04:57:55 +00:00
|
|
|
f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_"
|
2022-10-22 09:13:06 +00:00
|
|
|
f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}"
|
|
|
|
)
|
2022-09-26 04:55:25 +00:00
|
|
|
for image_type in result.images:
|
2022-09-25 20:07:27 +00:00
|
|
|
subpath = os.path.join(outdir, image_type)
|
|
|
|
os.makedirs(subpath, exist_ok=True)
|
|
|
|
filepath = os.path.join(
|
|
|
|
subpath, f"{basefilename}_[{image_type}].{output_file_extension}"
|
|
|
|
)
|
2022-09-26 04:55:25 +00:00
|
|
|
result.save(filepath, image_type=image_type)
|
2023-12-31 05:21:49 +00:00
|
|
|
logger.info(f" {image_type:<22} {filepath}")
|
2023-01-22 01:36:47 +00:00
|
|
|
if image_type == return_filename_type:
|
|
|
|
result_filenames.append(filepath)
|
2023-11-28 02:00:56 +00:00
|
|
|
if videogen:
|
|
|
|
try:
|
|
|
|
generate_video(
|
|
|
|
input_path=filepath,
|
|
|
|
)
|
|
|
|
except FileNotFoundError as e:
|
|
|
|
logger.error(str(e))
|
|
|
|
exit(1)
|
2023-01-28 01:18:42 +00:00
|
|
|
|
|
|
|
if make_gif and result.progress_latents:
|
2023-01-22 01:36:47 +00:00
|
|
|
subpath = os.path.join(outdir, "gif")
|
|
|
|
os.makedirs(subpath, exist_ok=True)
|
|
|
|
filepath = os.path.join(subpath, f"{basefilename}.gif")
|
2023-01-28 01:18:42 +00:00
|
|
|
|
2023-09-29 08:13:50 +00:00
|
|
|
frames = [*result.progress_latents, result.images["generated"]]
|
2023-01-29 01:16:47 +00:00
|
|
|
|
2023-01-28 01:18:42 +00:00
|
|
|
if prompt.init_image:
|
|
|
|
resized_init_image = pillow_fit_image_within(
|
|
|
|
prompt.init_image, prompt.width, prompt.height
|
|
|
|
)
|
2023-09-29 08:13:50 +00:00
|
|
|
frames = [resized_init_image, *frames]
|
2023-01-29 05:32:56 +00:00
|
|
|
frames.reverse()
|
|
|
|
make_bounce_animation(
|
|
|
|
imgs=frames,
|
|
|
|
outpath=filepath,
|
|
|
|
start_pause_duration_ms=1500,
|
|
|
|
end_pause_duration_ms=1000,
|
|
|
|
)
|
2023-12-31 05:21:49 +00:00
|
|
|
image_type = "gif"
|
|
|
|
logger.info(f" {image_type:<22} {filepath}")
|
2023-01-29 01:16:47 +00:00
|
|
|
if make_compare_gif and prompt.init_image:
|
|
|
|
subpath = os.path.join(outdir, "gif")
|
|
|
|
os.makedirs(subpath, exist_ok=True)
|
|
|
|
filepath = os.path.join(subpath, f"{basefilename}_[compare].gif")
|
|
|
|
resized_init_image = pillow_fit_image_within(
|
|
|
|
prompt.init_image, prompt.width, prompt.height
|
|
|
|
)
|
2023-01-29 05:32:56 +00:00
|
|
|
frames = [result.images["generated"], resized_init_image]
|
2023-01-29 01:16:47 +00:00
|
|
|
|
|
|
|
make_bounce_animation(
|
2023-01-28 01:18:42 +00:00
|
|
|
imgs=frames,
|
2023-01-29 01:16:47 +00:00
|
|
|
outpath=filepath,
|
2023-01-22 01:36:47 +00:00
|
|
|
)
|
2023-12-31 05:21:49 +00:00
|
|
|
image_type = "gif"
|
|
|
|
logger.info(f" {image_type:<22} {filepath}")
|
2023-01-29 01:16:47 +00:00
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
base_count += 1
|
2022-09-17 05:21:20 +00:00
|
|
|
del result
|
2022-09-10 05:14:04 +00:00
|
|
|
|
2023-01-22 01:36:47 +00:00
|
|
|
return result_filenames
|
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
|
2022-09-13 07:27:53 +00:00
|
|
|
def imagine(
|
2023-12-08 04:57:55 +00:00
|
|
|
prompts: "list[ImaginePrompt] | str | ImaginePrompt",
|
|
|
|
precision: str = "autocast",
|
|
|
|
debug_img_callback: Callable | None = None,
|
|
|
|
progress_img_callback: Callable | None = None,
|
|
|
|
progress_img_interval_steps: int = 3,
|
2022-11-14 06:51:23 +00:00
|
|
|
progress_img_interval_min_s=0.1,
|
2022-09-12 04:32:11 +00:00
|
|
|
half_mode=None,
|
2023-12-08 04:57:55 +00:00
|
|
|
add_caption: bool = False,
|
|
|
|
unsafe_retry_count: int = 1,
|
2022-09-10 05:14:04 +00:00
|
|
|
):
|
2023-02-03 05:43:04 +00:00
|
|
|
import torch.nn
|
|
|
|
|
2023-12-20 20:32:29 +00:00
|
|
|
from imaginairy.api.generate_refiners import generate_single_image
|
2023-02-03 05:43:04 +00:00
|
|
|
from imaginairy.schema import ImaginePrompt
|
|
|
|
from imaginairy.utils import (
|
2023-05-20 23:50:15 +00:00
|
|
|
check_torch_version,
|
2023-02-03 05:43:04 +00:00
|
|
|
fix_torch_group_norm,
|
|
|
|
fix_torch_nn_layer_norm,
|
|
|
|
get_device,
|
|
|
|
platform_appropriate_autocast,
|
|
|
|
)
|
|
|
|
|
2023-05-20 23:50:15 +00:00
|
|
|
check_torch_version()
|
|
|
|
|
2022-09-10 05:14:04 +00:00
|
|
|
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
|
|
|
|
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
|
2022-10-13 05:32:17 +00:00
|
|
|
|
2022-10-24 05:42:17 +00:00
|
|
|
try:
|
|
|
|
num_prompts = str(len(prompts))
|
|
|
|
except TypeError:
|
|
|
|
num_prompts = "?"
|
|
|
|
|
2022-09-22 05:03:12 +00:00
|
|
|
if get_device() == "cpu":
|
2023-12-05 03:01:51 +00:00
|
|
|
logger.warning("Running in CPU mode. It's gonna be slooooooow.")
|
|
|
|
from imaginairy.utils.torch_installer import torch_version_check
|
|
|
|
|
|
|
|
torch_version_check()
|
2022-09-22 05:38:44 +00:00
|
|
|
|
2023-12-03 14:28:04 +00:00
|
|
|
if half_mode is None:
|
|
|
|
half_mode = "cuda" in get_device() or get_device() == "mps"
|
|
|
|
|
2022-09-22 05:38:44 +00:00
|
|
|
with torch.no_grad(), platform_appropriate_autocast(
|
|
|
|
precision
|
2022-09-22 05:03:12 +00:00
|
|
|
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
|
2022-10-24 05:42:17 +00:00
|
|
|
for i, prompt in enumerate(prompts):
|
2023-12-28 05:52:37 +00:00
|
|
|
concrete_prompt = prompt.make_concrete_copy()
|
2023-12-31 05:21:49 +00:00
|
|
|
prog_text = f"{i + 1}/{num_prompts}"
|
|
|
|
logger.info(f"🖼 {prog_text} {concrete_prompt.prompt_description()}")
|
2023-09-29 08:13:50 +00:00
|
|
|
for attempt in range(unsafe_retry_count + 1):
|
2023-12-28 05:52:37 +00:00
|
|
|
if attempt > 0 and isinstance(concrete_prompt.seed, int):
|
|
|
|
concrete_prompt.seed += 100_000_000 + attempt
|
2023-12-20 20:32:29 +00:00
|
|
|
result = generate_single_image(
|
2023-12-28 05:52:37 +00:00
|
|
|
concrete_prompt,
|
2023-01-26 04:58:28 +00:00
|
|
|
debug_img_callback=debug_img_callback,
|
|
|
|
progress_img_callback=progress_img_callback,
|
|
|
|
progress_img_interval_steps=progress_img_interval_steps,
|
|
|
|
progress_img_interval_min_s=progress_img_interval_min_s,
|
|
|
|
add_caption=add_caption,
|
2023-12-03 14:28:04 +00:00
|
|
|
dtype=torch.float16 if half_mode else torch.float32,
|
2024-01-02 02:35:14 +00:00
|
|
|
output_perf=True,
|
2023-01-26 04:58:28 +00:00
|
|
|
)
|
2023-02-15 16:02:58 +00:00
|
|
|
if not result.safety_score.is_filtered:
|
2023-01-26 04:58:28 +00:00
|
|
|
break
|
|
|
|
if attempt < unsafe_retry_count:
|
2023-02-05 15:43:53 +00:00
|
|
|
logger.info(" Image was unsafe, retrying with new seed...")
|
2023-01-26 04:58:28 +00:00
|
|
|
|
|
|
|
yield result
|