|
|
|
@ -38,6 +38,12 @@ def log_img(img, description):
|
|
|
|
|
_CURRENT_LOGGING_CONTEXT.log_img(img, description)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_progress_latent(latent):
|
|
|
|
|
if _CURRENT_LOGGING_CONTEXT is None:
|
|
|
|
|
return
|
|
|
|
|
_CURRENT_LOGGING_CONTEXT.log_progress_latent(latent)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_tensor(t, description=""):
|
|
|
|
|
if _CURRENT_LOGGING_CONTEXT is None:
|
|
|
|
|
return
|
|
|
|
@ -64,15 +70,30 @@ class TimingContext:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageLoggingContext:
|
|
|
|
|
def __init__(self, prompt, model, img_callback=None, img_outdir=None):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
prompt,
|
|
|
|
|
model,
|
|
|
|
|
debug_img_callback=None,
|
|
|
|
|
img_outdir=None,
|
|
|
|
|
progress_img_callback=None,
|
|
|
|
|
progress_img_interval_steps=3,
|
|
|
|
|
progress_img_interval_min_s=0.1,
|
|
|
|
|
):
|
|
|
|
|
self.prompt = prompt
|
|
|
|
|
self.model = model
|
|
|
|
|
self.step_count = 0
|
|
|
|
|
self.image_count = 0
|
|
|
|
|
self.img_callback = img_callback
|
|
|
|
|
self.debug_img_callback = debug_img_callback
|
|
|
|
|
self.img_outdir = img_outdir
|
|
|
|
|
self.progress_img_callback = progress_img_callback
|
|
|
|
|
self.progress_img_interval_steps = progress_img_interval_steps
|
|
|
|
|
self.progress_img_interval_min_s = progress_img_interval_min_s
|
|
|
|
|
|
|
|
|
|
self.start_ts = time.perf_counter()
|
|
|
|
|
self.timings = {}
|
|
|
|
|
self.last_progress_img_ts = 0
|
|
|
|
|
self.last_progress_img_step = -1000
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
global _CURRENT_LOGGING_CONTEXT # noqa
|
|
|
|
@ -91,18 +112,30 @@ class ImageLoggingContext:
|
|
|
|
|
return self.timings
|
|
|
|
|
|
|
|
|
|
def log_conditioning(self, conditioning, description):
|
|
|
|
|
if not self.img_callback:
|
|
|
|
|
if not self.debug_img_callback:
|
|
|
|
|
return
|
|
|
|
|
img = conditioning_to_img(conditioning)
|
|
|
|
|
|
|
|
|
|
self.img_callback(
|
|
|
|
|
self.debug_img_callback(
|
|
|
|
|
img, description, self.image_count, self.step_count, self.prompt
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def log_latents(self, latents, description):
|
|
|
|
|
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
|
|
|
|
|
|
|
|
|
|
if not self.img_callback:
|
|
|
|
|
if "predicted_latent" in description:
|
|
|
|
|
if (
|
|
|
|
|
self.step_count - self.last_progress_img_step
|
|
|
|
|
) > self.progress_img_interval_steps:
|
|
|
|
|
if (
|
|
|
|
|
time.perf_counter() - self.last_progress_img_ts
|
|
|
|
|
> self.progress_img_interval_min_s
|
|
|
|
|
):
|
|
|
|
|
self.log_progress_latent(latents)
|
|
|
|
|
self.last_progress_img_step = self.step_count
|
|
|
|
|
self.last_progress_img_ts = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
if not self.debug_img_callback:
|
|
|
|
|
return
|
|
|
|
|
if latents.shape[1] != 4:
|
|
|
|
|
# logger.info(f"Didn't save tensor of shape {samples.shape} for {description}")
|
|
|
|
@ -115,23 +148,31 @@ class ImageLoggingContext:
|
|
|
|
|
description = f"{description}-{shape_str}"
|
|
|
|
|
for img in model_latents_to_pillow_imgs(latents):
|
|
|
|
|
self.image_count += 1
|
|
|
|
|
self.img_callback(
|
|
|
|
|
self.debug_img_callback(
|
|
|
|
|
img, description, self.image_count, self.step_count, self.prompt
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def log_img(self, img, description):
|
|
|
|
|
if not self.img_callback:
|
|
|
|
|
if not self.debug_img_callback:
|
|
|
|
|
return
|
|
|
|
|
self.image_count += 1
|
|
|
|
|
if isinstance(img, torch.Tensor):
|
|
|
|
|
img = ToPILImage()(img.squeeze().cpu().detach())
|
|
|
|
|
img = img.copy()
|
|
|
|
|
self.img_callback(
|
|
|
|
|
self.debug_img_callback(
|
|
|
|
|
img, description, self.image_count, self.step_count, self.prompt
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def log_progress_latent(self, latent):
|
|
|
|
|
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
|
|
|
|
|
|
|
|
|
|
if not self.progress_img_callback:
|
|
|
|
|
return
|
|
|
|
|
for img in model_latents_to_pillow_imgs(latent):
|
|
|
|
|
self.progress_img_callback(img)
|
|
|
|
|
|
|
|
|
|
def log_tensor(self, t, description=""):
|
|
|
|
|
if not self.img_callback:
|
|
|
|
|
if not self.debug_img_callback:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if len(t.shape) == 2:
|
|
|
|
|