fix: memory management issue

the dtype being used as a cache key wasn't consistent. this caused the model to be loaded twice
pull/408/head
Bryce 6 months ago committed by Bryce Drennan
parent 82c30024c9
commit 0fe3733933

@ -174,6 +174,9 @@ def imagine(
if get_device() == "cpu":
logger.info("Running in CPU mode. it's gonna be slooooooow.")
if half_mode is None:
half_mode = "cuda" in get_device() or get_device() == "mps"
with torch.no_grad(), platform_appropriate_autocast(
precision
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
@ -192,6 +195,7 @@ def imagine(
progress_img_interval_min_s=progress_img_interval_min_s,
half_mode=half_mode,
add_caption=add_caption,
dtype=torch.float16 if half_mode else torch.float32,
)
if not result.safety_score.is_filtered:
break
@ -682,14 +686,19 @@ def _scale_latent(
return latent
def _generate_composition_image(prompt, target_height, target_width, cutoff=512):
def _generate_composition_image(
prompt, target_height, target_width, cutoff=512, dtype=None
):
from PIL import Image
from imaginairy.api_refiners import _generate_single_image
from imaginairy.utils import default, get_default_dtype
if prompt.width <= cutoff and prompt.height <= cutoff:
return None, None
dtype = default(dtype, get_default_dtype)
shrink_scale = calc_scale_to_fit_within(
height=prompt.height,
width=prompt.width,
@ -708,7 +717,7 @@ def _generate_composition_image(prompt, target_height, target_width, cutoff=512)
},
)
result = _generate_single_image(composition_prompt)
result = _generate_single_image(composition_prompt, dtype=dtype)
img = result.images["generated"]
while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image

@ -14,11 +14,12 @@ def _generate_single_image(
progress_img_callback=None,
progress_img_interval_steps=3,
progress_img_interval_min_s=0.1,
half_mode=None,
add_caption=False,
# controlnet, finetune, naive, auto
inpaint_method="finetune",
return_latent=False,
dtype=None,
half_mode=None,
):
import gc
@ -58,6 +59,9 @@ def _generate_single_image(
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import get_device, randn_seeded
if dtype is None:
dtype = torch.float16 if half_mode else torch.float32
get_device()
gc.collect()
torch.cuda.empty_cache()
@ -75,7 +79,7 @@ def _generate_single_image(
weights_location=prompt.model,
config_path=prompt.model_config_path,
control_weights_locations=tuple(control_modes),
half_mode=half_mode,
dtype=dtype,
for_inpainting=for_inpainting and inpaint_method == "finetune",
)
@ -110,6 +114,7 @@ def _generate_single_image(
positive_conditioning=prompt.conditioning,
text_encoder=sd.clip_text_encoder,
)
clip_text_embedding = clip_text_embedding.to(device=sd.device, dtype=sd.dtype)
result_images = {}
progress_latents = []
@ -151,6 +156,7 @@ def _generate_single_image(
init_image_t = pillow_img_to_torch_image(init_image)
init_image_t = init_image_t.to(device=sd.device, dtype=sd.dtype)
init_latent = sd.lda.encode(init_image_t)
shape = init_latent.shape
log_latent(init_latent, "init_latent")
@ -263,6 +269,7 @@ def _generate_single_image(
target_height=init_image.height,
target_width=init_image.width,
cutoff=get_model_default_image_size(prompt.model),
dtype=dtype,
)
else:
comp_image, comp_img_orig = _generate_composition_image(
@ -270,21 +277,20 @@ def _generate_single_image(
target_height=prompt.height,
target_width=prompt.width,
cutoff=get_model_default_image_size(prompt.model),
dtype=dtype,
)
if comp_image is not None:
result_images["composition"] = comp_img_orig
result_images["composition-upscaled"] = comp_image
# noise = noise[:, :, : comp_image.height, : comp_image.shape[3]]
comp_cutoff = 0.60
comp_cutoff = 0.50
first_step = int((prompt.steps) * comp_cutoff)
noise_step = int((prompt.steps - 1) * comp_cutoff)
# noise_step = int(prompt.steps * max(comp_cutoff - 0.05, 0))
# noise_step = max(noise_step, 0)
# noise_step = min(noise_step, prompt.steps - 1)
log_img(comp_image, "comp_image")
log_img(comp_img_orig, "comp_image")
log_img(comp_image, "comp_image_upscaled")
comp_image_t = pillow_img_to_torch_image(comp_image)
comp_image_t = comp_image_t.to(sd.device, dtype=sd.dtype)
init_latent = sd.lda.encode(comp_image_t)
for controlnet, control_image_t in controlnets:
controlnet.set_controlnet_condition(
control_image_t.to(device=sd.device, dtype=sd.dtype)
@ -299,6 +305,7 @@ def _generate_single_image(
raise ValueError(msg)
sd.scheduler.to(device=sd.device, dtype=sd.dtype)
sd.set_num_inference_steps(prompt.steps)
if hasattr(sd, "mask_latents"):
sd.set_inpainting_conditions(
target_image=init_image,
@ -318,6 +325,12 @@ def _generate_single_image(
x = noised_latent
x = x.to(device=sd.device, dtype=sd.dtype)
# if "cuda" in str(sd.lda.device):
# sd.lda.to("cpu")
gc.collect()
torch.cuda.empty_cache()
# print(f"moving unet to {sd.device}")
# sd.unet.to(device=sd.device, dtype=sd.dtype)
for step in tqdm(sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}"):
log_latent(x, "noisy_latent")
x = sd(
@ -327,7 +340,26 @@ def _generate_single_image(
condition_scale=prompt.prompt_strength,
)
# z = sd(
# randn_seeded(seed=prompt.seed, size=[1, 4, 8, 8]).to(
# device=sd.device, dtype=sd.dtype
# ),
# step=step,
# clip_text_embedding=clip_text_embedding,
# condition_scale=prompt.prompt_strength,
# )
if "cuda" in str(sd.unet.device):
# print("moving unet to cpu")
# sd.unet.to("cpu")
gc.collect()
torch.cuda.empty_cache()
logger.debug("Decoding image")
if x.device != sd.lda.device:
sd.lda.to(x.device)
gc.collect()
torch.cuda.empty_cache()
gen_img = sd.lda.decode_latents(x)
if mask_image_orig and init_image:
@ -407,7 +439,11 @@ def _generate_single_image(
def _prompts_to_embeddings(prompts, text_encoder):
import torch
total_weight = sum(wp.weight for wp in prompts)
if str(text_encoder.device) == "cpu":
text_encoder = text_encoder.to(dtype=torch.float32)
conditioning = sum(
text_encoder(wp.text) * (wp.weight / total_weight) for wp in prompts
)

@ -238,7 +238,7 @@ def get_diffusion_model_refiners(
weights_location=iconfig.DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
control_weights_locations=None,
half_mode=None,
dtype=None,
for_inpainting=False,
for_training=False,
):
@ -251,8 +251,8 @@ def get_diffusion_model_refiners(
return _get_diffusion_model_refiners(
weights_location,
config_path,
half_mode,
for_inpainting,
dtype=dtype,
control_weights_locations=control_weights_locations,
for_training=for_training,
)
@ -264,7 +264,7 @@ def get_diffusion_model_refiners(
return _get_diffusion_model_refiners(
iconfig.DEFAULT_MODEL,
config_path,
half_mode,
dtype=dtype,
for_inpainting=False,
for_training=for_training,
control_weights_locations=control_weights_locations,
@ -275,7 +275,6 @@ def get_diffusion_model_refiners(
def _get_diffusion_model_refiners(
weights_location=iconfig.DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
half_mode=None,
for_inpainting=False,
for_training=False,
control_weights_locations=None,
@ -401,11 +400,8 @@ def load_controlnet_adapter(
control_weights_location,
target_unet,
scale=1.0,
half_mode=False,
):
controlnet_state_dict = load_state_dict(
control_weights_location, half_mode=half_mode
)
controlnet_state_dict = load_state_dict(control_weights_location, half_mode=False)
controlnet_state_dict = cast_weights(
source_weights=controlnet_state_dict,
source_model_name="controlnet-1-1",

@ -26,6 +26,18 @@ def get_device() -> str:
return "cpu"
@lru_cache
def get_default_dtype():
"""Return the default dtype for torch."""
if get_device() == "cuda":
return torch.float16
if get_device() == "mps":
return torch.float16
return torch.float32
@lru_cache
def get_hardware_description(device_type: str) -> str:
"""Description of the hardware being used."""

Binary file not shown.

Before

Width:  |  Height:  |  Size: 522 KiB

After

Width:  |  Height:  |  Size: 499 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 534 KiB

After

Width:  |  Height:  |  Size: 519 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 428 KiB

After

Width:  |  Height:  |  Size: 420 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 543 KiB

After

Width:  |  Height:  |  Size: 536 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 557 KiB

After

Width:  |  Height:  |  Size: 556 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.4 MiB

After

Width:  |  Height:  |  Size: 2.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 418 KiB

After

Width:  |  Height:  |  Size: 417 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 397 KiB

After

Width:  |  Height:  |  Size: 358 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 283 KiB

After

Width:  |  Height:  |  Size: 280 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 475 KiB

After

Width:  |  Height:  |  Size: 471 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 364 KiB

After

Width:  |  Height:  |  Size: 381 KiB

@ -27,4 +27,4 @@ def test_control_images(filename_base_for_outputs, control_func, control_name):
control_img = control_img_to_pillow_img(control_t)
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(control_img, img_path, threshold=4000)
assert_image_similar_to_expectation(control_img, img_path, threshold=8000)

@ -305,7 +305,7 @@ def test_tile_mode(filename_base_for_outputs):
result = next(imagine(prompt))
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=25000)
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=26000)
control_modes = list(CONTROL_MODES.keys())
@ -353,10 +353,13 @@ def test_controlnet(filename_base_for_outputs, control_mode):
result = next(imagine(prompt))
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=24000)
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=25000)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@pytest.mark.skipif(
get_device() in {"cpu", "mps"},
reason="Too slow to run on CPU. Too much memory for MPS",
)
def test_large_image(filename_base_for_outputs):
prompt_text = "a stormy ocean. oil painting"
prompt = ImaginePrompt(

@ -31,7 +31,7 @@ def test_imagine_cmd(monkeypatch):
f"{TESTS_FOLDER}/test_output",
],
)
assert result.exit_code == 0, (result.stdout, result.stderr)
assert result.exit_code == 0, result.stdout
def test_edit_cmd(monkeypatch):
@ -51,7 +51,7 @@ def test_edit_cmd(monkeypatch):
f"{TESTS_FOLDER}/test_output",
],
)
assert result.exit_code == 0, (result.stdout, result.stderr)
assert result.exit_code == 0, result.stdout
def test_aimg_shell():

@ -141,6 +141,7 @@ def test_describe_picture():
"a painting of a girl with a pearl earring wearing a yellow dress and a pearl earring in her ear and a black background",
"a painting of a girl with a pearl ear wearing a yellow dress and a pearl earring on her left ear and a black background",
"a painting of a woman with a pearl ear wearing an ornament pearl earring and wearing an orange, white, blue and yellow dress",
"a painting of a woman with a pearl earring looking to her left, in profile with her right eye partially closed, standing upright",
}

Loading…
Cancel
Save