mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-11-06 03:20:34 +00:00
22 lines
601 B
Python
Executable File
22 lines
601 B
Python
Executable File
import torch.nn as nn
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
class GlobalAvgPool2d(nn.Module):
|
|
def __init__(self):
|
|
"""Global average pooling over the input's spatial dimensions"""
|
|
super(GlobalAvgPool2d, self).__init__()
|
|
|
|
def forward(self, inputs):
|
|
in_size = inputs.size()
|
|
return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
|
|
|
class SingleGPU(nn.Module):
|
|
def __init__(self, module):
|
|
super(SingleGPU, self).__init__()
|
|
self.module=module
|
|
|
|
def forward(self, input):
|
|
return self.module(input.cuda(non_blocking=True))
|
|
|