|
|
|
@ -10,7 +10,7 @@ import ctypes
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple
|
|
|
|
|
from typing import AsyncContextManager, Dict, Optional, Sequence
|
|
|
|
|
|
|
|
|
|
import hivemind
|
|
|
|
|
import torch
|
|
|
|
@ -29,7 +29,7 @@ class MemoryCache:
|
|
|
|
|
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
|
|
|
|
|
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
|
|
|
|
self.alloc_timeout = alloc_timeout
|
|
|
|
|
self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
|
|
|
|
|
self._lock_metadata = mp.Lock()
|
|
|
|
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
|
|
|
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
|
|
|
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
|
|
|
|
|