You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/utils/model_cache.py

381 lines
12 KiB
Python

"""Utilities for managing model caching"""
# pylama: ignore=W0212
import logging
from collections import OrderedDict
from functools import cached_property
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
log = logger.debug
def get_model_size(model):
from torch import nn
if not isinstance(model, nn.Module) and hasattr(model, "model"):
model = model.model
return sum(v.numel() * v.element_size() for v in model.parameters())
def move_model_device(model, device):
from torch import nn
if not isinstance(model, nn.Module) and hasattr(model, "model"):
model = model.model
return model.to(device)
class MemoryTrackingCache:
def __init__(self, *args, **kwargs):
self.memory_usage = 0
self._item_memory_usage = {}
self._cache = OrderedDict()
super().__init__(*args, **kwargs)
def first_key(self):
if self._cache:
return next(iter(self._cache))
raise KeyError("Empty dictionary")
def last_key(self):
if self._cache:
return next(reversed(self._cache))
raise KeyError("Empty dictionary")
def set(self, key, value, memory_usage=None):
if key in self._cache:
# Subtract old item memory usage if key already exists
self.memory_usage -= self._item_memory_usage[key]
self._cache[key] = value
# Calculate and store new item memory usage
item_memory_usage = max(get_model_size(value), memory_usage)
self._item_memory_usage[key] = item_memory_usage
self.memory_usage += item_memory_usage
def pop(self, key):
# Subtract item memory usage before deletion
self.memory_usage -= self._item_memory_usage[key]
del self._item_memory_usage[key]
return self._cache.pop(key)
def move_to_end(self, key, last=True):
self._cache.move_to_end(key, last=last)
def __contains__(self, item):
return item in self._cache
def __delitem__(self, key):
self.pop(key)
def __getitem__(self, item):
return self._cache[item]
def get(self, item):
return self._cache.get(item)
def __len__(self):
return len(self._cache)
def __bool__(self):
return bool(self._cache)
def keys(self):
return self._cache.keys()
def get_mem_free_total(device):
import psutil
import torch
if device.type == "cuda":
if not torch.cuda.is_initialized():
torch.cuda.init()
stats = torch.cuda.memory_stats(device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total *= 0.9
else:
# if we don't add a buffer, larger images come out as noise
mem_free_total = psutil.virtual_memory().available * 0.6
return mem_free_total
class GPUModelCache:
def __init__(self, max_cpu_memory_gb="80%", max_gpu_memory_gb="95%", device=None):
self._device = device
if device in ("cpu", "mps"):
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
max_cpu_memory_gb = 0
self._max_cpu_memory_gb = max_cpu_memory_gb
self._max_gpu_memory_gb = max_gpu_memory_gb
self.gpu_cache = MemoryTrackingCache()
self.cpu_cache = MemoryTrackingCache()
def stats_msg(self):
import psutil
msg = (
f" GPU cache: {len(self.gpu_cache)} items; {self.gpu_cache.memory_usage / (1024 ** 2):.1f} MB; Max: {self.max_gpu_memory / (1024 ** 2):.1f} MB;\n"
f" CPU cache: {len(self.cpu_cache)} items; {self.cpu_cache.memory_usage / (1024 ** 2):.1f} MB; Max: {self.max_cpu_memory / (1024 ** 2):.1f} MB;\n"
f" mem_free_total: {get_mem_free_total(self.device) / (1024 ** 2):.1f} MB; Ram Free: {psutil.virtual_memory().available / (1024 ** 2):.1f} MB;"
)
return msg
@cached_property
def device(self):
import torch
if self._device is None:
self._device = get_device()
if self._device in ("cpu", "mps", "mps:0"):
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
self._max_cpu_memory_gb = 0
return torch.device(self._device)
def make_gpu_space(self, bytes_to_free):
import gc
import torch.cuda
log(self.stats_msg())
log(f"Ensuring {bytes_to_free / (1024 ** 2):.1f} MB of GPU space.")
while self.gpu_cache and (
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
or self.gpu_cache.memory_usage + bytes_to_free
> get_mem_free_total(self.device)
):
oldest_gpu_key = self.gpu_cache.first_key()
self._move_to_cpu(oldest_gpu_key)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if (
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
or self.gpu_cache.memory_usage + bytes_to_free
> get_mem_free_total(self.device)
):
msg = f"Unable to make {bytes_to_free / (1024 ** 2):.1f} MB space on {self.device}. \n{self.stats_msg()}"
raise RuntimeError(msg)
def make_cpu_space(self, bytes_to_free):
import gc
import psutil
log(self.stats_msg())
log(f"Ensuring {bytes_to_free / (1024 ** 2):.1f} MB of RAM space.")
while self.cpu_cache and (
self.cpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
or self.cpu_cache.memory_usage + bytes_to_free
> psutil.virtual_memory().available * 0.8
):
oldest_cpu_key = self.cpu_cache.first_key()
log(f"dropping {oldest_cpu_key} from memory")
self.cpu_cache.pop(oldest_cpu_key)
log(self.stats_msg())
gc.collect()
@cached_property
def max_cpu_memory(self):
_ = self.device
if isinstance(self._max_cpu_memory_gb, str):
if self._max_cpu_memory_gb.endswith("%"):
import psutil
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
pct_to_use = float(self._max_cpu_memory_gb[:-1]) / 100.0
return total_ram_gb * pct_to_use * (1024**3)
msg = f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}"
raise ValueError(msg)
return self._max_cpu_memory_gb * (1024**3)
@cached_property
def max_gpu_memory(self):
_ = self.device
if isinstance(self._max_gpu_memory_gb, str):
if self._max_gpu_memory_gb.endswith("%"):
import torch
if self.device.type == "cuda":
device_props = torch.cuda.get_device_properties(0)
total_ram_gb = device_props.total_memory / (1024**3)
else:
import psutil
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
pct_to_use = float(self._max_gpu_memory_gb[:-1]) / 100.0
return total_ram_gb * pct_to_use * (1024**3)
msg = f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}"
raise ValueError(msg)
return self._max_gpu_memory_gb * (1024**3)
def _move_to_gpu(self, key, model):
model_size = get_model_size(model)
if self.gpu_cache.memory_usage + model_size > self.max_gpu_memory:
if len(self.gpu_cache) == 0:
msg = f"GPU cache maximum ({self.max_gpu_memory / (1024 ** 2)} MB) is smaller than the item being cached ({model_size / 1024 ** 2} MB)."
raise RuntimeError(msg)
self.make_gpu_space(model_size)
try:
model_size = max(self.cpu_cache._item_memory_usage[key], model_size)
self.cpu_cache.pop(key)
log(f"dropping {key} from cpu cache")
except KeyError:
pass
log(f"moving {key} to gpu")
move_model_device(model, self.device)
self.gpu_cache.set(key, value=model, memory_usage=model_size)
def _move_to_cpu(self, key):
import gc
import psutil
import torch
memory_usage = self.gpu_cache._item_memory_usage[key]
model = self.gpu_cache.pop(key)
model_size = max(get_model_size(model), memory_usage)
self.make_cpu_space(model_size)
if (
self.cpu_cache.memory_usage + model_size < self.max_cpu_memory
and self.cpu_cache.memory_usage + model_size
< psutil.virtual_memory().available * 0.8
):
log(f"moving {key} to cpu")
move_model_device(model, torch.device("cpu"))
log(self.stats_msg())
self.cpu_cache.set(key, model, memory_usage=model_size)
else:
log(f"dropping {key} from memory")
del model
gc.collect()
log(self.stats_msg())
def get(self, key):
import torch
if key not in self:
msg = f"The key {key} does not exist in the cache"
raise KeyError(msg)
if key in self.cpu_cache and self.device != torch.device("cpu"):
self.cpu_cache.move_to_end(key)
self._move_to_gpu(key, self.cpu_cache[key])
if key in self.gpu_cache:
self.gpu_cache.move_to_end(key)
model = self.gpu_cache.get(key)
return model
def __getitem__(self, key):
return self.get(key)
def set(self, key, model, memory_usage=0):
from torch import nn
if (
hasattr(model, "model") and isinstance(model.model, nn.Module)
) or isinstance(model, nn.Module):
pass
else:
raise ValueError("Only nn.Module objects can be cached")
model_size = max(get_model_size(model), memory_usage)
self.make_gpu_space(model_size)
self._move_to_gpu(key, model)
def __contains__(self, key):
return key in self.gpu_cache or key in self.cpu_cache
def keys(self):
return list(self.cpu_cache.keys()) + list(self.gpu_cache.keys())
def stats(self):
return {
"cpu_cache_count": len(self.cpu_cache),
"cpu_cache_memory_usage": self.cpu_cache.memory_usage,
"cpu_cache_max_memory": self.max_cpu_memory,
"gpu_cache_count": len(self.gpu_cache),
"gpu_cache_memory_usage": self.gpu_cache.memory_usage,
"gpu_cache_max_memory": self.max_gpu_memory,
}
class MemoryManagedModelWrapper:
_mmmw_cache = GPUModelCache()
def __init__(self, fn, namespace, estimated_ram_size_mb, *args, **kwargs):
self._mmmw_fn = fn
self._mmmw_args = args
self._mmmw_kwargs = kwargs
self._mmmw_namespace = namespace
self._mmmw_estimated_ram_size_mb = estimated_ram_size_mb
self._mmmw_cache_key = (namespace, *args, *tuple(kwargs.items()))
def _mmmw_load_model(self):
if self._mmmw_cache_key not in self.__class__._mmmw_cache:
log(f"Loading model: {self._mmmw_cache_key}")
self.__class__._mmmw_cache.make_gpu_space(
self._mmmw_estimated_ram_size_mb * 1024**2
)
free_before = get_mem_free_total(self.__class__._mmmw_cache.device)
model = self._mmmw_fn(*self._mmmw_args, **self._mmmw_kwargs)
move_model_device(model, self.__class__._mmmw_cache.device)
free_after = get_mem_free_total(self.__class__._mmmw_cache.device)
log(f"Model loaded: {self._mmmw_cache_key} Used {free_after - free_before}")
self.__class__._mmmw_cache.set(
self._mmmw_cache_key,
model,
memory_usage=self._mmmw_estimated_ram_size_mb * 1024**2,
)
model = self.__class__._mmmw_cache[self._mmmw_cache_key]
return model
def __getattr__(self, name):
model = self._mmmw_load_model()
return getattr(model, name)
def __call__(self, *args, **kwargs):
model = self._mmmw_load_model()
return model(*args, **kwargs)
def memory_managed_model(namespace, memory_usage_mb=0):
def decorator(fn):
def wrapper(*args, **kwargs):
return MemoryManagedModelWrapper(
fn, namespace, memory_usage_mb, *args, **kwargs
)
return wrapper
return decorator