GIMP-ML/gimp-plugins/DeblurGANv2/test_dataset.py
2020-04-27 10:02:33 +05:30

77 lines
2.6 KiB
Python

import os
import unittest
from shutil import rmtree
from tempfile import mkdtemp
import cv2
import numpy as np
from torch.utils.data import DataLoader
from dataset import PairedDataset
def make_img():
return (np.random.rand(100, 100, 3) * 255).astype('uint8')
class AugTest(unittest.TestCase):
tmp_dir = mkdtemp()
raw = os.path.join(tmp_dir, 'raw')
gt = os.path.join(tmp_dir, 'gt')
def setUp(self):
for d in (self.raw, self.gt):
os.makedirs(d)
for i in range(5):
for d in (self.raw, self.gt):
img = make_img()
cv2.imwrite(os.path.join(d, f'{i}.png'), img)
def tearDown(self):
rmtree(self.tmp_dir)
def dataset_gen(self, equal=True):
base_config = {'files_a': os.path.join(self.raw, '*.png'),
'files_b': os.path.join(self.raw if equal else self.gt, '*.png'),
'size': 32,
}
for b in ([0, 1], [0, 0.9]):
for scope in ('strong', 'weak'):
for crop in ('random', 'center'):
for preload in (0, 1):
for preload_size in (0, 64):
config = base_config.copy()
config['bounds'] = b
config['scope'] = scope
config['crop'] = crop
config['preload'] = preload
config['preload_size'] = preload_size
config['verbose'] = False
dataset = PairedDataset.from_config(config)
yield dataset
def test_equal_datasets(self):
for dataset in self.dataset_gen(equal=True):
dataloader = DataLoader(dataset=dataset,
batch_size=2,
shuffle=True,
drop_last=True)
dataloader = iter(dataloader)
batch = next(dataloader)
a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
np.testing.assert_allclose(a, b)
def test_datasets(self):
for dataset in self.dataset_gen(equal=False):
dataloader = DataLoader(dataset=dataset,
batch_size=2,
shuffle=True,
drop_last=True)
dataloader = iter(dataloader)
batch = next(dataloader)
a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
assert not np.all(a == b), 'images should not be the same'