imaginAIry/imaginairy/api/generate.py

284 lines
11 KiB
Python
Raw Normal View History

"""Functions for generating and processing images"""
import logging
2022-09-08 03:59:30 +00:00
import os
from pathlib import Path
2023-12-21 05:23:13 +00:00
from typing import TYPE_CHECKING, Callable
2022-09-08 03:59:30 +00:00
2023-12-21 05:23:13 +00:00
if TYPE_CHECKING:
from imaginairy.schema import ImaginePrompt
logger = logging.getLogger(__name__)
2022-09-08 03:59:30 +00:00
# leave undocumented. I'd ask that no one publicize this flag. Just want a
# slight barrier to entry. Please don't use this is any way that's gonna cause
# the media or politicians to freak out about AI...
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", "strict")
if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
IMAGINAIRY_SAFETY_MODE = "relaxed"
elif IMAGINAIRY_SAFETY_MODE == "filter":
IMAGINAIRY_SAFETY_MODE = "strict"
2022-09-11 07:35:57 +00:00
# we put this in the global scope so it can be used in the interactive shell
_most_recent_result = None
def imagine_image_files(
prompts: "list[ImaginePrompt] | ImaginePrompt",
outdir: str,
precision: str = "autocast",
record_step_images: bool = False,
output_file_extension: str = "jpg",
print_caption: bool = False,
make_gif: bool = False,
make_compare_gif: bool = False,
return_filename_type: str = "generated",
videogen: bool = False,
2022-09-08 03:59:30 +00:00
):
"""
Generates and saves image files based on given prompts, with options for animations and videos.
Args:
prompts (list[ImaginePrompt] | ImaginePrompt): A prompt or list of prompts for image generation.
outdir (str): Directory path where the generated images and other files will be saved.
precision (str, optional): Precision mode for image generation, defaults to 'autocast'.
record_step_images (bool, optional): If True, saves step-by-step images of the generation process, defaults to False.
output_file_extension (str, optional): File extension for output images, must be 'jpg' or 'png', defaults to 'jpg'.
print_caption (bool, optional): If True, prints captions on the generated images, defaults to False.
make_gif (bool, optional): If True, creates a GIF from the generation steps, defaults to False.
make_compare_gif (bool, optional): If True, creates a comparison GIF with initial and generated images, defaults to False.
return_filename_type (str, optional): Type of filenames to return, defaults to 'generated'.
videogen (bool, optional): If True, generates a video from the generated images, defaults to False.
Returns:
list[str]: A list of filenames of the generated images.
Raises:
ValueError: If the output file extension is not 'jpg' or 'png'.
"""
from PIL import ImageDraw
2023-12-15 21:47:39 +00:00
from imaginairy.api.video_sample import generate_video
2023-12-21 05:23:13 +00:00
from imaginairy.utils import get_next_filenumber, prompt_normalized
2022-09-24 18:21:53 +00:00
generated_imgs_path = os.path.join(outdir, "generated")
os.makedirs(generated_imgs_path, exist_ok=True)
base_count = get_next_filenumber(generated_imgs_path)
output_file_extension = output_file_extension.lower()
if output_file_extension not in {"jpg", "png"}:
raise ValueError("Must output a png or jpg")
if not isinstance(prompts, list):
prompts = [prompts]
def _record_step(img, description, image_count, step_count, prompt):
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
os.makedirs(steps_path, exist_ok=True)
filename = f"{base_count:08}_S{prompt.seed}_{image_count:04}_step{step_count:03}_{prompt_normalized(description)[:40]}.jpg"
destination = os.path.join(steps_path, filename)
draw = ImageDraw.Draw(img)
draw.text((10, 10), str(description))
img.save(destination)
if make_gif:
for p in prompts:
p.collect_progress_latents = True
result_filenames = []
for result in imagine(
prompts,
precision=precision,
2022-11-14 06:51:23 +00:00
debug_img_callback=_record_step if record_step_images else None,
add_caption=print_caption,
):
primary_filename = save_image_result(
result,
base_count,
outdir=outdir,
output_file_extension=output_file_extension,
primary_filename_type=return_filename_type,
make_gif=make_gif,
make_compare_gif=make_compare_gif,
)
if not primary_filename:
continue
result_filenames.append(primary_filename)
if primary_filename and videogen:
try:
generate_video(
input_path=primary_filename,
)
except FileNotFoundError as e:
logger.error(str(e))
exit(1)
base_count += 1
del result
2023-12-21 05:23:13 +00:00
return result_filenames
def save_image_result(
result,
base_count: int,
outdir: str | Path,
output_file_extension: str,
primary_filename_type,
make_gif=False,
make_compare_gif=False,
):
from imaginairy.utils import prompt_normalized
from imaginairy.utils.animations import make_bounce_animation
from imaginairy.utils.img_utils import pillow_fit_image_within
prompt = result.prompt
if prompt.is_intermediate:
# we don't save intermediate images
return
img_str = ""
if prompt.init_image:
img_str = f"_img2img-{prompt.init_image_strength}"
basefilename = (
f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_"
f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}"
)
primary_filename = None
for image_type in result.images:
subpath = os.path.join(outdir, image_type)
os.makedirs(subpath, exist_ok=True)
filepath = os.path.join(
subpath, f"{basefilename}_[{image_type}].{output_file_extension}"
2022-10-22 09:13:06 +00:00
)
result.save(filepath, image_type=image_type)
logger.info(f" {image_type:<22} {filepath}")
if image_type == primary_filename_type:
primary_filename = filepath
if make_gif and result.progress_latents:
subpath = os.path.join(outdir, "gif")
os.makedirs(subpath, exist_ok=True)
filepath = os.path.join(subpath, f"{basefilename}.gif")
frames = [*result.progress_latents, result.images["generated"]]
if prompt.init_image:
resized_init_image = pillow_fit_image_within(
prompt.init_image, prompt.width, prompt.height
)
frames = [resized_init_image, *frames]
frames.reverse()
make_bounce_animation(
imgs=frames,
outpath=filepath,
start_pause_duration_ms=1500,
end_pause_duration_ms=1000,
)
image_type = "gif"
logger.info(f" {image_type:<22} {filepath}")
if make_compare_gif and prompt.init_image:
subpath = os.path.join(outdir, "gif")
os.makedirs(subpath, exist_ok=True)
filepath = os.path.join(subpath, f"{basefilename}_[compare].gif")
resized_init_image = pillow_fit_image_within(
prompt.init_image, prompt.width, prompt.height
)
frames = [result.images["generated"], resized_init_image]
make_bounce_animation(
imgs=frames,
outpath=filepath,
)
image_type = "gif"
logger.info(f" {image_type:<22} {filepath}")
return primary_filename
def imagine(
prompts: "list[ImaginePrompt] | str | ImaginePrompt",
precision: str = "autocast",
debug_img_callback: Callable | None = None,
progress_img_callback: Callable | None = None,
progress_img_interval_steps: int = 3,
2022-11-14 06:51:23 +00:00
progress_img_interval_min_s=0.1,
half_mode=None,
add_caption: bool = False,
unsafe_retry_count: int = 1,
):
"""
Generates images based on the provided prompts using the ImaginAIry API.
Args:
prompts (list[ImaginePrompt] | str | ImaginePrompt): A prompt or list of prompts for image generation.
Can be a string, a single ImaginePrompt instance, or a list of ImaginePrompt instances.
precision (str, optional): The precision mode for image generation, defaults to 'autocast'.
debug_img_callback (Callable, optional): Callback function for debugging images, defaults to None.
progress_img_callback (Callable, optional): Callback function called at intervals with progress images, defaults to None.
progress_img_interval_steps (int, optional): Number of steps between each progress image callback, defaults to 3.
progress_img_interval_min_s (float, optional): Minimum seconds between each progress image callback, defaults to 0.1.
half_mode: If set, determines whether to use half precision mode for image generation, defaults to None.
add_caption (bool, optional): Flag to add captions to the generated images, defaults to False.
unsafe_retry_count (int, optional): Number of retries for generating an image if it is deemed unsafe, defaults to 1.
Yields:
The generated image(s) based on the provided prompts.
"""
import torch.nn
from imaginairy.api.generate_refiners import generate_single_image
from imaginairy.schema import ImaginePrompt
from imaginairy.utils import (
check_torch_version,
fix_torch_group_norm,
fix_torch_nn_layer_norm,
get_device,
)
check_torch_version()
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
2022-10-13 05:32:17 +00:00
try:
num_prompts = str(len(prompts))
except TypeError:
num_prompts = "?"
if get_device() == "cpu":
logger.warning("Running in CPU mode. It's gonna be slooooooow.")
from imaginairy.utils.torch_installer import torch_version_check
torch_version_check()
if half_mode is None:
half_mode = "cuda" in get_device() or get_device() == "mps"
2024-01-05 14:34:17 +00:00
with torch.no_grad(), fix_torch_nn_layer_norm(), fix_torch_group_norm():
for i, prompt in enumerate(prompts):
concrete_prompt = prompt.make_concrete_copy()
prog_text = f"{i + 1}/{num_prompts}"
logger.info(f"🖼 {prog_text} {concrete_prompt.prompt_description()}")
for attempt in range(unsafe_retry_count + 1):
if attempt > 0 and isinstance(concrete_prompt.seed, int):
concrete_prompt.seed += 100_000_000 + attempt
result = generate_single_image(
concrete_prompt,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
add_caption=add_caption,
dtype=torch.float16 if half_mode else torch.float32,
output_perf=True,
)
2023-02-15 16:02:58 +00:00
if not result.safety_score.is_filtered:
break
if attempt < unsafe_retry_count:
logger.info(" Image was unsafe, retrying with new seed...")
yield result