imaginAIry/imaginairy/utils/memory_tracker.py
Bryce f50a1f5b0c fix: interrupted generations don't prevent more generations
fixes #424

- pref: improve memory usage when loading SD15.
- feature: clean up CLI output more
- feature: cuda memory tracking context manager
- feature: use safetensors fp16 for sd15
2024-01-01 19:59:31 -08:00

44 lines
1.4 KiB
Python

import contextlib
from typing import Callable, List
import torch
class TorchRAMTracker(contextlib.ContextDecorator):
"""Tracks peak CUDA memory usage for a block of code."""
_memory_stack: List[int] = []
mem_interface = torch.cuda
def __init__(
self, name="", callback_fn: "Callable[[TorchRAMTracker], None] | None" = None
):
self.name = name
self.peak_memory = 0
self.start_memory = 0
self.end_memory = 0
self.callback_fn = callback_fn
self._stack_depth = None
def start(self):
current_peak = self.mem_interface.max_memory_allocated()
TorchRAMTracker._memory_stack.append(current_peak)
self._stack_depth = len(TorchRAMTracker._memory_stack)
self.mem_interface.reset_peak_memory_stats()
self.start_memory = self.mem_interface.memory_allocated()
def stop(self):
end_peak = self.mem_interface.max_memory_allocated()
peaks = TorchRAMTracker._memory_stack[self._stack_depth :] + [end_peak]
self.peak_memory = max(peaks)
del TorchRAMTracker._memory_stack[self._stack_depth :]
self.end_memory = self.mem_interface.memory_allocated()
self.peak_memory_delta = self.peak_memory - self.start_memory
def __enter__(self):
self.start()
return self
def __exit__(self, *exc):
self.stop()