You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
3.3 KiB
Python

# -*- coding: utf8 -*-
import torch.cuda as cuda
import torch.nn as nn
import torch
import collections
from torch.nn.parallel._functions import Gather
__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
def async_copy_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
v = obj.cuda(dev, non_blocking=True)
if main_stream is not None:
v.data.record_stream(main_stream)
return v
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
else:
return obj
def dict_gather(outputs, target_device, dim=0):
"""
Gathers variables from different GPUs on a specified device
(-1 means the CPU), with dictionary support.
"""
def gather_map(outputs):
out = outputs[0]
if torch.is_tensor(out):
# MJY(20180330) HACK:: force nr_dims > 0
if out.dim() == 0:
outputs = [o.unsqueeze(0) for o in outputs]
return Gather.apply(target_device, dim, *outputs)
elif out is None:
return None
elif isinstance(out, collections.Mapping):
return {k: gather_map([o[k] for o in outputs]) for k in out}
elif isinstance(out, collections.Sequence):
return type(out)(map(gather_map, zip(*outputs)))
return gather_map(outputs)
class DictGatherDataParallel(nn.DataParallel):
def gather(self, outputs, output_device):
return dict_gather(outputs, output_device, dim=self.dim)
class UserScatteredDataParallel(DictGatherDataParallel):
def scatter(self, inputs, kwargs, device_ids):
assert len(inputs) == 1
inputs = inputs[0]
inputs = _async_copy_stream(inputs, device_ids)
inputs = [[i] for i in inputs]
assert len(kwargs) == 0
kwargs = [{} for _ in range(len(inputs))]
return inputs, kwargs
def user_scattered_collate(batch):
return batch
def _async_copy(inputs, device_ids):
nr_devs = len(device_ids)
assert type(inputs) in (tuple, list)
assert len(inputs) == nr_devs
outputs = []
for i, dev in zip(inputs, device_ids):
with cuda.device(dev):
outputs.append(async_copy_to(i, dev))
return tuple(outputs)
def _async_copy_stream(inputs, device_ids):
nr_devs = len(device_ids)
assert type(inputs) in (tuple, list)
assert len(inputs) == nr_devs
outputs = []
streams = [_get_stream(d) for d in device_ids]
for i, dev, stream in zip(inputs, device_ids, streams):
with cuda.device(dev):
main_stream = cuda.current_stream()
with cuda.stream(stream):
outputs.append(async_copy_to(i, dev, main_stream=main_stream))
main_stream.wait_stream(stream)
return outputs
"""Adapted from: torch/nn/parallel/_functions.py"""
# background streams used for copying
_streams = None
def _get_stream(device):
"""Gets a background stream for copying between CPU and GPU"""
global _streams
if device == -1:
return None
if _streams is None:
_streams = [None] * cuda.device_count()
if _streams[device] is None: _streams[device] = cuda.Stream(device)
return _streams[device]