Fix calling torch.mps.synchronize()

pull/477/head
Aleksandr Borzunov 9 months ago
parent 9f3671114e
commit 5177079239

@ -233,7 +233,7 @@ def synchronize(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize(device)
elif device.type == "mps":
torch.mps.synchronize(device)
torch.mps.synchronize()
def get_device_name(device: torch.device) -> str:

Loading…
Cancel
Save