mirror of https://github.com/kritiksoman/GIMP-ML
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
990 B
Python
34 lines
990 B
Python
import random
|
|
import numpy as np
|
|
import torch
|
|
from torch.autograd import Variable
|
|
from collections import deque
|
|
|
|
|
|
class ImagePool():
|
|
def __init__(self, pool_size):
|
|
self.pool_size = pool_size
|
|
self.sample_size = pool_size
|
|
if self.pool_size > 0:
|
|
self.num_imgs = 0
|
|
self.images = deque()
|
|
|
|
def add(self, images):
|
|
if self.pool_size == 0:
|
|
return images
|
|
for image in images.data:
|
|
image = torch.unsqueeze(image, 0)
|
|
if self.num_imgs < self.pool_size:
|
|
self.num_imgs = self.num_imgs + 1
|
|
self.images.append(image)
|
|
else:
|
|
self.images.popleft()
|
|
self.images.append(image)
|
|
|
|
def query(self):
|
|
if len(self.images) > self.sample_size:
|
|
return_images = list(random.sample(self.images, self.sample_size))
|
|
else:
|
|
return_images = list(self.images)
|
|
return torch.cat(return_images, 0)
|