|
|
|
@ -8,6 +8,9 @@ from imaginairy.utils import get_device
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log = logger.debug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_size(model):
|
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
@ -119,6 +122,16 @@ class GPUModelCache:
|
|
|
|
|
self.gpu_cache = MemoryTrackingCache()
|
|
|
|
|
self.cpu_cache = MemoryTrackingCache()
|
|
|
|
|
|
|
|
|
|
def stats_msg(self):
|
|
|
|
|
import psutil
|
|
|
|
|
|
|
|
|
|
msg = (
|
|
|
|
|
f" GPU cache: {len(self.gpu_cache)} items; {self.gpu_cache.memory_usage / (1024 ** 2):.1f} MB; Max: {self.max_gpu_memory / (1024 ** 2):.1f} MB;\n"
|
|
|
|
|
f" CPU cache: {len(self.cpu_cache)} items; {self.cpu_cache.memory_usage / (1024 ** 2):.1f} MB; Max: {self.max_cpu_memory / (1024 ** 2):.1f} MB;\n"
|
|
|
|
|
f" mem_free_total: {get_mem_free_total(self.device) / (1024 ** 2):.1f} MB; Ram Free: {psutil.virtual_memory().available / (1024 ** 2):.1f} MB;"
|
|
|
|
|
)
|
|
|
|
|
return msg
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def device(self):
|
|
|
|
|
import torch
|
|
|
|
@ -126,7 +139,7 @@ class GPUModelCache:
|
|
|
|
|
if self._device is None:
|
|
|
|
|
self._device = get_device()
|
|
|
|
|
|
|
|
|
|
if self._device in ("cpu", "mps"):
|
|
|
|
|
if self._device in ("cpu", "mps", "mps:0"):
|
|
|
|
|
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
|
|
|
|
|
self._max_cpu_memory_gb = 0
|
|
|
|
|
|
|
|
|
@ -137,10 +150,9 @@ class GPUModelCache:
|
|
|
|
|
|
|
|
|
|
import torch.cuda
|
|
|
|
|
|
|
|
|
|
mem_free = get_mem_free_total(self.device)
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of GPU space. current usage: {self.gpu_cache.memory_usage / (1024 ** 2):.1f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; Max mem: {self.max_gpu_memory / (1024 ** 2):.1f} MB"
|
|
|
|
|
)
|
|
|
|
|
log(self.stats_msg())
|
|
|
|
|
log(f"Ensuring {bytes_to_free / (1024 ** 2):.1f} MB of GPU space.")
|
|
|
|
|
|
|
|
|
|
while self.gpu_cache and (
|
|
|
|
|
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
|
|
|
|
|
or self.gpu_cache.memory_usage + bytes_to_free
|
|
|
|
@ -165,18 +177,17 @@ class GPUModelCache:
|
|
|
|
|
|
|
|
|
|
import psutil
|
|
|
|
|
|
|
|
|
|
mem_free = psutil.virtual_memory().available * 0.8
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of RAM space. current usage: {self.cpu_cache.memory_usage / (1024 ** 2):.2f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; max mem: {self.max_cpu_memory / (1024 ** 2):.1f} MB"
|
|
|
|
|
)
|
|
|
|
|
log(self.stats_msg())
|
|
|
|
|
log(f"Ensuring {bytes_to_free / (1024 ** 2):.1f} MB of RAM space.")
|
|
|
|
|
while self.cpu_cache and (
|
|
|
|
|
self.cpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
|
|
|
|
|
or self.cpu_cache.memory_usage + bytes_to_free
|
|
|
|
|
> psutil.virtual_memory().available * 0.8
|
|
|
|
|
):
|
|
|
|
|
oldest_cpu_key = self.cpu_cache.first_key()
|
|
|
|
|
logger.debug(f"dropping {oldest_cpu_key} from memory")
|
|
|
|
|
log(f"dropping {oldest_cpu_key} from memory")
|
|
|
|
|
self.cpu_cache.pop(oldest_cpu_key)
|
|
|
|
|
log(self.stats_msg())
|
|
|
|
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
@ -227,14 +238,17 @@ class GPUModelCache:
|
|
|
|
|
try:
|
|
|
|
|
model_size = max(self.cpu_cache._item_memory_usage[key], model_size)
|
|
|
|
|
self.cpu_cache.pop(key)
|
|
|
|
|
log(f"dropping {key} from cpu cache")
|
|
|
|
|
except KeyError:
|
|
|
|
|
pass
|
|
|
|
|
logger.debug(f"moving {key} to gpu")
|
|
|
|
|
log(f"moving {key} to gpu")
|
|
|
|
|
move_model_device(model, self.device)
|
|
|
|
|
|
|
|
|
|
self.gpu_cache.set(key, value=model, memory_usage=model_size)
|
|
|
|
|
|
|
|
|
|
def _move_to_cpu(self, key):
|
|
|
|
|
import gc
|
|
|
|
|
|
|
|
|
|
import psutil
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
@ -249,12 +263,16 @@ class GPUModelCache:
|
|
|
|
|
and self.cpu_cache.memory_usage + model_size
|
|
|
|
|
< psutil.virtual_memory().available * 0.8
|
|
|
|
|
):
|
|
|
|
|
logger.debug(f"moving {key} to cpu")
|
|
|
|
|
log(f"moving {key} to cpu")
|
|
|
|
|
move_model_device(model, torch.device("cpu"))
|
|
|
|
|
log(self.stats_msg())
|
|
|
|
|
|
|
|
|
|
self.cpu_cache.set(key, model, memory_usage=model_size)
|
|
|
|
|
else:
|
|
|
|
|
logger.debug(f"dropping {key} from memory")
|
|
|
|
|
log(f"dropping {key} from memory")
|
|
|
|
|
del model
|
|
|
|
|
gc.collect()
|
|
|
|
|
log(self.stats_msg())
|
|
|
|
|
|
|
|
|
|
def get(self, key):
|
|
|
|
|
import torch
|
|
|
|
@ -321,7 +339,7 @@ class MemoryManagedModelWrapper:
|
|
|
|
|
|
|
|
|
|
def _mmmw_load_model(self):
|
|
|
|
|
if self._mmmw_cache_key not in self.__class__._mmmw_cache:
|
|
|
|
|
logger.debug(f"Loading model: {self._mmmw_cache_key}")
|
|
|
|
|
log(f"Loading model: {self._mmmw_cache_key}")
|
|
|
|
|
self.__class__._mmmw_cache.make_gpu_space(
|
|
|
|
|
self._mmmw_estimated_ram_size_mb * 1024**2
|
|
|
|
|
)
|
|
|
|
@ -329,9 +347,7 @@ class MemoryManagedModelWrapper:
|
|
|
|
|
model = self._mmmw_fn(*self._mmmw_args, **self._mmmw_kwargs)
|
|
|
|
|
move_model_device(model, self.__class__._mmmw_cache.device)
|
|
|
|
|
free_after = get_mem_free_total(self.__class__._mmmw_cache.device)
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Model loaded: {self._mmmw_cache_key} Used {free_before - free_after}"
|
|
|
|
|
)
|
|
|
|
|
log(f"Model loaded: {self._mmmw_cache_key} Used {free_after - free_before}")
|
|
|
|
|
self.__class__._mmmw_cache.set(
|
|
|
|
|
self._mmmw_cache_key,
|
|
|
|
|
model,
|
|
|
|
|