tests: "prime" the controlnets

Trying to get things working on m1. doesn't fix everything
pull/333/head
Bryce 1 year ago committed by Bryce Drennan
parent fb19e34acc
commit 926692ad03

@ -18,7 +18,9 @@ def realesrgan_upsampler():
model_path = get_cached_url_path(url)
device = get_device()
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=512, device=device)
upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=512, device=device
)
upsampler.device = torch.device(device)
upsampler.model.to(device)

@ -8,6 +8,9 @@ from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
log = logger.debug
def get_model_size(model):
from torch import nn
@ -119,6 +122,16 @@ class GPUModelCache:
self.gpu_cache = MemoryTrackingCache()
self.cpu_cache = MemoryTrackingCache()
def stats_msg(self):
import psutil
msg = (
f" GPU cache: {len(self.gpu_cache)} items; {self.gpu_cache.memory_usage / (1024 ** 2):.1f} MB; Max: {self.max_gpu_memory / (1024 ** 2):.1f} MB;\n"
f" CPU cache: {len(self.cpu_cache)} items; {self.cpu_cache.memory_usage / (1024 ** 2):.1f} MB; Max: {self.max_cpu_memory / (1024 ** 2):.1f} MB;\n"
f" mem_free_total: {get_mem_free_total(self.device) / (1024 ** 2):.1f} MB; Ram Free: {psutil.virtual_memory().available / (1024 ** 2):.1f} MB;"
)
return msg
@cached_property
def device(self):
import torch
@ -126,7 +139,7 @@ class GPUModelCache:
if self._device is None:
self._device = get_device()
if self._device in ("cpu", "mps"):
if self._device in ("cpu", "mps", "mps:0"):
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
self._max_cpu_memory_gb = 0
@ -137,10 +150,9 @@ class GPUModelCache:
import torch.cuda
mem_free = get_mem_free_total(self.device)
logger.debug(
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of GPU space. current usage: {self.gpu_cache.memory_usage / (1024 ** 2):.1f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; Max mem: {self.max_gpu_memory / (1024 ** 2):.1f} MB"
)
log(self.stats_msg())
log(f"Ensuring {bytes_to_free / (1024 ** 2):.1f} MB of GPU space.")
while self.gpu_cache and (
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
or self.gpu_cache.memory_usage + bytes_to_free
@ -165,18 +177,17 @@ class GPUModelCache:
import psutil
mem_free = psutil.virtual_memory().available * 0.8
logger.debug(
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of RAM space. current usage: {self.cpu_cache.memory_usage / (1024 ** 2):.2f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; max mem: {self.max_cpu_memory / (1024 ** 2):.1f} MB"
)
log(self.stats_msg())
log(f"Ensuring {bytes_to_free / (1024 ** 2):.1f} MB of RAM space.")
while self.cpu_cache and (
self.cpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
or self.cpu_cache.memory_usage + bytes_to_free
> psutil.virtual_memory().available * 0.8
):
oldest_cpu_key = self.cpu_cache.first_key()
logger.debug(f"dropping {oldest_cpu_key} from memory")
log(f"dropping {oldest_cpu_key} from memory")
self.cpu_cache.pop(oldest_cpu_key)
log(self.stats_msg())
gc.collect()
@ -227,14 +238,17 @@ class GPUModelCache:
try:
model_size = max(self.cpu_cache._item_memory_usage[key], model_size)
self.cpu_cache.pop(key)
log(f"dropping {key} from cpu cache")
except KeyError:
pass
logger.debug(f"moving {key} to gpu")
log(f"moving {key} to gpu")
move_model_device(model, self.device)
self.gpu_cache.set(key, value=model, memory_usage=model_size)
def _move_to_cpu(self, key):
import gc
import psutil
import torch
@ -249,12 +263,16 @@ class GPUModelCache:
and self.cpu_cache.memory_usage + model_size
< psutil.virtual_memory().available * 0.8
):
logger.debug(f"moving {key} to cpu")
log(f"moving {key} to cpu")
move_model_device(model, torch.device("cpu"))
log(self.stats_msg())
self.cpu_cache.set(key, model, memory_usage=model_size)
else:
logger.debug(f"dropping {key} from memory")
log(f"dropping {key} from memory")
del model
gc.collect()
log(self.stats_msg())
def get(self, key):
import torch
@ -321,7 +339,7 @@ class MemoryManagedModelWrapper:
def _mmmw_load_model(self):
if self._mmmw_cache_key not in self.__class__._mmmw_cache:
logger.debug(f"Loading model: {self._mmmw_cache_key}")
log(f"Loading model: {self._mmmw_cache_key}")
self.__class__._mmmw_cache.make_gpu_space(
self._mmmw_estimated_ram_size_mb * 1024**2
)
@ -329,9 +347,7 @@ class MemoryManagedModelWrapper:
model = self._mmmw_fn(*self._mmmw_args, **self._mmmw_kwargs)
move_model_device(model, self.__class__._mmmw_cache.device)
free_after = get_mem_free_total(self.__class__._mmmw_cache.device)
logger.debug(
f"Model loaded: {self._mmmw_cache_key} Used {free_before - free_after}"
)
log(f"Model loaded: {self._mmmw_cache_key} Used {free_after - free_before}")
self.__class__._mmmw_cache.set(
self._mmmw_cache_key,
model,

@ -334,6 +334,13 @@ def test_controlnet(filename_base_for_outputs, control_mode):
control_mode=control_mode,
fix_faces=True,
)
prompt.steps = 1
prompt.width = 256
prompt.height = 256
result = next(imagine(prompt))
prompt.steps = 15
prompt.width = 512
prompt.height = 512
result = next(imagine(prompt))
img_path = f"{filename_base_for_outputs}.png"

Loading…
Cancel
Save