You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import contextlib
|
|
from typing import Callable, List
|
|
|
|
import torch
|
|
|
|
|
|
class TorchRAMTracker(contextlib.ContextDecorator):
|
|
"""Tracks peak CUDA memory usage for a block of code."""
|
|
|
|
_memory_stack: List[int] = []
|
|
mem_interface = torch.cuda
|
|
|
|
def __init__(
|
|
self, name="", callback_fn: "Callable[[TorchRAMTracker], None] | None" = None
|
|
):
|
|
self.name = name
|
|
self.peak_memory = 0
|
|
self.start_memory = 0
|
|
self.end_memory = 0
|
|
self.callback_fn = callback_fn
|
|
self._stack_depth = None
|
|
|
|
def start(self):
|
|
current_peak = self.mem_interface.max_memory_allocated()
|
|
TorchRAMTracker._memory_stack.append(current_peak)
|
|
self._stack_depth = len(TorchRAMTracker._memory_stack)
|
|
self.mem_interface.reset_peak_memory_stats()
|
|
self.start_memory = self.mem_interface.memory_allocated()
|
|
|
|
def stop(self):
|
|
end_peak = self.mem_interface.max_memory_allocated()
|
|
peaks = TorchRAMTracker._memory_stack[self._stack_depth :] + [end_peak]
|
|
self.peak_memory = max(peaks)
|
|
del TorchRAMTracker._memory_stack[self._stack_depth :]
|
|
self.end_memory = self.mem_interface.memory_allocated()
|
|
self.peak_memory_delta = self.peak_memory - self.start_memory
|
|
|
|
def __enter__(self):
|
|
self.start()
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.stop()
|