Incidental Polysemanticity

Update: accepted to the Re-Align and BGPT workshops @ ICLR 2024!


One obstacle to interpreting neural networks is polysemanticity. This where a single neuron represents multiple features.

If there are more features than neurons, it might be "necessary" for the model to be polysemantic in order to represent everything. This is the notion of "superposition" from Elhage et al. (2022).

Of course, a clear solution would be to train a model large enough to have at least one neuron per feature. However, what we find in "What Causes Polysemanticity?" is that polysemanticity can happen "incidentally" in the training process, even if we have a large enough model.

Hypothesis

When we initialise a neural network, the weights are random. Some neurons will be more correlated with some features than other neurons, just by chance.

As training happens, the optimiser pushes the weights of those correlated neurons in the direction of the features, so that they can represent the features well. If there is pressure for sparsity, only one neuron will represent each feature.

Most likely, this is the neuron which was most correlated to the feature at initialisation. If it happened to be the most correlated neuron for multiple features, then it would end up representing multiple features.

In that case, we get polysemanticity "incidentally".

Experiments

To test this hypothesis, we consider the simplest possible setup: over-parameterised autoencoders similar to those in Elhage et al. (2022). That is:

y=ReLU(WWTx)

where xn and Wn×m, with mn. These models were trained on the standard basis vectors ein. To induce sparsity, we take two separate approaches: introducing 1 regularisation for the model weights and adding noise after the hidden layer.

We find that:

  1. 1 regularisation induces sparsity
  2. Some types of noise can induce sparisty
  3. The amount of incidental polysemanticity can be predicted
  4. It is due to the weight initialisations

Sparsity from Regularisation

Result 1: 1 regularisation induces a winner-takes-all dynamic at a rate proportional to the regularisation parameter.

Loss with 1

Taking the ith basis vector as the input, the output of the model is:

(ReLU(W1,:·Wi,:),,ReLU(Wn,:·Wi,:))

With 1 regularisation, our loss function becomes:

=i(ReLU(Wi,:·Wi,:)1)2+iji(ReLU(Wi,:·Wj,:)0)2+iλWi,:1

where λ is the regularisation parameter.

Thus, gradient descent pushes us in the direction:

Wi,:=(4Wi,:221)Wi,:4jiReLU(Wi,:·Wj,:)Wj,:λsign(Wi,:)

Forces for sparsity

We can split this into the three terms:

  1. Feature benefitThis is why we can use tied weights WT and W, since it should push the ith column of Wenc and the ith row of Wdec to have dot product of 1, even if initialised as untied weights.: pushes Wi,: to be unit length
  2. Interference: pushes Wi,: to be orthogonal to other Wj,:
  3. Regularisation: pushes non-zero Wi,: to be sparse

The feature benefit and regularisation forces are in competition. Since the regularisation force has a constant value while the feature benefit force is proportional to the length of Wi,:, the regularisation force will dominate for small Wi,: and the feature benefit force will dominate for large Wi,:.

Thus Wi,k will be pushed to 0 if it is below some threshold θ. Leaving the derivations for the paper, we find the net effect on sparsity is proportional to how far the weight is from this threshold:

d|Wi,k|dt=(1Wi,:22)(|Wi,k|θ)

Speed of sparsification

In fact, we can quantify the speed at which sparsity is induced. Again, leaving the maths for the paper, it follows from above that the 1 norm at time t is inversely proportional to λt:

Wi,:(t)1=1Θ(1m+λt)

Since Wi,:(t)1 throughout training, the m non-zero values at any particular point should have a magnitude of around 1m, and so Wi,:(t)1m1m=m. Thus the weights should go from Θ(m) to Θ(1λt) as t1λm and Θ(1) as t1λ. This is exactly what we see:

Sparsity from Noise

Result 2: noise drawn from a distribution with excess kurtosis induces sparsity.

Implicit regularisation

In practice, we don't get sparsity in neural networks because of 1 regularisation. A more realistic cause is via noise in the hidden layer, a la Bricken et al. (2023):

y=ReLU(W(WTx+ξ))

for ξm and ξ~𝒟. Having removed the regularisation term, the loss is rotationally symmetric with respect to the hidden layer (excluding the noise). That means there is no privileged basis, and no particular reason for features to be represented by a single neuron, as opposed to a linear combination of features.

However, if we take the noise into account, we find that one term in the loss is:

Wi,:44(μ4σ43)

where μ4 is the fourth moment of 𝒟, and μ4σ43 is the excess kurtosis.

Thus, when 𝒟 has negative excess kurtosis, this component of the loss will push to increase Wi,:4. Due to the constraint that Wi,:2=1 from before, this incentivises Wi,j=±1 for some j and Wi,k=0 for kj.

Bernoulli vs. Gaussian noise

Bernoulli noise of either ±σ has excess kurtosis of 2, while Gaussian noise has excess kurtosis of 0. Thus we would expect to see the former to induce sparsity (and a fourth norm of 1), while the latter would not. As expected:

Counting Hypothesis

Result 3: the number of polysemantic neurons can be predicted by a simple combinatorial model.

Possible model solutions

Recall that the output of the autoencoder is:

(ReLU(W1,:·Wi,:),,ReLU(Wn,:·Wi,:))

That is, we would like Wi,:·Wi,:=1 and Wi,:·Wj,:0 for ij.

One way to satisfy this is if Wi,: equals the ith standard basis vector fim. This is because WWT will just be the identity matrix, and so ReLU(WWTei)=ei.

However, when m>n, we have another solution. Take m=4 and n=2, and consider the following weight matrix:

W=[10001000]

We see that ReLU(WWTe1)=(1,0) and ReLU(WWTe2)=(0,1), which still satisfying the constraints. This is a polysemantic solution!

Interference force

Knowing that it is possible, we can now ask why it occurs. One force we haven't considered in detail is the interference force:

jiReLU(Wi,:·Wj,:)Wj,:

up to constants.

This is only non-zero if the angle between Wi,: and Wj,: is less than π2. Thus, we can simplify by only considering its effect in the direction of Wi,:. It has magnitude:

(jiReLU(Wi,:·Wj,:)Wj,:)·Wi,:=jiReLU(Wi,:·Wj,:)2

This means that the interference force should be weak at the start when Wi,:·Wj,: are mean zero, and only kick in Wi,: and Wj,: share some non-zero coordinate k. If Wi,k and Wj,k have the same sign, the interference force will push at least one to zero. Thus, we would only expect that polysemanticity occurs when they have the opposite sign, since the ReLU will zero out the negative term and they will maintain their non-zero value.

Balls and bins

With (n2) pairs of features, 1m probability of the most significant neuron being the same for both and 12 probability of them having the opposite sign, we would predict there to be (n2)12mn24m polysemantic neurons, which we find:

Initialisation Hypothesis

Result 4: polysemanticity occurs in m>n models due to weight initialisations.

If initialisations were the cause of polysemanticity, the weights at the start of training should be correlated with the weights at the end. That is, the diagonals of WWT should be larger than the off-diagonals. As predicted:

Future Work

The incidental polysemanticity we have discussed in our work is qualitatively different from necessary polysemanticity, because it arises from the learning dynamics inducing a privileged basis. Furthermore, the fact that it occurs all the way up to m=n2 suggests that making the model larger may not solve the problem.

We look forward to future work which investigates this phenomenon in more fleshed-out settings, and which attempts to nudge the learning dynamics to stop it from occurring.

If you want to play with different configurations of our models or reproduce our plots, check out the code repository!