GIMP-ML/gimp-plugins/pytorch-SRResNet/dataset.py

16 lines
508 B
Python
Raw Normal View History

2020-04-27 04:32:33 +00:00
import torch.utils.data as data
import torch
import h5py
class DatasetFromHdf5(data.Dataset):
def __init__(self, file_path):
super(DatasetFromHdf5, self).__init__()
hf = h5py.File(file_path)
self.data = hf.get("data")
self.target = hf.get("label")
def __getitem__(self, index):
return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
def __len__(self):
return self.data.shape[0]