mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-10-31 09:20:18 +00:00
32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
|
import random
|
||
|
import torch
|
||
|
from torch.autograd import Variable
|
||
|
class ImagePool():
|
||
|
def __init__(self, pool_size):
|
||
|
self.pool_size = pool_size
|
||
|
if self.pool_size > 0:
|
||
|
self.num_imgs = 0
|
||
|
self.images = []
|
||
|
|
||
|
def query(self, images):
|
||
|
if self.pool_size == 0:
|
||
|
return images
|
||
|
return_images = []
|
||
|
for image in images.data:
|
||
|
image = torch.unsqueeze(image, 0)
|
||
|
if self.num_imgs < self.pool_size:
|
||
|
self.num_imgs = self.num_imgs + 1
|
||
|
self.images.append(image)
|
||
|
return_images.append(image)
|
||
|
else:
|
||
|
p = random.uniform(0, 1)
|
||
|
if p > 0.5:
|
||
|
random_id = random.randint(0, self.pool_size-1)
|
||
|
tmp = self.images[random_id].clone()
|
||
|
self.images[random_id] = image
|
||
|
return_images.append(tmp)
|
||
|
else:
|
||
|
return_images.append(image)
|
||
|
return_images = Variable(torch.cat(return_images, 0))
|
||
|
return return_images
|