mirror of https://github.com/kritiksoman/GIMP-ML
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
3.3 KiB
Python
113 lines
3.3 KiB
Python
# -*- coding: utf8 -*-
|
|
|
|
import torch.cuda as cuda
|
|
import torch.nn as nn
|
|
import torch
|
|
import collections
|
|
from torch.nn.parallel._functions import Gather
|
|
|
|
|
|
__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
|
|
|
|
|
|
def async_copy_to(obj, dev, main_stream=None):
|
|
if torch.is_tensor(obj):
|
|
v = obj.cuda(dev, non_blocking=True)
|
|
if main_stream is not None:
|
|
v.data.record_stream(main_stream)
|
|
return v
|
|
elif isinstance(obj, collections.Mapping):
|
|
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
|
|
elif isinstance(obj, collections.Sequence):
|
|
return [async_copy_to(o, dev, main_stream) for o in obj]
|
|
else:
|
|
return obj
|
|
|
|
|
|
def dict_gather(outputs, target_device, dim=0):
|
|
"""
|
|
Gathers variables from different GPUs on a specified device
|
|
(-1 means the CPU), with dictionary support.
|
|
"""
|
|
def gather_map(outputs):
|
|
out = outputs[0]
|
|
if torch.is_tensor(out):
|
|
# MJY(20180330) HACK:: force nr_dims > 0
|
|
if out.dim() == 0:
|
|
outputs = [o.unsqueeze(0) for o in outputs]
|
|
return Gather.apply(target_device, dim, *outputs)
|
|
elif out is None:
|
|
return None
|
|
elif isinstance(out, collections.Mapping):
|
|
return {k: gather_map([o[k] for o in outputs]) for k in out}
|
|
elif isinstance(out, collections.Sequence):
|
|
return type(out)(map(gather_map, zip(*outputs)))
|
|
return gather_map(outputs)
|
|
|
|
|
|
class DictGatherDataParallel(nn.DataParallel):
|
|
def gather(self, outputs, output_device):
|
|
return dict_gather(outputs, output_device, dim=self.dim)
|
|
|
|
|
|
class UserScatteredDataParallel(DictGatherDataParallel):
|
|
def scatter(self, inputs, kwargs, device_ids):
|
|
assert len(inputs) == 1
|
|
inputs = inputs[0]
|
|
inputs = _async_copy_stream(inputs, device_ids)
|
|
inputs = [[i] for i in inputs]
|
|
assert len(kwargs) == 0
|
|
kwargs = [{} for _ in range(len(inputs))]
|
|
|
|
return inputs, kwargs
|
|
|
|
|
|
def user_scattered_collate(batch):
|
|
return batch
|
|
|
|
|
|
def _async_copy(inputs, device_ids):
|
|
nr_devs = len(device_ids)
|
|
assert type(inputs) in (tuple, list)
|
|
assert len(inputs) == nr_devs
|
|
|
|
outputs = []
|
|
for i, dev in zip(inputs, device_ids):
|
|
with cuda.device(dev):
|
|
outputs.append(async_copy_to(i, dev))
|
|
|
|
return tuple(outputs)
|
|
|
|
|
|
def _async_copy_stream(inputs, device_ids):
|
|
nr_devs = len(device_ids)
|
|
assert type(inputs) in (tuple, list)
|
|
assert len(inputs) == nr_devs
|
|
|
|
outputs = []
|
|
streams = [_get_stream(d) for d in device_ids]
|
|
for i, dev, stream in zip(inputs, device_ids, streams):
|
|
with cuda.device(dev):
|
|
main_stream = cuda.current_stream()
|
|
with cuda.stream(stream):
|
|
outputs.append(async_copy_to(i, dev, main_stream=main_stream))
|
|
main_stream.wait_stream(stream)
|
|
|
|
return outputs
|
|
|
|
|
|
"""Adapted from: torch/nn/parallel/_functions.py"""
|
|
# background streams used for copying
|
|
_streams = None
|
|
|
|
|
|
def _get_stream(device):
|
|
"""Gets a background stream for copying between CPU and GPU"""
|
|
global _streams
|
|
if device == -1:
|
|
return None
|
|
if _streams is None:
|
|
_streams = [None] * cuda.device_count()
|
|
if _streams[device] is None: _streams[device] = cuda.Stream(device)
|
|
return _streams[device]
|