perf: improve memory usage (#433)

add warning for corrupt weights files
pull/434/head
Bryce Drennan 5 months ago committed by GitHub
parent 26d1ff9bc4
commit 77c4b85037
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -310,7 +310,9 @@ def generate_single_image(
condition_scale=prompt.prompt_strength,
**text_conditioning_kwargs,
)
# trying to clear memory. not sure if this helps
sd.unet.set_context(context="self_attention_map", value={})
sd.unet._reset_context()
clear_gpu_cache()
logger.debug("Decoding image")

@ -12,7 +12,7 @@ from imaginairy.vendored.realesrgan import RealESRGANer
@memory_managed_model("realesrgan_upsampler", memory_usage_mb=70)
def realesrgan_upsampler(tile=1024, tile_pad=50, ultrasharp=False):
def realesrgan_upsampler(tile=512, tile_pad=50, ultrasharp=False):
model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
)

@ -674,7 +674,7 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
vae_weights_path = download_diffusers_weights(
base_url=base_url, sub="vae", prefer_fp16=False
)
print(vae_weights_path)
logger.debug(f"vae: {vae_weights_path}")
vae_weights = translator.load_and_translate_weights(
source_path=vae_weights_path,
device="cpu",
@ -684,8 +684,10 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
del vae_weights
translator = translators.diffusers_unet_sdxl_to_refiners_translator()
unet_weights_path = download_diffusers_weights(base_url=base_url, sub="unet")
print(unet_weights_path)
unet_weights_path = download_diffusers_weights(
base_url=base_url, sub="unet", prefer_fp16=True
)
logger.debug(f"unet: {unet_weights_path}")
unet_weights = translator.load_and_translate_weights(
source_path=unet_weights_path,
device="cpu",
@ -700,8 +702,8 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
text_encoder_2_path = download_diffusers_weights(
base_url=base_url, sub="text_encoder_2"
)
print(text_encoder_1_path)
print(text_encoder_2_path)
logger.debug(f"text encoder 1: {text_encoder_1_path}")
logger.debug(f"text encoder 2: {text_encoder_2_path}")
text_encoder_weights = (
translators.DoubleTextEncoderTranslator().load_and_translate_weights(
text_encoder_l_weights_path=text_encoder_1_path,

@ -1,3 +1,4 @@
import logging
import math
import os
import queue
@ -12,6 +13,7 @@ from imaginairy.utils.model_manager import get_cached_url_path
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
logger = logging.getLogger(__name__)
class RealESRGANer:
"""A helper class for upsampling images with RealESRGAN.
@ -146,7 +148,7 @@ class RealESRGANer:
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
logger.debug(f"Tiling with {tiles_x}x{tiles_y} ({tiles_x*tiles_y}) tiles")
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):

@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
from typing import Dict
import torch
from refiners.fluxion import load_from_safetensors
from safetensors import safe_open
from torch import device as Device
logger = logging.getLogger(__name__)
@ -29,7 +29,8 @@ class WeightTranslationMap:
source_weights = torch.load(source_path, map_location="cpu")
elif extension in ["safetensors"]:
source_weights = load_from_safetensors(source_path, device=device)
with safe_open(source_path, framework="pt", device=device) as f: # type: ignore
source_weights = {k: f.get_tensor(k) for k in f.keys()} # noqa
else:
msg = f"Unsupported extension {extension}"
raise ValueError(msg)
@ -79,10 +80,30 @@ class WeightTranslationMap:
return cls(**d)
def check_nan_path(path: str, device):
from safetensors import safe_open
with safe_open(path, framework="pt", device=device) as f: # type: ignore
for k in f.keys(): # noqa
if torch.any(torch.isnan(f.get_tensor(k))):
print(f"Found nan values in {k} of {path}")
def translate_weights(
source_weights: TensorDict, weight_map: WeightTranslationMap
) -> TensorDict:
new_state_dict: TensorDict = {}
# check source weights for nan
for k, v in source_weights.items():
nan_count = torch.sum(torch.isnan(v)).item()
if nan_count:
msg = (
f"Found {nan_count} nan values in {k} of source state dict."
" This could indicate the source weights are corrupted and "
"need to be re-downloaded. "
)
logger.warning(msg)
# print(f"Translating {len(source_weights)} weights")
# print(f"Using {len(weight_map.name_map)} name mappings")
# print(source_weights.keys())
@ -142,7 +163,7 @@ def translate_weights(
if source_weights:
msg = f"Unmapped keys: {list(source_weights.keys())}"
print(msg)
logger.info(msg)
for k in source_weights:
if isinstance(source_weights[k], torch.Tensor):
print(f" {k}: {source_weights[k].shape}")
@ -154,6 +175,15 @@ def translate_weights(
if key in new_state_dict:
new_state_dict[key] = new_state_dict[key].reshape(new_shape)
# check for nan values
for k in list(new_state_dict.keys()):
v = new_state_dict[k]
nan_count = torch.sum(torch.isnan(v)).item()
if nan_count:
logger.warning(
f"Found {nan_count} nan values in {k} of converted state dict."
)
return new_state_dict

Loading…
Cancel
Save