"Grokking": Generalisation Beyond Overfitting

An unusual result on small algorithmic datasets.

27th June 2023

Introduction & Motivation

This post explores the ability of overparameterised networks to be able to find the general solution to small algorithmic datasets if trained for a long time. It is based on the work in Nanda et al (2023), which also has a very nice walkthrough (with code). See also Power et al (2022) . This work tries to explore the "how" in a neural network, by taking a well understood problem in a small network and exploring how it gets to the solution. This is a subset of "mechanistic interpretability".

The problem chosen was addition modulo p. For a small p, this is a nice problem because you can use every possible problem in your test/train set (for example, with p=113 there are ~12,700 ordered sum pairs).

If you take a single layer transformer there are enough parameters to simply learn the test set, and this happens in ~200 epochs. However, if you leave it to train long enough, it will eventually solve the general case. This post replicates some of the above paper, and explores some of the parameter space.

The code used for the content of this blog can be found here.

Figure 1: The train (blue) and test (orange) loss on a modular addition problem. 100% accuracy occurs at ~0.01 loss. However, if the network is trained for much longer it learns to generalise - leading to a much reduced loss on the test data also.

Nanda et al (2023)

This paper looks to set up a network to produce a "grokking" effect on a small arithmatic problem, modular addition. This problem is nice due to the bounded size of the problem space, while still having some reasonable depth. 

The network is set up with a 30% test/train split to give it enough to work with while still having a sizable generalisation target. The optimiser has a high weight decay to  give the network an incentive to generalise. The training is smoothed by running the training as a single batch.

What the authors found is shown to the right. The generalised network had learnt to sum rotations on the unit circle in order to work out the sum. Sine and cosine are approximated for some multiple w of a and b. Triginometry identies are used to find the sine and cosine sums (a non linear operation). The network then derives the c such that the total sum is n*2pi (the second non linear operation).  This solution is repeatably found, and is due to the reliance on linear operations, which a neural network is good at.

Figure 2: A figure from Nanda et al (2023) showing the method derived by the network to sum the numbers. Note that the algorithm runs bottom to top to line up with a standard drawing of a transformer from bottom up.

Investigating behaviour across seeds

The "nominal" prime used was 113, which is the same as is used in the paper. It was run across several random test/train split seeds to investigate the behaviour.

In all cases, the network successfully grokked. Figure 3 shows te test dataset log loss across various seeds. The loss always initially increases on this data, as the network memorises the train dataset. The initial (random) loss is log(113) ~= 4.7. However, this eventually falls, with a log loss of ~-2 showing a 100% accuracy on the test data. The amount of training required varies by over a factor of two, from 4.5k to 10.5k epochs. Each seed took ~2.5mins on a single GPU to train.

Figure 3: The loss on the test datasets across seeds for the same prime.

Network behaviour across primes

The next question to ask is whether this behaviour is consistent across the choice of prime numbers. (Note that primes are chosen as addition modulo a non-prime looks slightly different from a number theoretical basis).

The primes were chosen across a ~factor of two. The larger primes see more training per epoch as the training fraction was kept consistent at 0.3. The amount of training data goes with the square of the size of the prime, while the network size only scaled (approximately) with the prime.

You can see in Figure 4 that p=73 did not consistently grok in 20k epochs. It would eventually solve for the general, but with more training and less consistently. Figure 5 then compares some large primes to some small ones to show that the greater availability of training data and compute makes a big difference to the grokking behaviour. This did not have a scaling rule due to the wide variability with the choice of random seed. 

The fact that the amount of data limits the smaller networks is shown in Figure 6, in which a training fraction of 0.5 allows grokking with p=43 while a fraction of 0.3 does not.

Figure 4: Grokking across a range of prime numbers, and with two different seeds.

Figure 5: A comparison of the grokking behaviour of large and small primes.

Figure 6: The log loss on the test dataset for p=43 and a training fraction of 0.3 (left) and 0.5 (right). This shows that it is data starvation which prevents the network grokking at low primes.

Deriving the Key Frequencies

The network is using geometric rotation at a key frequency to add the numbers. To examine the key frequency used, we can look at the weights embedding matrix. This has a noticable structure to it when it has fully generalised a solution. The structure can be more clearly seen using a singular value decomposition as shown in Figure 7.

This tells us that there are some key directions in the weights matrix. We can now examine these by using a fourier analysis.

Figure 8 (left) shows a matrix of fourier bases, with cosine and sine on respective axes. This allows us to extract the primary frequencies. Figure 8 (right) shows the resultant extracted frequencys. This shows a clear signal on significant directions, and will form the basis of the following analysis.

Figure 7: The U matrix of the SVD of the weights embedding matrices for an ungenerlilsed (left) and generalised (right) transformer. You can see that there is a definitive structure in the generalised network, but only for the first 8 or so terms. 

Figure 8: A fourier basis matrix (sine and cosine across each axis), and this applied to the weights embedding matrix, You can clearly see key frequencies have been picked out.

Figure 9: Summing the fourier matrix applied to the weights embedding. Left for a network which hits 100% on the train set, but has not yet grokked. Right for the same network after it has grokked. The key frequencies have been extracted, interestingly always as integer pairs.

The key frequencies found by the networks

The network chooses ~4 frequencies to be the basis of this analysis. The obvious question is whether there is any pattern to these frequencies? The answer seems to be no. Figure 10 shows the extracted frequencies for several seeds at p=113. 19 is repeated several times, as are 49 and 53. But the distribution looks roughly  evenly distributed without clear patterns. However, it is possible that there is an advantage to the network finding a frequency at ~0.5p, this may be a good way of getting a consistent signal.

Figure 10: Key frequencies extracted by the network across seeds. This "random" behaviour was observed across primes also.