You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/log_utils.py

243 lines
7.1 KiB
Python

import logging
import logging.config
import re
import time
import warnings
import torch
from pytorch_lightning import _logger as pytorch_logger
from torchvision.transforms import ToPILImage
from transformers.modeling_utils import logger as modeling_logger
from transformers.utils.logging import _configure_library_root_logger
_CURRENT_LOGGING_CONTEXT = None
logger = logging.getLogger(__name__)
def log_conditioning(conditioning, description):
if _CURRENT_LOGGING_CONTEXT is None:
return
_CURRENT_LOGGING_CONTEXT.log_conditioning(conditioning, description)
def log_latent(latents, description):
if _CURRENT_LOGGING_CONTEXT is None:
return
if latents is None:
return
_CURRENT_LOGGING_CONTEXT.log_latents(latents, description)
def log_img(img, description):
if _CURRENT_LOGGING_CONTEXT is None:
return
_CURRENT_LOGGING_CONTEXT.log_img(img, description)
def log_tensor(t, description=""):
if _CURRENT_LOGGING_CONTEXT is None:
return
_CURRENT_LOGGING_CONTEXT.log_img(t, description)
def increment_step():
if _CURRENT_LOGGING_CONTEXT is None:
return
_CURRENT_LOGGING_CONTEXT.step_count += 1
class TimingContext:
def __init__(self, logging_context, description):
self.logging_context = logging_context
self.description = description
self.start_time = None
def __enter__(self):
self.start_time = time.time()
def __exit__(self, exc_type, exc_value, traceback):
self.logging_context.timings[self.description] = time.time() - self.start_time
class ImageLoggingContext:
def __init__(self, prompt, model, img_callback=None, img_outdir=None):
self.prompt = prompt
self.model = model
self.step_count = 0
self.image_count = 0
self.img_callback = img_callback
self.img_outdir = img_outdir
self.start_ts = time.perf_counter()
self.timings = {}
def __enter__(self):
global _CURRENT_LOGGING_CONTEXT # noqa
_CURRENT_LOGGING_CONTEXT = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_LOGGING_CONTEXT # noqa
_CURRENT_LOGGING_CONTEXT = None
def timing(self, description):
return TimingContext(self, description)
def get_timings(self):
self.timings["total"] = time.perf_counter() - self.start_ts
return self.timings
def log_conditioning(self, conditioning, description):
if not self.img_callback:
return
img = conditioning_to_img(conditioning)
self.img_callback(
img, description, self.image_count, self.step_count, self.prompt
)
def log_latents(self, latents, description):
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
if not self.img_callback:
return
if latents.shape[1] != 4:
# logger.info(f"Didn't save tensor of shape {samples.shape} for {description}")
return
try:
shape_str = ",".join(tuple(latents.shape))
except TypeError:
shape_str = str(latents.shape)
description = f"{description}-{shape_str}"
for img in model_latents_to_pillow_imgs(latents):
self.image_count += 1
self.img_callback(
img, description, self.image_count, self.step_count, self.prompt
)
def log_img(self, img, description):
if not self.img_callback:
return
self.image_count += 1
if isinstance(img, torch.Tensor):
img = ToPILImage()(img.squeeze().cpu().detach())
img = img.copy()
self.img_callback(
img, description, self.image_count, self.step_count, self.prompt
)
def log_tensor(self, t, description=""):
if not self.img_callback:
return
if len(t.shape) == 2:
self.log_img(t, description)
def log_indexed_graph_of_tensor(self):
pass
# def img_callback(self, img, description, step_count, prompt):
# steps_path = os.path.join(self.img_outdir, "steps", f"{self.file_num:08}_S{prompt.seed}")
# os.makedirs(steps_path, exist_ok=True)
# filename = f"{self.file_num:08}_S{prompt.seed}_step{step_count:04}_{filesafe_text(description)[:40]}.jpg"
# destination = os.path.join(steps_path, filename)
# draw = ImageDraw.Draw(img)
# draw.text((10, 10), str(description))
# img.save(destination)
def filesafe_text(t):
return re.sub(r"[^a-zA-Z0-9.,\[\]() -]+", "_", t)[:130]
def conditioning_to_img(conditioning):
return ToPILImage()(conditioning)
class IndentingFormatter(logging.Formatter):
def format(self, record):
s = super().format(record)
if _CURRENT_LOGGING_CONTEXT is not None:
s = f" {s}"
return s
def configure_logging(level="INFO"):
fmt = "%(message)s"
if level == "DEBUG":
fmt = "%(asctime)s [%(levelname)s] %(name)s:%(lineno)d: %(message)s"
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": True,
"formatters": {
"standard": {
"format": fmt,
"class": "imaginairy.log_utils.IndentingFormatter",
},
},
"handlers": {
"default": {
"level": "INFO",
"formatter": "standard",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout", # Default is stderr
},
},
"loggers": {
"": { # root logger
"handlers": ["default"],
"level": "WARNING",
"propagate": False,
},
"imaginairy": {"handlers": ["default"], "level": level, "propagate": False},
"transformers.modeling_utils": {
"handlers": ["default"],
"level": "ERROR",
"propagate": False,
},
},
}
suppress_annoying_logs_and_warnings()
logging.config.dictConfig(LOGGING_CONFIG)
def disable_transformers_custom_logging():
_configure_library_root_logger()
_logger = modeling_logger.parent
_logger.handlers = []
_logger.propagate = True
_logger.setLevel(logging.NOTSET)
modeling_logger.handlers = []
modeling_logger.propagate = True
modeling_logger.setLevel(logging.ERROR)
def disable_pytorch_lighting_custom_logging():
pytorch_logger.setLevel(logging.NOTSET)
def disable_common_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"The operator .*?is not currently supported.*",
)
warnings.filterwarnings(
"ignore", category=UserWarning, message=r"The parameter 'pretrained' is.*"
)
warnings.filterwarnings(
"ignore", category=UserWarning, message=r"Arguments other than a weight.*"
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
def suppress_annoying_logs_and_warnings():
disable_transformers_custom_logging()
disable_pytorch_lighting_custom_logging()
disable_common_warnings()