Training a GAN: Ensuring Stable Convergence
(or all of the things I did wrong during training)
26th October 2022
MNIST is a dataset of 70,000 examples of handwritten digits as shown to the right. These are 28x28 pixel images taken from the American Census Bureau. It is useful as it comes preprocessed, with each image centralised and in the correct orientation. It is possible to extend this dataset by including transformations, but I did not opt for this in the work described below. The code used can be found on my GitHub here.
Training an MNIST classifier
Owing to the preprocessed dataset it is relatively simple to write an MNIST classifier. This was trained on 60,000 of the images, and tested on the remaining 10,000. Training was perfromed using Google Colab, although a CPU can also produce a reasonable classifier for this problem in a reasonable amount of time. The network mainly consisted of just two convolutional blocks and two linear layers.
Figure 1: A sample of the MNIST dataset
The initial training of this classifier is shown in Figure 2. Shown is the loss and also classification progress on the test dataset. The model quickly gets to >90% classification accuracy, reaching 97.6% after two epochs. Figure 3 shows a random sample of model mistakes at this point. The errors it makes shows it is not yet at human level classification, but do show up recognisable ambigueties in the handwritten digits.
Figure 3: Sample of incorrect classifications after 2 training epochs
Figure 2: Training convergence and test set performance after 2 epochs (936 batches)
Figure 4: Classifier loss over 100 epochs
To see how good this simple classifier could get, I ran the model for 100 iterations. This took 213s on a single GPU. The loss is shown in figure 4. The additional training time shows a steady improvement in loss, with a lot of variation batch-batch. This achieved a 99.7% correct classification rate on the test set. This means that only 15 images were incorrectly classified. A sample is shown in Figure 5. This now shows a mix of visible mistakes and unclear digits, and could probably be described as at a human level. However, it is likely that the classifier will make different types of mistakes to a human.
Figure 5: A selection of the incorrect classifications after 100 epochs.
Training an MNIST GAN
A Generative Adversarial Network (GAN) was first proposed in 2014 by Goodfellow et al. It is actually composed of two networks, a generator and a discriminator. The generator is a network which tries to create images representitave of the dataset without ever being shown them. It does this by feeding its output into the discriminator, which tries to identify whether a given image is real or generated. The score awarded to the discriminator on the fake images is the negative of the score awarded to the generator. This puts the two networks in direct competition, and they are trained together. Once the generator is sufficiently trained, you can use its output to create new images which should be representitave of the training dataset.
GANs are notorously difficult to train. I will outline some of the mistakes and fixes I made in the course of training my GAN. The basic structure of my generator was 4 convolutional layers turning a 1x100 noise vector into a 28x28 image. The discriminator was three convolutional layers taking a 28x28 image and outputting a single number. The higher the number, the more "real" it determined the image.
Figure 6: The first training scores of my GAN
First Steps
Figure 6 shows the initial results of training. Essentially, the discriminator got too good too quickly. The gradients on its loss then went too low for the generator to successfully train, and an inadequate equilibrium was reached. Fixing this is a game of tuning the learning rate, the number of discriminator training iterations per generator training iteration, and the model architectures themselves. In this case the generator would just output noise.
Having performed some tuning I managed to achieve the output shown in Figure 7. Although this is an improvement on noise, it appears that I have saturated my output. Adding batch normalisation to the generator was sufficient to create some periodic geometries. This is to be expected from a network using convolutions, as the image is examined in smaller areas. At this point, and playing around with the batch size a bit, the next step was to scale up the training times!
Figure 7: Early generated images
Figure 8: Addition of batch normalisation in generator layers
Using Scale
Is scale always the answer? In this case, it got me a lot of the way there. Figure 9 shows the generation of what is approaching recognisable digits. However, the colors are oddly muted, and the right and lower sides have strange artifacts. Figure 10 shows the loss of the networks reaching a plateau at this point.
Figure 9: 35 training epochs... in gif form!
Figure 10: Loss of both networks over 35 training epochs.
Normalise Your Inputs...
In hindsight this next step is obvious, but for a network with a tanh output, the pixels are always kept in the range [-1,1]. However, the MNIST dataset is in the range [0,255]. I was normalising this to [0,1]. The discriminator is therefore seeing a different distribution across the real and fake images, leading to non-convergence. To demonstrate the effect I show the outputs of 3 netoworks differing only in their input nomalisation in Figure 11.
data = (torch.tensor(data).float())
data = (torch.tensor(data).float())/255
data=(torch.tensor(data).float()-128)/128
Figure 11: The effects of input normalisation on model outputs
Padding
The correct input scaling gets us to the third image in Figure 11. More training makes the digits more distinct. However, we still get the problem of the artifacts on the right and lower sides. It seems like the generator is unable to see these areas of the generated images. My solution was to add padding to the layers in the discriminator convoutional network. Padding extends the convolution past the bounds of the image. Figures 12 and 13 show the effect of 1 and 2 pixel padding on teh first and second convolutional layers. This successfully removes the corner distortion from the generated images.
Figure 12: Adding padding of 1 pixel
Figure 13: Padding of 2 pixels
GAN Outputs
With all of these modifications I can now train a GAN which can generate recognisable digits. I did not have to perform any tuning to remove degenerative solutions. Figure 14 shows the animation of the training. The model converges and then becomes relatively stable. Looking at the raw outputs, the discriminator evaluates both the real and fake images at ~-2.4, demonstrating that it is unable to discriminate the two. To demonstrate the generation of digits representitave of the training data, Figure 15 shows the distribution of 10,000 digits, as classified by the trained classifier. We see that the generation is roughly even across digits. If the generated distribution was not roughly even, a discriminator would be able to achieve better results simply by learning the non-even generation. 3 is generated the least, and 7 the most.
Figure 14: Output of the GAN after 150 epochs
Figure 15: Classified distribution of generated digits.