In this lab, we are going to introduce an unsupervised learning model: Generative adversarial network(GAN)
GAN has two main components in the model, generator and discriminator. Discriminator tries to discriminate real data from generated data and generator tries to generate real-like data to fool discriminator. The training process alternates between optimizing discriminator and optimizing generator. As long as discriminator was smart enough, it can lead generator to go toward the manifold of real datas.
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # disable warnings and info
import tensorflow as tf
import tensorflow.keras as keras
import imageio
import moviepy.editor as mpy
SAMPLE_COL = 16
SAMPLE_ROW = 16
SAMPLE_NUM = SAMPLE_COL * SAMPLE_ROW
IMG_H = 28
IMG_W = 28
IMG_C = 1
IMG_SHAPE = (IMG_H, IMG_W, IMG_C)
BATCH_SIZE = 5000
Z_DIM = 128
BZ = (BATCH_SIZE, Z_DIM)
BUF = 65536
DC_LR = 2.5e-04
DC_EPOCH = 256
W_LR = 2.0e-04
W_EPOCH = 256
WClipLo = -0.01
WClipHi = 0.01
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)
tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit = 10000)])
DCGAN is short for Deep Convolutional Generative Adversarial Networks. It is a paper that doing well on image task, its architecture increase training stability and quality of generated sample. In this lab, we will modify the code of DCGAN and demo the training of DCGAN on MNIST dataset.
Some suggestions in DCGAN(referenced from paper):
# Load images, discard labels
(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
iTrain = train_images.reshape(-1, 28, 28, 1).astype(np.float32)
# Normalizing the images to the range of [0., 1.]
iTrain = iTrain / 255.0
dsTrain = tf.data.Dataset.from_tensor_slices(iTrain).shuffle(BUF).batch(BATCH_SIZE, drop_remainder=True)
# Utility function
def utPuzzle(imgs, row, col, path=None):
h, w, c = imgs[0].shape
out = np.zeros((h * row, w * col, c), np.uint8)
for n, img in enumerate(imgs):
j, i = divmod(n, col)
out[j * h : (j + 1) * h, i * w : (i + 1) * w, :] = img
if path is not None : imageio.imwrite(path, out)
return out
def utMakeGif(imgs, fname, duration):
n = float(len(imgs)) / duration
clip = mpy.VideoClip(lambda t : imgs[int(n * t)], duration = duration)
clip.write_gif(fname, fps = n)
def GAN(img_shape, z_dim):
# x-shape
xh, xw, xc = img_shape
# z-shape
zh = xh // 4
zw = xw // 4
# return Generator and Discriminator
return keras.Sequential([ # Generator
keras.layers.Dense(units = 1024, input_shape = (z_dim,)),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.Dense(units = zh * zw << 8), # zh * zw * 256
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.Reshape(target_shape = (zh, zw, 256)),
keras.layers.Conv2DTranspose(
filters = 32,
kernel_size = 5,
strides = 2,
padding = "SAME"
),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.Conv2DTranspose(
filters = xc,
kernel_size = 5,
strides = 2,
padding = "SAME",
activation = keras.activations.sigmoid
),
]), keras.Sequential([ # Discriminator
keras.layers.Conv2D(
filters = 32,
kernel_size = 5,
strides = (2, 2),
padding = "SAME",
input_shape = img_shape,
),
keras.layers.LeakyReLU(),
keras.layers.Conv2D(
filters = 128,
kernel_size = 5,
strides = (2, 2),
padding = "SAME"
),
keras.layers.BatchNormalization(),
keras.layers.LeakyReLU(),
keras.layers.Flatten(),
keras.layers.Dense(units = 1024),
keras.layers.BatchNormalization(),
keras.layers.LeakyReLU(),
keras.layers.Dense(units = 1),
])
s = tf.random.normal([SAMPLE_NUM, Z_DIM])
DC_G, DC_D = GAN(IMG_SHAPE, Z_DIM)
optimizer_g = keras.optimizers.Adam(DC_LR)
optimizer_d = keras.optimizers.Adam(DC_LR)
cross_entropy = keras.losses.BinaryCrossentropy(from_logits = True)
def DC_G_Loss(c0):
"""
c0: logits of fake images
"""
return cross_entropy(tf.ones_like(c0), c0)
def DC_D_Loss(c0, c1):
"""
c0: logits of fake images
c1: logits of real images
"""
l1 = cross_entropy(tf.ones_like(c1), c1)
l0 = cross_entropy(tf.zeros_like(c0), c0)
return l1 + l0
@tf.function
def DC_D_Train(c1):
z = tf.random.normal(BZ)
with tf.GradientTape() as tp:
c0 = DC_G(z, training = True)
z0 = DC_D(c0, training = True)
z1 = DC_D(c1, training = True)
lg = DC_G_Loss(z0)
ld = DC_D_Loss(z0, z1)
gradient_d = tp.gradient(ld, DC_D.trainable_variables)
optimizer_d.apply_gradients(zip(gradient_d, DC_D.trainable_variables))
return lg, ld
@tf.function
def DC_G_Train(c1):
z = tf.random.normal(BZ)
with tf.GradientTape() as tp:
c0 = DC_G(z, training = True)
z1 = DC_D(c1, training = True)
z0 = DC_D(c0, training = True)
lg = DC_G_Loss(z0)
ld = DC_D_Loss(z0, z1)
gradient_g = tp.gradient(lg, DC_G.trainable_variables)
optimizer_g.apply_gradients(zip(gradient_g, DC_G.trainable_variables))
return lg, ld
# ratio of training step D:G = 5:1
DCTrain = (
DC_D_Train,
DC_D_Train,
DC_D_Train,
DC_D_Train,
DC_D_Train,
DC_G_Train
)
DCCritic = len(DCTrain)
Let's plot the generated images right after the initialization. It's good to check if there are any unexpacted artifacts in it. For case of DCGAN, we should see checkboard effect in our generated samples if we use fully convolutional layers. As mentioned in this blog post, this will introduce some checkboard effect. If the training is succeed, then this effect can be largely reduced. The blog post used upsampling to replace strided deconvolution in generator. This can cancell off the checkboard effect but have more blury result.
In the following cell, we also plot the original MNIST dataset.
print("Generator Initial Output :")
c0 = DC_G(tf.random.normal((1, Z_DIM)), training = False)
plt.imshow((c0[0, :, :, 0] * 255.0).numpy().astype(np.uint8), cmap = "gray")
plt.axis("off")
plt.show()
print("Discriminator Initial Output : %E" % DC_D(c0).numpy())
Generator Initial Output :
Discriminator Initial Output : -8.492252E-03
dc_lg = [None] * DC_EPOCH #record loss of g for each epoch
dc_ld = [None] * DC_EPOCH #record loss of d for each epoch
dc_sp = [None] * DC_EPOCH #record sample images for each epoch
rsTrain = float(BATCH_SIZE) / float(len(iTrain))
ctr = 0
for ep in range(DC_EPOCH):
loss_g_t = 0.0
loss_d_t = 0.0
for batch in dsTrain:
loss_g, loss_d = DCTrain[ctr](batch)
ctr += 1
loss_g_t += loss_g.numpy()
loss_d_t += loss_d.numpy()
if ctr == DCCritic : ctr = 0
dc_lg[ep] = loss_g_t * rsTrain
dc_ld[ep] = loss_d_t * rsTrain
out = DC_G(s, training = False)
img = utPuzzle(
(out * 255.0).numpy().astype(np.uint8),
SAMPLE_COL,
SAMPLE_ROW,
"imgs/dc_%04d.png" % ep
)
dc_sp[ep] = img
if (ep + 1) % 32 == 0:
plt.imshow(img[..., 0], cmap = "gray")
plt.axis("off")
plt.title("Epoch %d" % ep)
plt.show()
utMakeGif(np.array(dc_sp), "imgs/dcgan.gif", duration = 2)
MoviePy - Building file imgs/dcgan.gif with imageio.
We plot the training loss of discriminator and generator. We can see that we can't tell the model has converged or not from the training loss. Both curves oscillate at certain levels and it's independent with the quality of the generated images. So in practice, we plot the generated samples to monitor the training process. And due to this inconvenience, there are some works proposed in 2017 tried to solved it.
plt.plot(range(DC_EPOCH), dc_ld, color = "blue", label = "Discriminator Loss")
plt.plot(range(DC_EPOCH), dc_lg, color = "red", label = "Generator Loss")
plt.legend(loc = "upper right")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("DCGAN Training Loss")
plt.show()
There are some theoretical deficiencies in vanilla GAN. Wasserstein GAN (WGAN) was proposed to solve these problems. Apart from the original paper, this and this may help you understand the motivation of WGAN. We'll skip the theory in this tutorial and jump directly to the implementation. From the engineering perspective, the following are modification compared with origin GAN.
Details of the algorithm are shown below.
WG, WD = GAN(IMG_SHAPE, Z_DIM)
optimizer_g = keras.optimizers.RMSprop(W_LR)
optimizer_d = keras.optimizers.RMSprop(W_LR)
@tf.function
def WGTrain(c1):
z = tf.random.normal(BZ)
with tf.GradientTape() as tpg:
c0 = WG(z, training = True)
z1 = WD(c1, training = True)
z0 = WD(c0, training = True)
ld = tf.reduce_mean(z0)
lg = - ld
ld = ld - tf.reduce_mean(z1)
gradient_g = tpg.gradient(lg, WG.trainable_variables)
optimizer_g.apply_gradients(zip(gradient_g, WG.trainable_variables))
return lg, ld
@tf.function
def WDTrain(c1):
z = tf.random.normal(BZ)
with tf.GradientTape() as tpd:
c0 = WG(z, training = True)
z1 = WD(c1, training = True)
z0 = WD(c0, training = True)
ld = tf.reduce_mean(z0)
lg = - ld
ld = ld - tf.reduce_mean(z1)
gradient_d = tpd.gradient(ld, WD.trainable_variables)
optimizer_d.apply_gradients(zip(gradient_d, WD.trainable_variables))
# clipping
for v in WD.trainable_variables:
v.assign(tf.clip_by_value(v, WClipLo, WClipHi))
return lg, ld
WTrain = (
WDTrain,
WDTrain,
WDTrain,
WDTrain,
WDTrain,
WGTrain
)
WCritic = len(WTrain)
Then we train the WGAN and visualize the training as before.
wlg = [None] * W_EPOCH #record loss of g for each epoch
wld = [None] * W_EPOCH #record loss of d for each epoch
wsp = [None] * W_EPOCH #record sample images for each epoch
rsTrain = float(BATCH_SIZE) / float(len(iTrain))
ctr = 0
for ep in range(W_EPOCH):
lgt = 0.0
ldt = 0.0
for c1 in dsTrain:
lg, ld = WTrain[ctr](c1)
ctr += 1
lgt += lg.numpy()
ldt += ld.numpy()
if ctr == WCritic : ctr = 0
wlg[ep] = lgt * rsTrain
wld[ep] = ldt * rsTrain
out = WG(s, training = False)
img = utPuzzle(
(out * 255.0).numpy().astype(np.uint8),
SAMPLE_COL,
SAMPLE_ROW,
"imgs/w_%04d.png" % ep
)
wsp[ep] = img
if (ep+1) % 32 == 0:
plt.imshow(img[..., 0], cmap = "gray")
plt.axis("off")
plt.title("Epoch %d" % ep)
plt.show()
utMakeGif(np.array(wsp), "imgs/wgan.gif", duration = 2)
MoviePy - Building file imgs/wgan.gif with imageio.
plt.plot(range(W_EPOCH), wld, color = "blue", label = "Discriminator Loss")
plt.plot(range(W_EPOCH), wlg, color = "red", label = "Generator Loss")
plt.legend(loc = "upper right")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("WGAN Training Loss")
plt.show()
Although Wasserstein GAN (WGAN) made progress toward stable training of GANs, still fail to converge in some settings. In this lab, you are required to implement Improved Wasserstein GANs, which is a milestone for GANs research.
We will show the training result of Improved WGAN below, which indicates that, compared to WGAN, Improved WGAN has a much better performance. It generates recognizable digits much faster during the training process.
Details of the algorithm are shown below.