feature: progress image callback

pull/96/head
Bryce 2 years ago committed by Bryce Drennan
parent 411f359f4e
commit 893b041a8f

@ -352,10 +352,11 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- ✅ add tests
- ✅ set up ci (test/lint/format)
- ✅ unified pipeline (txt2img & img2img combined)
- setup parallel testing
- setup parallel testing
- add docs
- remove yaml config
- delete more unused code
- faster latent logging https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/9
- Interface improvements
- ✅ init-image at command line
- ✅ prompt expansion

@ -74,7 +74,7 @@ def imagine_image_files(
for result in imagine(
prompts,
precision=precision,
img_callback=_record_step if record_step_images else None,
debug_img_callback=_record_step if record_step_images else None,
add_caption=print_caption,
):
prompt = result.prompt
@ -101,7 +101,10 @@ def imagine_image_files(
def imagine(
prompts,
precision="autocast",
img_callback=None,
debug_img_callback=None,
progress_img_callback=None,
progress_img_interval_steps=3,
progress_img_interval_min_s=0.1,
half_mode=None,
add_caption=False,
):
@ -135,7 +138,10 @@ def imagine(
with ImageLoggingContext(
prompt=prompt,
model=model,
img_callback=img_callback,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
) as lc:
seed_everything(prompt.seed)
model.tile_mode(prompt.tile_mode)

@ -38,6 +38,12 @@ def log_img(img, description):
_CURRENT_LOGGING_CONTEXT.log_img(img, description)
def log_progress_latent(latent):
if _CURRENT_LOGGING_CONTEXT is None:
return
_CURRENT_LOGGING_CONTEXT.log_progress_latent(latent)
def log_tensor(t, description=""):
if _CURRENT_LOGGING_CONTEXT is None:
return
@ -64,15 +70,30 @@ class TimingContext:
class ImageLoggingContext:
def __init__(self, prompt, model, img_callback=None, img_outdir=None):
def __init__(
self,
prompt,
model,
debug_img_callback=None,
img_outdir=None,
progress_img_callback=None,
progress_img_interval_steps=3,
progress_img_interval_min_s=0.1,
):
self.prompt = prompt
self.model = model
self.step_count = 0
self.image_count = 0
self.img_callback = img_callback
self.debug_img_callback = debug_img_callback
self.img_outdir = img_outdir
self.progress_img_callback = progress_img_callback
self.progress_img_interval_steps = progress_img_interval_steps
self.progress_img_interval_min_s = progress_img_interval_min_s
self.start_ts = time.perf_counter()
self.timings = {}
self.last_progress_img_ts = 0
self.last_progress_img_step = -1000
def __enter__(self):
global _CURRENT_LOGGING_CONTEXT # noqa
@ -91,18 +112,30 @@ class ImageLoggingContext:
return self.timings
def log_conditioning(self, conditioning, description):
if not self.img_callback:
if not self.debug_img_callback:
return
img = conditioning_to_img(conditioning)
self.img_callback(
self.debug_img_callback(
img, description, self.image_count, self.step_count, self.prompt
)
def log_latents(self, latents, description):
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
if not self.img_callback:
if "predicted_latent" in description:
if (
self.step_count - self.last_progress_img_step
) > self.progress_img_interval_steps:
if (
time.perf_counter() - self.last_progress_img_ts
> self.progress_img_interval_min_s
):
self.log_progress_latent(latents)
self.last_progress_img_step = self.step_count
self.last_progress_img_ts = time.perf_counter()
if not self.debug_img_callback:
return
if latents.shape[1] != 4:
# logger.info(f"Didn't save tensor of shape {samples.shape} for {description}")
@ -115,23 +148,31 @@ class ImageLoggingContext:
description = f"{description}-{shape_str}"
for img in model_latents_to_pillow_imgs(latents):
self.image_count += 1
self.img_callback(
self.debug_img_callback(
img, description, self.image_count, self.step_count, self.prompt
)
def log_img(self, img, description):
if not self.img_callback:
if not self.debug_img_callback:
return
self.image_count += 1
if isinstance(img, torch.Tensor):
img = ToPILImage()(img.squeeze().cpu().detach())
img = img.copy()
self.img_callback(
self.debug_img_callback(
img, description, self.image_count, self.step_count, self.prompt
)
def log_progress_latent(self, latent):
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
if not self.progress_img_callback:
return
for img in model_latents_to_pillow_imgs(latent):
self.progress_img_callback(img)
def log_tensor(self, t, description=""):
if not self.img_callback:
if not self.debug_img_callback:
return
if len(t.shape) == 2:

@ -175,7 +175,7 @@ def test_img_to_img_fruit_2_gold_repeat():
ImaginePrompt(**kwargs),
ImaginePrompt(**kwargs),
]
for result in imagine(prompts, img_callback=None):
for result in imagine(prompts, debug_img_callback=None):
result.img.save(
f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_plms_{get_device()}_run-{run_count:02}.jpg"
)

Loading…
Cancel
Save