Add basic support for device.type == "mps"

pull/477/head
Aleksandr Borzunov 9 months ago
parent 92287b24a3
commit 2fa0e68e8e

@ -153,7 +153,12 @@ class Server:
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device(device.type, index=0)
@ -373,6 +378,8 @@ class Server:
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
f"{reserved_vram / gib:.1f} GiB reserved memory"
)
elif self.device.type == "mps":
torch.mps.empty_cache()
def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None:

@ -207,14 +207,12 @@ def measure_compute_rps(
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
if device.type == "cuda":
torch.cuda.synchronize(device)
synchronize(device)
start_time = time.perf_counter()
for step in range(n_steps):
for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
if device.type == "cuda":
torch.cuda.synchronize(device)
synchronize(device)
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
@ -230,8 +228,15 @@ def measure_compute_rps(
return device_rps
def synchronize(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize(device)
elif device.type == "mps":
torch.mps.synchronize(device)
def get_device_name(device: torch.device) -> str:
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper()
def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:

Loading…
Cancel
Save