Yet Another Generative Adversarial Network (GAN) Guide in Keras, with MNIST testing example

Generative Adversarial Network (GAN) is a brilliant idea. By combining the power of generative model and a classifier and training them against each other, GAN can eventually generate samples that could fool classifier and even human.

We can start coding by writing a simple generator. It takes noises as input and returns a picture, which is of size 28*28*1 in MNIST. The noise could be of any size theoretically; however, if the dimension of noises is too small, the generating power is obviously lacking (imaging the generator with 1-dimension noise input). Say the noise input is of 32 dimensions, and for each dimension, a noise input is sampled uniformly from -1 to 1. For the next few layers, we reshape the noise to the size of 7*7*(*), then upsample it until we have an output of size 28*28*(*), and finally compress the output into 28*28*1 and add an activation function of ‘tanh’ in order to return an image of range [-1, 1] in each pixel. A simplified version of generator (with 4 layers that convert [64] -> [7*7*128] -> [14*14*64] -> [28*28*32] -> [28*28*1]) looks like this:

    # 64 -> 7*7*128 -> 14*14*64 -> 28*28*32 -> 28*28*1
    # Specify the shape of output for generator
    height = 28
    width = 28
    scale = 4
    init_shape = (int(np.ceil(width / scale)),
                  int(np.ceil(height / scale)),
                  32 * scale)

    # Noise dimension is 64
    entrance = Input((64,))
    temp = entrance

    # Dense layer and reshape for image generation
    temp = Dense(
    temp = BatchNormalization(momentum=batch_norm_momentum)(temp)
    temp = LeakyReLU(alpha=relu_alpha)(temp)
    temp = Reshape(init_shape)(temp)
    temp = Dropout(dropout)(temp)

    # Use UpSampling2D and Conv2D for up sampling
    for layer in range(2):
        temp = UpSampling2D()(temp)
        temp = Conv2D(
            filters=32 * 2 ** (1 - layer),
        if batch_norm:
            temp = BatchNormalization(momentum=batch_norm_momentum)(temp)
        temp = LeakyReLU(alpha=relu_alpha)(temp)
        temp = Dropout(dropout)(temp)

    # Compress the layers into output shape
    temp = Conv2D(

    temp = Activation('tanh')(temp)
    generator = Model(entrance, temp)

Now we are done the generator, the discriminator (classifier that differentiates generated images from the real ones) should be even easier.  It takes a 28*28*1 image and gives out a probability of the input being real ([28*28*1] -> [14*14*64] -> [7*7*128] -> [4*4*256] ->[4096] ->[1]). Note that the discriminator should not be too powerful, otherwise it will simply crush the generator before it even begins to learn, which results in an early failure.

    entrance = Input((28, 28, 1))
    temp = entrance

    # Alternatively, could use average pooling for down sampling
    for layer in range(3):
        temp = Conv2D(
            filters=64 * 2 ** layer,
        if batch_norm:
            temp = BatchNormalization(momentum=batch_norm_momentum)(temp)
        temp = LeakyReLU(alpha=relu_alpha)(temp)
        temp = Dropout(dropout)(temp)

    # Flatten the convolutional net and use sigmiod to classify the input
    temp = Flatten()(temp)
    temp = Dense(1, activation='sigmoid')(temp)
    discriminator = Model(entrance, temp)

Before we head to the training, we need the discriminator model and adversarial model ready. The first one is simply our discriminator, trained with the loss function of binary cross entropy; and the second one is built by stacking generator and discriminator together, with the weights of discriminator frozen. This is because during the training of generator, we are going to feed adversarial model with fake images and real labels, and we do not want to mess up the discriminator in this phase. The code looks like this:

# Build discriminator_model = Sequential([self._discriminator])
    optimizer=Adam(lr=0.0002, decay=1e-06),

# Build adversarial_model:
self._discriminator.trainable = False = Sequential([self._generator, self._discriminator])
    optimizer=Adam(lr=0.0001, decay=1e-06),

Now we have both models ready, let’s proceed to the training part. For each epoch, we are going to feed the discriminator a bunch of real images from MNIST, and then a bunch of fake ones from our currently crapy generator (with labels 1 and 0 separately). I have not tested too much on this part yet, but I have heard that mixing real and fake samples could be bad for the discriminator. The training of discriminator looks like this:

# Select a random half batch of images
indices = np.random.randint(
    0, training_sample_num, half_batch_size)
half_batch_samples = self._training_samples[indices]

# Generate a half batch of fake images from noise
half_batch_noise = self.get_noise(half_batch_size)
half_batch_mimics = self._generator.predict(half_batch_noise)

# Train the discriminator model wih half real and half fake samples
# Do not mix the real images with the fake ones
d_loss_samples =
    half_batch_samples, np.ones((half_batch_size, 1)))
d_loss_mimics =
    half_batch_mimics, np.zeros((half_batch_size, 1)))
d_loss = np.add(d_loss_samples, d_loss_mimics) / 2

And in the same epoch, we would like to train the adversarial model with generated images along with the real labels, which can make our generator return more realistic images in the eyes of the discriminator, and the code is even simpler:

batch_noise = self.get_noise(batch_size)
g_loss =
    batch_noise, np.ones((batch_size, 1)))

That’s pretty much all the things we need to do for a GAN to be trained with MNIST. Finally, we can pick a number and feed it into GAN. Let’s pick number ‘5’ from MNIST for its non-symmetrical appearance, and use here are the results:

These three images are obtained with uniform noise and looks… okay I guess. And here are some results using noise with Wigner semicircle distribution:

And digits definitely look clearer and more realistic. I think Gaussian noise will also work well this GAN in this case. Feel free to check out all the code, including some potential optimization at CycleGAN.


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s