feature: improvements to memory management
not thoroughly tested on low-memory devicespull/333/head
parent
6db296aa37
commit
4c77fd376b
@ -0,0 +1,91 @@
|
||||
"""Most of these modifications are just so we get full stack traces in the shell."""
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
import traceback
|
||||
from functools import update_wrapper
|
||||
|
||||
import click
|
||||
from click_help_colors import HelpColorsCommand, HelpColorsMixin
|
||||
from click_shell import Shell
|
||||
from click_shell._compat import get_method_type
|
||||
from click_shell.core import ClickShell, get_complete, get_help
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def mod_get_invoke(command):
|
||||
"""
|
||||
Get the Cmd main method from the click command
|
||||
:param command: The click Command object
|
||||
:return: the do_* method for Cmd
|
||||
:rtype: function.
|
||||
"""
|
||||
|
||||
assert isinstance(command, click.Command)
|
||||
|
||||
def invoke_(self, arg): # pylint: disable=unused-argument
|
||||
try:
|
||||
command.main(
|
||||
args=shlex.split(arg),
|
||||
prog_name=command.name,
|
||||
standalone_mode=False,
|
||||
parent=self.ctx,
|
||||
)
|
||||
except click.ClickException as e:
|
||||
# Show the error message
|
||||
e.show()
|
||||
except click.Abort:
|
||||
# We got an EOF or Keyboard interrupt. Just silence it
|
||||
pass
|
||||
except SystemExit:
|
||||
# Catch this an return the code instead. All of click's help commands do a sys.exit(),
|
||||
# and that's not ideal when running in a shell.
|
||||
pass
|
||||
except Exception as e: # noqa
|
||||
traceback.print_exception(e)
|
||||
# logger.warning(traceback.format_exc())
|
||||
|
||||
# Always return False so the shell doesn't exit
|
||||
return False
|
||||
|
||||
invoke_ = update_wrapper(invoke_, command.callback)
|
||||
invoke_.__name__ = "do_%s" % command.name # noqa
|
||||
return invoke_
|
||||
|
||||
|
||||
class ModClickShell(ClickShell):
|
||||
def add_command(self, cmd, name):
|
||||
# Use the MethodType to add these as bound methods to our current instance
|
||||
setattr(
|
||||
self, "do_%s" % name, get_method_type(mod_get_invoke(cmd), self) # noqa
|
||||
)
|
||||
setattr(self, "help_%s" % name, get_method_type(get_help(cmd), self)) # noqa
|
||||
setattr(
|
||||
self, "complete_%s" % name, get_method_type(get_complete(cmd), self) # noqa
|
||||
)
|
||||
|
||||
|
||||
class ModShell(Shell):
|
||||
def __init__(
|
||||
self, prompt=None, intro=None, hist_file=None, on_finished=None, **attrs
|
||||
):
|
||||
attrs["invoke_without_command"] = True
|
||||
super(Shell, self).__init__(**attrs)
|
||||
|
||||
# Make our shell
|
||||
self.shell = ModClickShell(hist_file=hist_file, on_finished=on_finished)
|
||||
if prompt:
|
||||
self.shell.prompt = prompt
|
||||
self.shell.intro = intro
|
||||
|
||||
|
||||
class ColorShell(HelpColorsMixin, ModShell):
|
||||
pass
|
||||
|
||||
|
||||
class ImagineColorsCommand(HelpColorsCommand):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.help_headers_color = "yellow"
|
||||
self.help_options_color = "green"
|
@ -0,0 +1,363 @@
|
||||
# pylama: ignore=W0212
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from functools import cached_property
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_model_size(model):
|
||||
from torch import nn
|
||||
|
||||
if not isinstance(model, nn.Module) and hasattr(model, "model"):
|
||||
model = model.model
|
||||
|
||||
return sum(v.numel() * v.element_size() for v in model.parameters())
|
||||
|
||||
|
||||
def move_model_device(model, device):
|
||||
from torch import nn
|
||||
|
||||
if not isinstance(model, nn.Module) and hasattr(model, "model"):
|
||||
model = model.model
|
||||
|
||||
return model.to(device)
|
||||
|
||||
|
||||
class MemoryTrackingCache:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.memory_usage = 0
|
||||
self._item_memory_usage = {}
|
||||
self._cache = OrderedDict()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def first_key(self):
|
||||
if self._cache:
|
||||
return next(iter(self._cache))
|
||||
raise KeyError("Empty dictionary")
|
||||
|
||||
def last_key(self):
|
||||
if self._cache:
|
||||
return next(reversed(self._cache))
|
||||
raise KeyError("Empty dictionary")
|
||||
|
||||
def set(self, key, value, memory_usage=None):
|
||||
if key in self._cache:
|
||||
# Subtract old item memory usage if key already exists
|
||||
self.memory_usage -= self._item_memory_usage[key]
|
||||
|
||||
self._cache[key] = value
|
||||
|
||||
# Calculate and store new item memory usage
|
||||
item_memory_usage = max(get_model_size(value), memory_usage)
|
||||
self._item_memory_usage[key] = item_memory_usage
|
||||
self.memory_usage += item_memory_usage
|
||||
|
||||
def pop(self, key):
|
||||
# Subtract item memory usage before deletion
|
||||
self.memory_usage -= self._item_memory_usage[key]
|
||||
del self._item_memory_usage[key]
|
||||
return self._cache.pop(key)
|
||||
|
||||
def move_to_end(self, key, last=True):
|
||||
self._cache.move_to_end(key, last=last)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._cache
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.pop(key)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._cache[item]
|
||||
|
||||
def get(self, item):
|
||||
return self._cache.get(item)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cache)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self._cache)
|
||||
|
||||
def keys(self):
|
||||
return self._cache.keys()
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
if device.type == "cuda":
|
||||
if not torch.cuda.is_initialized():
|
||||
torch.cuda.init()
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats["active_bytes.all.current"]
|
||||
mem_reserved = stats["reserved_bytes.all.current"]
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
mem_free_total *= 0.9
|
||||
else:
|
||||
# if we don't add a buffer, larger images come out as noise
|
||||
mem_free_total = psutil.virtual_memory().available * 0.6
|
||||
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class GPUModelCache:
|
||||
def __init__(self, max_cpu_memory_gb="80%", max_gpu_memory_gb="95%", device=None):
|
||||
self._device = device
|
||||
if device in ("cpu", "mps"):
|
||||
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
|
||||
max_cpu_memory_gb = 0
|
||||
|
||||
self._max_cpu_memory_gb = max_cpu_memory_gb
|
||||
self._max_gpu_memory_gb = max_gpu_memory_gb
|
||||
self.gpu_cache = MemoryTrackingCache()
|
||||
self.cpu_cache = MemoryTrackingCache()
|
||||
|
||||
@cached_property
|
||||
def device(self):
|
||||
import torch
|
||||
|
||||
if self._device is None:
|
||||
self._device = get_device()
|
||||
|
||||
if self._device in ("cpu", "mps"):
|
||||
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
|
||||
self._max_cpu_memory_gb = 0
|
||||
|
||||
return torch.device(self._device)
|
||||
|
||||
def make_gpu_space(self, bytes_to_free):
|
||||
import gc
|
||||
|
||||
import torch.cuda
|
||||
|
||||
mem_free = get_mem_free_total(self.device)
|
||||
logger.debug(
|
||||
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of GPU space. current usage: {self.gpu_cache.memory_usage / (1024 ** 2):.1f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; Max mem: {self.max_gpu_memory / (1024 ** 2):.1f} MB"
|
||||
)
|
||||
while self.gpu_cache and (
|
||||
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
|
||||
or self.gpu_cache.memory_usage + bytes_to_free
|
||||
> get_mem_free_total(self.device)
|
||||
):
|
||||
oldest_gpu_key = self.gpu_cache.first_key()
|
||||
self._move_to_cpu(oldest_gpu_key)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if (
|
||||
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
|
||||
or self.gpu_cache.memory_usage + bytes_to_free
|
||||
> get_mem_free_total(self.device)
|
||||
):
|
||||
raise RuntimeError("Unable to make space on GPU")
|
||||
|
||||
def make_cpu_space(self, bytes_to_free):
|
||||
import gc
|
||||
|
||||
import psutil
|
||||
|
||||
mem_free = psutil.virtual_memory().available * 0.8
|
||||
logger.debug(
|
||||
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of RAM space. current usage: {self.cpu_cache.memory_usage / (1024 ** 2):.2f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; max mem: {self.max_cpu_memory / (1024 ** 2):.1f} MB"
|
||||
)
|
||||
while self.cpu_cache and (
|
||||
self.cpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
|
||||
or self.cpu_cache.memory_usage + bytes_to_free
|
||||
> psutil.virtual_memory().available * 0.8
|
||||
):
|
||||
oldest_cpu_key = self.cpu_cache.first_key()
|
||||
logger.debug(f"dropping {oldest_cpu_key} from memory")
|
||||
self.cpu_cache.pop(oldest_cpu_key)
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
@cached_property
|
||||
def max_cpu_memory(self):
|
||||
_ = self.device
|
||||
if isinstance(self._max_cpu_memory_gb, str):
|
||||
if self._max_cpu_memory_gb.endswith("%"):
|
||||
import psutil
|
||||
|
||||
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
|
||||
pct_to_use = float(self._max_cpu_memory_gb[:-1]) / 100.0
|
||||
return total_ram_gb * pct_to_use * (1024**3)
|
||||
raise ValueError(
|
||||
f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}"
|
||||
)
|
||||
return self._max_cpu_memory_gb * (1024**3)
|
||||
|
||||
@cached_property
|
||||
def max_gpu_memory(self):
|
||||
_ = self.device
|
||||
if isinstance(self._max_gpu_memory_gb, str):
|
||||
if self._max_gpu_memory_gb.endswith("%"):
|
||||
import torch
|
||||
|
||||
if self.device.type == "cuda":
|
||||
device_props = torch.cuda.get_device_properties(0)
|
||||
total_ram_gb = device_props.total_memory / (1024**3)
|
||||
else:
|
||||
import psutil
|
||||
|
||||
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
|
||||
pct_to_use = float(self._max_gpu_memory_gb[:-1]) / 100.0
|
||||
return total_ram_gb * pct_to_use * (1024**3)
|
||||
raise ValueError(
|
||||
f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}"
|
||||
)
|
||||
return self._max_gpu_memory_gb * (1024**3)
|
||||
|
||||
def _move_to_gpu(self, key, model):
|
||||
model_size = get_model_size(model)
|
||||
if self.gpu_cache.memory_usage + model_size > self.max_gpu_memory:
|
||||
if len(self.gpu_cache) == 0:
|
||||
msg = f"GPU cache maximum ({self.max_gpu_memory / (1024 ** 2)} MB) is smaller than the item being cached ({model_size / 1024 ** 2} MB)."
|
||||
raise RuntimeError(msg)
|
||||
self.make_gpu_space(model_size)
|
||||
|
||||
try:
|
||||
model_size = max(self.cpu_cache._item_memory_usage[key], model_size)
|
||||
self.cpu_cache.pop(key)
|
||||
except KeyError:
|
||||
pass
|
||||
logger.debug(f"moving {key} to gpu")
|
||||
move_model_device(model, self.device)
|
||||
|
||||
self.gpu_cache.set(key, value=model, memory_usage=model_size)
|
||||
|
||||
def _move_to_cpu(self, key):
|
||||
import torch
|
||||
import psutil
|
||||
|
||||
memory_usage = self.gpu_cache._item_memory_usage[key]
|
||||
model = self.gpu_cache.pop(key)
|
||||
|
||||
model_size = max(get_model_size(model), memory_usage)
|
||||
self.make_cpu_space(model_size)
|
||||
|
||||
if (
|
||||
self.cpu_cache.memory_usage + model_size < self.max_cpu_memory
|
||||
and self.cpu_cache.memory_usage + model_size
|
||||
< psutil.virtual_memory().available * 0.8
|
||||
):
|
||||
logger.debug(f"moving {key} to cpu")
|
||||
move_model_device(model, torch.device("cpu"))
|
||||
|
||||
self.cpu_cache.set(key, model, memory_usage=model_size)
|
||||
else:
|
||||
logger.debug(f"dropping {key} from memory")
|
||||
|
||||
def get(self, key):
|
||||
import torch
|
||||
|
||||
if key not in self:
|
||||
raise KeyError(f"The key {key} does not exist in the cache")
|
||||
|
||||
if key in self.cpu_cache:
|
||||
if self.device != torch.device("cpu"):
|
||||
self.cpu_cache.move_to_end(key)
|
||||
self._move_to_gpu(key, self.cpu_cache[key])
|
||||
|
||||
if key in self.gpu_cache:
|
||||
self.gpu_cache.move_to_end(key)
|
||||
|
||||
model = self.gpu_cache.get(key)
|
||||
|
||||
return model
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.get(key)
|
||||
|
||||
def set(self, key, model, memory_usage=0):
|
||||
from torch import nn
|
||||
|
||||
if (
|
||||
hasattr(model, "model") and isinstance(model.model, nn.Module)
|
||||
) or isinstance(model, nn.Module):
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Only nn.Module objects can be cached")
|
||||
|
||||
model_size = max(get_model_size(model), memory_usage)
|
||||
self.make_gpu_space(model_size)
|
||||
self._move_to_gpu(key, model)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.gpu_cache or key in self.cpu_cache
|
||||
|
||||
def keys(self):
|
||||
return list(self.cpu_cache.keys()) + list(self.gpu_cache.keys())
|
||||
|
||||
def stats(self):
|
||||
return {
|
||||
"cpu_cache_count": len(self.cpu_cache),
|
||||
"cpu_cache_memory_usage": self.cpu_cache.memory_usage,
|
||||
"cpu_cache_max_memory": self.max_cpu_memory,
|
||||
"gpu_cache_count": len(self.gpu_cache),
|
||||
"gpu_cache_memory_usage": self.gpu_cache.memory_usage,
|
||||
"gpu_cache_max_memory": self.max_gpu_memory,
|
||||
}
|
||||
|
||||
|
||||
class MemoryManagedModelWrapper:
|
||||
_mmmw_cache = GPUModelCache()
|
||||
|
||||
def __init__(self, fn, namespace, estimated_ram_size_mb, *args, **kwargs):
|
||||
self._mmmw_fn = fn
|
||||
self._mmmw_args = args
|
||||
self._mmmw_kwargs = kwargs
|
||||
self._mmmw_namespace = namespace
|
||||
self._mmmw_estimated_ram_size_mb = estimated_ram_size_mb
|
||||
self._mmmw_cache_key = (namespace,) + args + tuple(kwargs.items())
|
||||
|
||||
def _mmmw_load_model(self):
|
||||
if self._mmmw_cache_key not in self.__class__._mmmw_cache:
|
||||
logger.debug(f"Loading model: {self._mmmw_cache_key}")
|
||||
self.__class__._mmmw_cache.make_gpu_space(
|
||||
self._mmmw_estimated_ram_size_mb * 1024**2
|
||||
)
|
||||
free_before = get_mem_free_total(self.__class__._mmmw_cache.device)
|
||||
model = self._mmmw_fn(*self._mmmw_args, **self._mmmw_kwargs)
|
||||
move_model_device(model, self.__class__._mmmw_cache.device)
|
||||
free_after = get_mem_free_total(self.__class__._mmmw_cache.device)
|
||||
logger.debug(
|
||||
f"Model loaded: {self._mmmw_cache_key} Used {free_before - free_after}"
|
||||
)
|
||||
self.__class__._mmmw_cache.set(
|
||||
self._mmmw_cache_key,
|
||||
model,
|
||||
memory_usage=self._mmmw_estimated_ram_size_mb * 1024**2,
|
||||
)
|
||||
|
||||
model = self.__class__._mmmw_cache[self._mmmw_cache_key]
|
||||
return model
|
||||
|
||||
def __getattr__(self, name):
|
||||
model = self._mmmw_load_model()
|
||||
return getattr(model, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
model = self._mmmw_load_model()
|
||||
return model(*args, **kwargs)
|
||||
|
||||
|
||||
def memory_managed_model(namespace, memory_usage_mb=0):
|
||||
def decorator(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
return MemoryManagedModelWrapper(
|
||||
fn, namespace, memory_usage_mb, *args, **kwargs
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
@ -0,0 +1,156 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from imaginairy import ImaginePrompt, imagine
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils.model_cache import GPUModelCache
|
||||
|
||||
|
||||
class DummyMemoryModule(nn.Module):
|
||||
def __init__(self, in_features):
|
||||
super().__init__()
|
||||
self.large_layer = nn.Linear(in_features - 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.large_layer(x)
|
||||
|
||||
|
||||
def create_model_of_n_bytes(n):
|
||||
import math
|
||||
n = int(math.floor(n/4))
|
||||
return DummyMemoryModule(n)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.parametrize(
|
||||
"model_version",
|
||||
[
|
||||
"SD-1.4",
|
||||
"SD-1.5",
|
||||
"SD-2.0",
|
||||
"SD-2.0-v",
|
||||
"SD-2.1",
|
||||
"SD-2.1-v",
|
||||
"openjourney-v1",
|
||||
"openjourney-v2",
|
||||
"openjourney-v4",
|
||||
],
|
||||
)
|
||||
def test_memory_usage(filename_base_for_orig_outputs, model_version):
|
||||
"""Test that we can switch between model versions."""
|
||||
prompt_text = "valley, fairytale treehouse village covered, , matte painting, highly detailed, dynamic lighting, cinematic, realism, realistic, photo real, sunset, detailed, high contrast, denoised, centered, michael whelan"
|
||||
prompts = [ImaginePrompt(prompt_text, model=model_version, seed=1, steps=30)]
|
||||
|
||||
for i, result in enumerate(imagine(prompts)):
|
||||
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png"
|
||||
result.img.save(img_path)
|
||||
|
||||
|
||||
def test_get_nonexistent():
|
||||
cache = GPUModelCache(max_cpu_memory_gb=1, max_gpu_memory_gb=1)
|
||||
with pytest.raises(KeyError):
|
||||
cache.get("nonexistent_key")
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
|
||||
def test_set_cpu_full():
|
||||
cache = GPUModelCache(
|
||||
max_cpu_memory_gb=0.000000001, max_gpu_memory_gb=0.01, device=get_device()
|
||||
)
|
||||
|
||||
for i in range(4):
|
||||
cache.set(f"key{i}", create_model_of_n_bytes(4_000_000))
|
||||
assert len(cache.cpu_cache) == 0
|
||||
assert len(cache.gpu_cache) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
|
||||
def test_set_gpu_full():
|
||||
cache = GPUModelCache(
|
||||
max_cpu_memory_gb=1, max_gpu_memory_gb=0.0000001, device=get_device()
|
||||
)
|
||||
assert cache.max_cpu_memory == 1073741824
|
||||
model = create_model_of_n_bytes(100_000)
|
||||
with pytest.raises(RuntimeError):
|
||||
cache.set("key1", model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
|
||||
def test_get_existing_cpu():
|
||||
cache = GPUModelCache(max_cpu_memory_gb=0.1, max_gpu_memory_gb=0.1, device="cpu")
|
||||
model = create_model_of_n_bytes(10_000)
|
||||
cache.set("key", model)
|
||||
retrieved_data = cache.get("key")
|
||||
assert retrieved_data == model
|
||||
# assert 'key' in cache.cpu_cache
|
||||
assert "key" in cache.gpu_cache
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
|
||||
def test_get_existing_move_to_gpu():
|
||||
cache = GPUModelCache(
|
||||
max_cpu_memory_gb=0.1, max_gpu_memory_gb=0.1, device=get_device()
|
||||
)
|
||||
model = create_model_of_n_bytes(10_000)
|
||||
cache.set("key", model)
|
||||
retrieved_data = cache.get("key")
|
||||
assert retrieved_data == model
|
||||
assert "key" not in cache.cpu_cache
|
||||
assert "key" in cache.gpu_cache
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
|
||||
def test_cache_ordering():
|
||||
cache = GPUModelCache(
|
||||
max_cpu_memory_gb=0.01, max_gpu_memory_gb=0.01, device=get_device()
|
||||
)
|
||||
|
||||
cache.set("key-0", create_model_of_n_bytes(4_000_000))
|
||||
assert list(cache.cpu_cache.keys()) == [] # noqa
|
||||
assert list(cache.gpu_cache.keys()) == ["key-0"]
|
||||
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
|
||||
0,
|
||||
4_000_000,
|
||||
)
|
||||
|
||||
cache.set("key-1", create_model_of_n_bytes(4_000_000))
|
||||
assert list(cache.cpu_cache.keys()) == [] # noqa
|
||||
assert list(cache.gpu_cache.keys()) == ["key-0", "key-1"]
|
||||
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
|
||||
0,
|
||||
8_000_000,
|
||||
)
|
||||
|
||||
cache.set("key-2", create_model_of_n_bytes(4_000_000))
|
||||
assert list(cache.cpu_cache.keys()) == ["key-0"]
|
||||
assert list(cache.gpu_cache.keys()) == ["key-1", "key-2"]
|
||||
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
|
||||
4_000_000,
|
||||
8_000_000,
|
||||
)
|
||||
|
||||
cache.set("key-3", create_model_of_n_bytes(4_000_000))
|
||||
assert list(cache.cpu_cache.keys()) == ["key-0", "key-1"]
|
||||
assert list(cache.gpu_cache.keys()) == ["key-2", "key-3"]
|
||||
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
|
||||
8_000_000,
|
||||
8_000_000,
|
||||
)
|
||||
|
||||
cache.set("key-4", create_model_of_n_bytes(4_000_000))
|
||||
assert list(cache.cpu_cache.keys()) == ["key-1", "key-2"]
|
||||
assert list(cache.gpu_cache.keys()) == ["key-3", "key-4"]
|
||||
assert list(cache.keys()) == ["key-1", "key-2", "key-3", "key-4"]
|
||||
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
|
||||
8_000_000,
|
||||
8_000_000,
|
||||
)
|
||||
|
||||
cache.get("key-2")
|
||||
assert list(cache.keys()) == ["key-3", "key-4", "key-2"]
|
||||
|
||||
cache.set("key-5", create_model_of_n_bytes(9_000_000))
|
||||
assert list(cache.cpu_cache.keys()) == ["key-4", "key-2"]
|
||||
assert list(cache.gpu_cache.keys()) == ["key-5"]
|
||||
assert list(cache.keys()) == ["key-4", "key-2", "key-5"]
|
Loading…
Reference in New Issue