2023-05-08 03:37:15 +00:00
|
|
|
# pylama: ignore=W0212
|
|
|
|
import logging
|
|
|
|
from collections import OrderedDict
|
|
|
|
from functools import cached_property
|
|
|
|
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_size(model):
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
if not isinstance(model, nn.Module) and hasattr(model, "model"):
|
|
|
|
model = model.model
|
|
|
|
|
|
|
|
return sum(v.numel() * v.element_size() for v in model.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
def move_model_device(model, device):
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
if not isinstance(model, nn.Module) and hasattr(model, "model"):
|
|
|
|
model = model.model
|
|
|
|
|
|
|
|
return model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryTrackingCache:
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self.memory_usage = 0
|
|
|
|
self._item_memory_usage = {}
|
|
|
|
self._cache = OrderedDict()
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def first_key(self):
|
|
|
|
if self._cache:
|
|
|
|
return next(iter(self._cache))
|
|
|
|
raise KeyError("Empty dictionary")
|
|
|
|
|
|
|
|
def last_key(self):
|
|
|
|
if self._cache:
|
|
|
|
return next(reversed(self._cache))
|
|
|
|
raise KeyError("Empty dictionary")
|
|
|
|
|
|
|
|
def set(self, key, value, memory_usage=None):
|
|
|
|
if key in self._cache:
|
|
|
|
# Subtract old item memory usage if key already exists
|
|
|
|
self.memory_usage -= self._item_memory_usage[key]
|
|
|
|
|
|
|
|
self._cache[key] = value
|
|
|
|
|
|
|
|
# Calculate and store new item memory usage
|
|
|
|
item_memory_usage = max(get_model_size(value), memory_usage)
|
|
|
|
self._item_memory_usage[key] = item_memory_usage
|
|
|
|
self.memory_usage += item_memory_usage
|
|
|
|
|
|
|
|
def pop(self, key):
|
|
|
|
# Subtract item memory usage before deletion
|
|
|
|
self.memory_usage -= self._item_memory_usage[key]
|
|
|
|
del self._item_memory_usage[key]
|
|
|
|
return self._cache.pop(key)
|
|
|
|
|
|
|
|
def move_to_end(self, key, last=True):
|
|
|
|
self._cache.move_to_end(key, last=last)
|
|
|
|
|
|
|
|
def __contains__(self, item):
|
|
|
|
return item in self._cache
|
|
|
|
|
|
|
|
def __delitem__(self, key):
|
|
|
|
self.pop(key)
|
|
|
|
|
|
|
|
def __getitem__(self, item):
|
|
|
|
return self._cache[item]
|
|
|
|
|
|
|
|
def get(self, item):
|
|
|
|
return self._cache.get(item)
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self._cache)
|
|
|
|
|
|
|
|
def __bool__(self):
|
|
|
|
return bool(self._cache)
|
|
|
|
|
|
|
|
def keys(self):
|
|
|
|
return self._cache.keys()
|
|
|
|
|
|
|
|
|
|
|
|
def get_mem_free_total(device):
|
|
|
|
import psutil
|
|
|
|
import torch
|
|
|
|
|
|
|
|
if device.type == "cuda":
|
|
|
|
if not torch.cuda.is_initialized():
|
|
|
|
torch.cuda.init()
|
|
|
|
stats = torch.cuda.memory_stats(device)
|
|
|
|
mem_active = stats["active_bytes.all.current"]
|
|
|
|
mem_reserved = stats["reserved_bytes.all.current"]
|
|
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
|
|
mem_free_torch = mem_reserved - mem_active
|
|
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
|
|
mem_free_total *= 0.9
|
|
|
|
else:
|
|
|
|
# if we don't add a buffer, larger images come out as noise
|
|
|
|
mem_free_total = psutil.virtual_memory().available * 0.6
|
|
|
|
|
|
|
|
return mem_free_total
|
|
|
|
|
|
|
|
|
|
|
|
class GPUModelCache:
|
|
|
|
def __init__(self, max_cpu_memory_gb="80%", max_gpu_memory_gb="95%", device=None):
|
|
|
|
self._device = device
|
|
|
|
if device in ("cpu", "mps"):
|
|
|
|
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
|
|
|
|
max_cpu_memory_gb = 0
|
|
|
|
|
|
|
|
self._max_cpu_memory_gb = max_cpu_memory_gb
|
|
|
|
self._max_gpu_memory_gb = max_gpu_memory_gb
|
|
|
|
self.gpu_cache = MemoryTrackingCache()
|
|
|
|
self.cpu_cache = MemoryTrackingCache()
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
def device(self):
|
|
|
|
import torch
|
|
|
|
|
|
|
|
if self._device is None:
|
|
|
|
self._device = get_device()
|
|
|
|
|
|
|
|
if self._device in ("cpu", "mps"):
|
|
|
|
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
|
|
|
|
self._max_cpu_memory_gb = 0
|
|
|
|
|
|
|
|
return torch.device(self._device)
|
|
|
|
|
|
|
|
def make_gpu_space(self, bytes_to_free):
|
|
|
|
import gc
|
|
|
|
|
|
|
|
import torch.cuda
|
|
|
|
|
|
|
|
mem_free = get_mem_free_total(self.device)
|
|
|
|
logger.debug(
|
|
|
|
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of GPU space. current usage: {self.gpu_cache.memory_usage / (1024 ** 2):.1f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; Max mem: {self.max_gpu_memory / (1024 ** 2):.1f} MB"
|
|
|
|
)
|
|
|
|
while self.gpu_cache and (
|
|
|
|
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
|
|
|
|
or self.gpu_cache.memory_usage + bytes_to_free
|
|
|
|
> get_mem_free_total(self.device)
|
|
|
|
):
|
|
|
|
oldest_gpu_key = self.gpu_cache.first_key()
|
|
|
|
self._move_to_cpu(oldest_gpu_key)
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
if (
|
|
|
|
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
|
|
|
|
or self.gpu_cache.memory_usage + bytes_to_free
|
|
|
|
> get_mem_free_total(self.device)
|
|
|
|
):
|
|
|
|
raise RuntimeError("Unable to make space on GPU")
|
|
|
|
|
|
|
|
def make_cpu_space(self, bytes_to_free):
|
|
|
|
import gc
|
|
|
|
|
|
|
|
import psutil
|
|
|
|
|
|
|
|
mem_free = psutil.virtual_memory().available * 0.8
|
|
|
|
logger.debug(
|
|
|
|
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of RAM space. current usage: {self.cpu_cache.memory_usage / (1024 ** 2):.2f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; max mem: {self.max_cpu_memory / (1024 ** 2):.1f} MB"
|
|
|
|
)
|
|
|
|
while self.cpu_cache and (
|
|
|
|
self.cpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
|
|
|
|
or self.cpu_cache.memory_usage + bytes_to_free
|
|
|
|
> psutil.virtual_memory().available * 0.8
|
|
|
|
):
|
|
|
|
oldest_cpu_key = self.cpu_cache.first_key()
|
|
|
|
logger.debug(f"dropping {oldest_cpu_key} from memory")
|
|
|
|
self.cpu_cache.pop(oldest_cpu_key)
|
|
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
def max_cpu_memory(self):
|
|
|
|
_ = self.device
|
|
|
|
if isinstance(self._max_cpu_memory_gb, str):
|
|
|
|
if self._max_cpu_memory_gb.endswith("%"):
|
|
|
|
import psutil
|
|
|
|
|
|
|
|
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
|
|
|
|
pct_to_use = float(self._max_cpu_memory_gb[:-1]) / 100.0
|
|
|
|
return total_ram_gb * pct_to_use * (1024**3)
|
|
|
|
raise ValueError(
|
|
|
|
f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}"
|
|
|
|
)
|
|
|
|
return self._max_cpu_memory_gb * (1024**3)
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
def max_gpu_memory(self):
|
|
|
|
_ = self.device
|
|
|
|
if isinstance(self._max_gpu_memory_gb, str):
|
|
|
|
if self._max_gpu_memory_gb.endswith("%"):
|
|
|
|
import torch
|
|
|
|
|
|
|
|
if self.device.type == "cuda":
|
|
|
|
device_props = torch.cuda.get_device_properties(0)
|
|
|
|
total_ram_gb = device_props.total_memory / (1024**3)
|
|
|
|
else:
|
|
|
|
import psutil
|
|
|
|
|
|
|
|
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
|
|
|
|
pct_to_use = float(self._max_gpu_memory_gb[:-1]) / 100.0
|
|
|
|
return total_ram_gb * pct_to_use * (1024**3)
|
|
|
|
raise ValueError(
|
|
|
|
f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}"
|
|
|
|
)
|
|
|
|
return self._max_gpu_memory_gb * (1024**3)
|
|
|
|
|
|
|
|
def _move_to_gpu(self, key, model):
|
|
|
|
model_size = get_model_size(model)
|
|
|
|
if self.gpu_cache.memory_usage + model_size > self.max_gpu_memory:
|
|
|
|
if len(self.gpu_cache) == 0:
|
|
|
|
msg = f"GPU cache maximum ({self.max_gpu_memory / (1024 ** 2)} MB) is smaller than the item being cached ({model_size / 1024 ** 2} MB)."
|
|
|
|
raise RuntimeError(msg)
|
|
|
|
self.make_gpu_space(model_size)
|
|
|
|
|
|
|
|
try:
|
|
|
|
model_size = max(self.cpu_cache._item_memory_usage[key], model_size)
|
|
|
|
self.cpu_cache.pop(key)
|
|
|
|
except KeyError:
|
|
|
|
pass
|
|
|
|
logger.debug(f"moving {key} to gpu")
|
|
|
|
move_model_device(model, self.device)
|
|
|
|
|
|
|
|
self.gpu_cache.set(key, value=model, memory_usage=model_size)
|
|
|
|
|
|
|
|
def _move_to_cpu(self, key):
|
|
|
|
import psutil
|
2023-05-15 15:18:37 +00:00
|
|
|
import torch
|
2023-05-08 03:37:15 +00:00
|
|
|
|
|
|
|
memory_usage = self.gpu_cache._item_memory_usage[key]
|
|
|
|
model = self.gpu_cache.pop(key)
|
|
|
|
|
|
|
|
model_size = max(get_model_size(model), memory_usage)
|
|
|
|
self.make_cpu_space(model_size)
|
|
|
|
|
|
|
|
if (
|
|
|
|
self.cpu_cache.memory_usage + model_size < self.max_cpu_memory
|
|
|
|
and self.cpu_cache.memory_usage + model_size
|
|
|
|
< psutil.virtual_memory().available * 0.8
|
|
|
|
):
|
|
|
|
logger.debug(f"moving {key} to cpu")
|
|
|
|
move_model_device(model, torch.device("cpu"))
|
|
|
|
|
|
|
|
self.cpu_cache.set(key, model, memory_usage=model_size)
|
|
|
|
else:
|
|
|
|
logger.debug(f"dropping {key} from memory")
|
|
|
|
|
|
|
|
def get(self, key):
|
|
|
|
import torch
|
|
|
|
|
|
|
|
if key not in self:
|
|
|
|
raise KeyError(f"The key {key} does not exist in the cache")
|
|
|
|
|
|
|
|
if key in self.cpu_cache:
|
|
|
|
if self.device != torch.device("cpu"):
|
|
|
|
self.cpu_cache.move_to_end(key)
|
|
|
|
self._move_to_gpu(key, self.cpu_cache[key])
|
|
|
|
|
|
|
|
if key in self.gpu_cache:
|
|
|
|
self.gpu_cache.move_to_end(key)
|
|
|
|
|
|
|
|
model = self.gpu_cache.get(key)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
|
|
|
return self.get(key)
|
|
|
|
|
|
|
|
def set(self, key, model, memory_usage=0):
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
if (
|
|
|
|
hasattr(model, "model") and isinstance(model.model, nn.Module)
|
|
|
|
) or isinstance(model, nn.Module):
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
raise ValueError("Only nn.Module objects can be cached")
|
|
|
|
|
|
|
|
model_size = max(get_model_size(model), memory_usage)
|
|
|
|
self.make_gpu_space(model_size)
|
|
|
|
self._move_to_gpu(key, model)
|
|
|
|
|
|
|
|
def __contains__(self, key):
|
|
|
|
return key in self.gpu_cache or key in self.cpu_cache
|
|
|
|
|
|
|
|
def keys(self):
|
|
|
|
return list(self.cpu_cache.keys()) + list(self.gpu_cache.keys())
|
|
|
|
|
|
|
|
def stats(self):
|
|
|
|
return {
|
|
|
|
"cpu_cache_count": len(self.cpu_cache),
|
|
|
|
"cpu_cache_memory_usage": self.cpu_cache.memory_usage,
|
|
|
|
"cpu_cache_max_memory": self.max_cpu_memory,
|
|
|
|
"gpu_cache_count": len(self.gpu_cache),
|
|
|
|
"gpu_cache_memory_usage": self.gpu_cache.memory_usage,
|
|
|
|
"gpu_cache_max_memory": self.max_gpu_memory,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryManagedModelWrapper:
|
|
|
|
_mmmw_cache = GPUModelCache()
|
|
|
|
|
|
|
|
def __init__(self, fn, namespace, estimated_ram_size_mb, *args, **kwargs):
|
|
|
|
self._mmmw_fn = fn
|
|
|
|
self._mmmw_args = args
|
|
|
|
self._mmmw_kwargs = kwargs
|
|
|
|
self._mmmw_namespace = namespace
|
|
|
|
self._mmmw_estimated_ram_size_mb = estimated_ram_size_mb
|
|
|
|
self._mmmw_cache_key = (namespace,) + args + tuple(kwargs.items())
|
|
|
|
|
|
|
|
def _mmmw_load_model(self):
|
|
|
|
if self._mmmw_cache_key not in self.__class__._mmmw_cache:
|
|
|
|
logger.debug(f"Loading model: {self._mmmw_cache_key}")
|
|
|
|
self.__class__._mmmw_cache.make_gpu_space(
|
|
|
|
self._mmmw_estimated_ram_size_mb * 1024**2
|
|
|
|
)
|
|
|
|
free_before = get_mem_free_total(self.__class__._mmmw_cache.device)
|
|
|
|
model = self._mmmw_fn(*self._mmmw_args, **self._mmmw_kwargs)
|
|
|
|
move_model_device(model, self.__class__._mmmw_cache.device)
|
|
|
|
free_after = get_mem_free_total(self.__class__._mmmw_cache.device)
|
|
|
|
logger.debug(
|
|
|
|
f"Model loaded: {self._mmmw_cache_key} Used {free_before - free_after}"
|
|
|
|
)
|
|
|
|
self.__class__._mmmw_cache.set(
|
|
|
|
self._mmmw_cache_key,
|
|
|
|
model,
|
|
|
|
memory_usage=self._mmmw_estimated_ram_size_mb * 1024**2,
|
|
|
|
)
|
|
|
|
|
|
|
|
model = self.__class__._mmmw_cache[self._mmmw_cache_key]
|
|
|
|
return model
|
|
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
|
model = self._mmmw_load_model()
|
|
|
|
return getattr(model, name)
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
model = self._mmmw_load_model()
|
|
|
|
return model(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def memory_managed_model(namespace, memory_usage_mb=0):
|
|
|
|
def decorator(fn):
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
return MemoryManagedModelWrapper(
|
|
|
|
fn, namespace, memory_usage_mb, *args, **kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
return decorator
|