submit project
parent
16f0403c02
commit
f7041fd8fa
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -0,0 +1,239 @@
|
||||
import math
|
||||
import os
|
||||
import hashlib
|
||||
from urllib.request import urlretrieve
|
||||
import zipfile
|
||||
import gzip
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _read32(bytestream):
|
||||
"""
|
||||
Read 32-bit integer from bytesteam
|
||||
:param bytestream: A bytestream
|
||||
:return: 32-bit integer
|
||||
"""
|
||||
dt = np.dtype(np.uint32).newbyteorder('>')
|
||||
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
|
||||
|
||||
|
||||
def _unzip(save_path, _, database_name, data_path):
|
||||
"""
|
||||
Unzip wrapper with the same interface as _ungzip
|
||||
:param save_path: The path of the gzip files
|
||||
:param database_name: Name of database
|
||||
:param data_path: Path to extract to
|
||||
:param _: HACK - Used to have to same interface as _ungzip
|
||||
"""
|
||||
print('Extracting {}...'.format(database_name))
|
||||
with zipfile.ZipFile(save_path) as zf:
|
||||
zf.extractall(data_path)
|
||||
|
||||
|
||||
def _ungzip(save_path, extract_path, database_name, _):
|
||||
"""
|
||||
Unzip a gzip file and extract it to extract_path
|
||||
:param save_path: The path of the gzip files
|
||||
:param extract_path: The location to extract the data to
|
||||
:param database_name: Name of database
|
||||
:param _: HACK - Used to have to same interface as _unzip
|
||||
"""
|
||||
# Get data from save_path
|
||||
with open(save_path, 'rb') as f:
|
||||
with gzip.GzipFile(fileobj=f) as bytestream:
|
||||
magic = _read32(bytestream)
|
||||
if magic != 2051:
|
||||
raise ValueError('Invalid magic number {} in file: {}'.format(magic, f.name))
|
||||
num_images = _read32(bytestream)
|
||||
rows = _read32(bytestream)
|
||||
cols = _read32(bytestream)
|
||||
buf = bytestream.read(rows * cols * num_images)
|
||||
data = np.frombuffer(buf, dtype=np.uint8)
|
||||
data = data.reshape(num_images, rows, cols)
|
||||
|
||||
# Save data to extract_path
|
||||
for image_i, image in enumerate(
|
||||
tqdm(data, unit='File', unit_scale=True, miniters=1, desc='Extracting {}'.format(database_name))):
|
||||
Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i)))
|
||||
|
||||
|
||||
def get_image(image_path, width, height, mode):
|
||||
"""
|
||||
Read image from image_path
|
||||
:param image_path: Path of image
|
||||
:param width: Width of image
|
||||
:param height: Height of image
|
||||
:param mode: Mode of image
|
||||
:return: Image data
|
||||
"""
|
||||
image = Image.open(image_path)
|
||||
|
||||
if image.size != (width, height): # HACK - Check if image is from the CELEBA dataset
|
||||
# Remove most pixels that aren't part of a face
|
||||
face_width = face_height = 108
|
||||
j = (image.size[0] - face_width) // 2
|
||||
i = (image.size[1] - face_height) // 2
|
||||
image = image.crop([j, i, j + face_width, i + face_height])
|
||||
image = image.resize([width, height], Image.BILINEAR)
|
||||
|
||||
return np.array(image.convert(mode))
|
||||
|
||||
|
||||
def get_batch(image_files, width, height, mode):
|
||||
data_batch = np.array(
|
||||
[get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32)
|
||||
|
||||
# Make sure the images are in 4 dimensions
|
||||
if len(data_batch.shape) < 4:
|
||||
data_batch = data_batch.reshape(data_batch.shape + (1,))
|
||||
|
||||
return data_batch
|
||||
|
||||
|
||||
def images_square_grid(images, mode):
|
||||
"""
|
||||
Save images as a square grid
|
||||
:param images: Images to be used for the grid
|
||||
:param mode: The mode to use for images
|
||||
:return: Image of images in a square grid
|
||||
"""
|
||||
# Get maximum size for square grid of images
|
||||
save_size = math.floor(np.sqrt(images.shape[0]))
|
||||
|
||||
# Scale to 0-255
|
||||
images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)
|
||||
|
||||
# Put images in a square arrangement
|
||||
images_in_square = np.reshape(
|
||||
images[:save_size*save_size],
|
||||
(save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
|
||||
if mode == 'L':
|
||||
images_in_square = np.squeeze(images_in_square, 4)
|
||||
|
||||
# Combine images to grid image
|
||||
new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
|
||||
for col_i, col_images in enumerate(images_in_square):
|
||||
for image_i, image in enumerate(col_images):
|
||||
im = Image.fromarray(image, mode)
|
||||
new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))
|
||||
|
||||
return new_im
|
||||
|
||||
|
||||
def download_extract(database_name, data_path):
|
||||
"""
|
||||
Download and extract database
|
||||
:param database_name: Database name
|
||||
"""
|
||||
DATASET_CELEBA_NAME = 'celeba'
|
||||
DATASET_MNIST_NAME = 'mnist'
|
||||
|
||||
if database_name == DATASET_CELEBA_NAME:
|
||||
url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip'
|
||||
hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb'
|
||||
extract_path = os.path.join(data_path, 'img_align_celeba')
|
||||
save_path = os.path.join(data_path, 'celeba.zip')
|
||||
extract_fn = _unzip
|
||||
elif database_name == DATASET_MNIST_NAME:
|
||||
url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
|
||||
hash_code = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
|
||||
extract_path = os.path.join(data_path, 'mnist')
|
||||
save_path = os.path.join(data_path, 'train-images-idx3-ubyte.gz')
|
||||
extract_fn = _ungzip
|
||||
|
||||
if os.path.exists(extract_path):
|
||||
print('Found {} Data'.format(database_name))
|
||||
return
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
os.makedirs(data_path)
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar:
|
||||
urlretrieve(
|
||||
url,
|
||||
save_path,
|
||||
pbar.hook)
|
||||
|
||||
assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \
|
||||
'{} file is corrupted. Remove the file and try again.'.format(save_path)
|
||||
|
||||
os.makedirs(extract_path)
|
||||
try:
|
||||
extract_fn(save_path, extract_path, database_name, data_path)
|
||||
except Exception as err:
|
||||
shutil.rmtree(extract_path) # Remove extraction folder if there is an error
|
||||
raise err
|
||||
|
||||
# Remove compressed data
|
||||
os.remove(save_path)
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
"""
|
||||
Dataset
|
||||
"""
|
||||
def __init__(self, dataset_name, data_files):
|
||||
"""
|
||||
Initalize the class
|
||||
:param dataset_name: Database name
|
||||
:param data_files: List of files in the database
|
||||
"""
|
||||
DATASET_CELEBA_NAME = 'celeba'
|
||||
DATASET_MNIST_NAME = 'mnist'
|
||||
IMAGE_WIDTH = 28
|
||||
IMAGE_HEIGHT = 28
|
||||
|
||||
if dataset_name == DATASET_CELEBA_NAME:
|
||||
self.image_mode = 'RGB'
|
||||
image_channels = 3
|
||||
|
||||
elif dataset_name == DATASET_MNIST_NAME:
|
||||
self.image_mode = 'L'
|
||||
image_channels = 1
|
||||
|
||||
self.data_files = data_files
|
||||
self.shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, image_channels
|
||||
|
||||
def get_batches(self, batch_size):
|
||||
"""
|
||||
Generate batches
|
||||
:param batch_size: Batch Size
|
||||
:return: Batches of data
|
||||
"""
|
||||
IMAGE_MAX_VALUE = 255
|
||||
|
||||
current_index = 0
|
||||
while current_index + batch_size <= self.shape[0]:
|
||||
data_batch = get_batch(
|
||||
self.data_files[current_index:current_index + batch_size],
|
||||
*self.shape[1:3],
|
||||
self.image_mode)
|
||||
|
||||
current_index += batch_size
|
||||
|
||||
yield data_batch / IMAGE_MAX_VALUE - 0.5
|
||||
|
||||
|
||||
class DLProgress(tqdm):
|
||||
"""
|
||||
Handle Progress Bar while Downloading
|
||||
"""
|
||||
last_block = 0
|
||||
|
||||
def hook(self, block_num=1, block_size=1, total_size=None):
|
||||
"""
|
||||
A hook function that will be called once on establishment of the network connection and
|
||||
once after each block read thereafter.
|
||||
:param block_num: A count of blocks transferred so far
|
||||
:param block_size: Block size in bytes
|
||||
:param total_size: The total size of the file. This may be -1 on older FTP servers which do not return
|
||||
a file size in response to a retrieval request.
|
||||
"""
|
||||
self.total = total_size
|
||||
self.update((block_num - self.last_block) * block_size)
|
||||
self.last_block = block_num
|
@ -0,0 +1,151 @@
|
||||
from copy import deepcopy
|
||||
from unittest import mock
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def test_safe(func):
|
||||
"""
|
||||
Isolate tests
|
||||
"""
|
||||
def func_wrapper(*args):
|
||||
with tf.Graph().as_default():
|
||||
result = func(*args)
|
||||
print('Tests Passed')
|
||||
return result
|
||||
|
||||
return func_wrapper
|
||||
|
||||
|
||||
def _assert_tensor_shape(tensor, shape, display_name):
|
||||
assert tf.assert_rank(tensor, len(shape), message='{} has wrong rank'.format(display_name))
|
||||
|
||||
tensor_shape = tensor.get_shape().as_list() if len(shape) else []
|
||||
|
||||
wrong_dimension = [ten_dim for ten_dim, cor_dim in zip(tensor_shape, shape)
|
||||
if cor_dim is not None and ten_dim != cor_dim]
|
||||
assert not wrong_dimension, \
|
||||
'{} has wrong shape. Found {}'.format(display_name, tensor_shape)
|
||||
|
||||
|
||||
def _check_input(tensor, shape, display_name, tf_name=None):
|
||||
assert tensor.op.type == 'Placeholder', \
|
||||
'{} is not a Placeholder.'.format(display_name)
|
||||
|
||||
_assert_tensor_shape(tensor, shape, 'Real Input')
|
||||
|
||||
if tf_name:
|
||||
assert tensor.name == tf_name, \
|
||||
'{} has bad name. Found name {}'.format(display_name, tensor.name)
|
||||
|
||||
|
||||
class TmpMock():
|
||||
"""
|
||||
Mock a attribute. Restore attribute when exiting scope.
|
||||
"""
|
||||
def __init__(self, module, attrib_name):
|
||||
self.original_attrib = deepcopy(getattr(module, attrib_name))
|
||||
setattr(module, attrib_name, mock.MagicMock())
|
||||
self.module = module
|
||||
self.attrib_name = attrib_name
|
||||
|
||||
def __enter__(self):
|
||||
return getattr(self.module, self.attrib_name)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
setattr(self.module, self.attrib_name, self.original_attrib)
|
||||
|
||||
|
||||
@test_safe
|
||||
def test_model_inputs(model_inputs):
|
||||
image_width = 28
|
||||
image_height = 28
|
||||
image_channels = 3
|
||||
z_dim = 100
|
||||
input_real, input_z, learn_rate = model_inputs(image_width, image_height, image_channels, z_dim)
|
||||
|
||||
_check_input(input_real, [None, image_width, image_height, image_channels], 'Real Input')
|
||||
_check_input(input_z, [None, z_dim], 'Z Input')
|
||||
_check_input(learn_rate, [], 'Learning Rate')
|
||||
|
||||
|
||||
@test_safe
|
||||
def test_discriminator(discriminator, tf_module):
|
||||
with TmpMock(tf_module, 'variable_scope') as mock_variable_scope:
|
||||
image = tf.placeholder(tf.float32, [None, 28, 28, 3])
|
||||
|
||||
output, logits = discriminator(image)
|
||||
_assert_tensor_shape(output, [None, 1], 'Discriminator Training(reuse=false) output')
|
||||
_assert_tensor_shape(logits, [None, 1], 'Discriminator Training(reuse=false) Logits')
|
||||
assert mock_variable_scope.called,\
|
||||
'tf.variable_scope not called in Discriminator Training(reuse=false)'
|
||||
assert mock_variable_scope.call_args == mock.call('discriminator', reuse=False), \
|
||||
'tf.variable_scope called with wrong arguments in Discriminator Training(reuse=false)'
|
||||
|
||||
mock_variable_scope.reset_mock()
|
||||
|
||||
output_reuse, logits_reuse = discriminator(image, True)
|
||||
_assert_tensor_shape(output_reuse, [None, 1], 'Discriminator Inference(reuse=True) output')
|
||||
_assert_tensor_shape(logits_reuse, [None, 1], 'Discriminator Inference(reuse=True) Logits')
|
||||
assert mock_variable_scope.called, \
|
||||
'tf.variable_scope not called in Discriminator Inference(reuse=True)'
|
||||
assert mock_variable_scope.call_args == mock.call('discriminator', reuse=True), \
|
||||
'tf.variable_scope called with wrong arguments in Discriminator Inference(reuse=True)'
|
||||
|
||||
|
||||
@test_safe
|
||||
def test_generator(generator, tf_module):
|
||||
with TmpMock(tf_module, 'variable_scope') as mock_variable_scope:
|
||||
z = tf.placeholder(tf.float32, [None, 100])
|
||||
out_channel_dim = 5
|
||||
|
||||
output = generator(z, out_channel_dim)
|
||||
_assert_tensor_shape(output, [None, 28, 28, out_channel_dim], 'Generator output (is_train=True)')
|
||||
assert mock_variable_scope.called, \
|
||||
'tf.variable_scope not called in Generator Training(reuse=false)'
|
||||
assert mock_variable_scope.call_args == mock.call('generator', reuse=False), \
|
||||
'tf.variable_scope called with wrong arguments in Generator Training(reuse=false)'
|
||||
|
||||
mock_variable_scope.reset_mock()
|
||||
output = generator(z, out_channel_dim, False)
|
||||
_assert_tensor_shape(output, [None, 28, 28, out_channel_dim], 'Generator output (is_train=False)')
|
||||
assert mock_variable_scope.called, \
|
||||
'tf.variable_scope not called in Generator Inference(reuse=True)'
|
||||
assert mock_variable_scope.call_args == mock.call('generator', reuse=True), \
|
||||
'tf.variable_scope called with wrong arguments in Generator Inference(reuse=True)'
|
||||
|
||||
|
||||
@test_safe
|
||||
def test_model_loss(model_loss):
|
||||
out_channel_dim = 4
|
||||
input_real = tf.placeholder(tf.float32, [None, 28, 28, out_channel_dim])
|
||||
input_z = tf.placeholder(tf.float32, [None, 100])
|
||||
|
||||
d_loss, g_loss = model_loss(input_real, input_z, out_channel_dim)
|
||||
|
||||
_assert_tensor_shape(d_loss, [], 'Discriminator Loss')
|
||||
_assert_tensor_shape(d_loss, [], 'Generator Loss')
|
||||
|
||||
|
||||
@test_safe
|
||||
def test_model_opt(model_opt, tf_module):
|
||||
with TmpMock(tf_module, 'trainable_variables') as mock_trainable_variables:
|
||||
with tf.variable_scope('discriminator'):
|
||||
discriminator_logits = tf.Variable(tf.zeros([3, 3]))
|
||||
with tf.variable_scope('generator'):
|
||||
generator_logits = tf.Variable(tf.zeros([3, 3]))
|
||||
|
||||
mock_trainable_variables.return_value = [discriminator_logits, generator_logits]
|
||||
d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
logits=discriminator_logits,
|
||||
labels=[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]))
|
||||
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
logits=generator_logits,
|
||||
labels=[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]))
|
||||
learning_rate = 0.001
|
||||
beta1 = 0.9
|
||||
|
||||
d_train_opt, g_train_opt = model_opt(d_loss, g_loss, learning_rate, beta1)
|
||||
assert mock_trainable_variables.called,\
|
||||
'tf.mock_trainable_variables not called'
|
||||
|
||||
|
Loading…
Reference in New Issue