Archives
Categories
Blogroll
I'm now working through chapter 4 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)". It's the chapter where we put together the pieces that we've covered so far and wind up with a fully-trainable LLM, so there's a lot in there -- nothing quite as daunting as fully trainable self-attention was, but still stuff that needs thought.
Last time around I covered something that for me seemed absurdly simplistic for what it did -- the step that converts the context vectors our attention layers come up with into output logits. Grasping how a single matrix multiplication could possibly do that took a bit of thought, but it was all clear in the end.
This time, layer normalisation.
Why do we need it?
Looking at my mental model of an LLM at this point, it works like this:
- We tokenise our input and convert the tokens into token embeddings and position embeddings, and then combine those for our input embeddings.
- Those input embeddings are fed into a multi-head attention layer, which spits out context vectors -- a kind of representation of what each token means in the context of the other tokens to its left.
- Those context vectors are fed into a second MHA layer, which does the same but at "a higher level" -- kind of like the way that each successive layer in a vision NN like a CNN has a higher level of abstraction: the first layer detects edges, the next corners, and so on up to the nth layer that can detect dogs' faces (though of course the abstractions the LLM is building are unlikely to be as human-comprehensible as those).
- After a number of layers, we have context vectors for each token that, in embedding space, point in the direction of the most probable next one, given the token itself and the ones to its left.
- We run those final context vectors through a single unbiased linear layer with no activation function -- a simple matrix multiplication -- to project them from embedding space to vocab space, which gives us our logits.
That actually feels like a pretty complete system, and I rather suspect that in principle, an LLM might work just with those calculations.
But -- as I said in the last post -- the code in the book is "folding, spindling and mutilating" the context vectors. Each attention layer isn't just the pure attention mechanism, made causal, with dropout and run multiple times in parallel. A few other things are happening on top of that.
As far as I understand it, these extra steps are to make them trainable in the real world. My naive description above, without those steps, might work in principle, but if you actually tried to train a model that worked that way, you'd not get great results.
Layer normalisation is one of the tricks used to make the LLM trainable, and I think it's pretty easy to understand, at least at a basic level.
What is it?
Imagine we have a matrix of context vectors that came out of our multi-head attention module. Disregarding batches, it's an n×demb matrix -- one row for each of our n tokens in the input, each row being an embedding vector of length demb.
For layer normalisation, what we do is essentially go through each row, and for each one adjust the numbers so that it has a mean of 0, and a variance of 1. The mean is (of course) the sum of all the numbers in the vector divided by its length, and the variance is a measure of how much the numbers differ from the mean (more about that later).
So, what we're doing is a kind of normalisation of the "sizes" of the vectors in a rather abstract sense.
My mental model for this maps it to the way a producer would normalise analog audio signals going into a mixing desk.
Analog signals are messy -- for example, the voltage from a microphone might have a DC bias, which means that instead of oscillating around 0 volts, it could oscillate around 3V or something. Different input signals might have different DC biases, so you would want to fix that first so that you can mix them. Likewise, different signals might be at different levels -- an active microphone might have a much higher voltage than a a passive one, and a direct feed from a synth might be at another different level.
So you'd want to eliminate that DC bias (which maps to setting the mean to zero), and get them all to roughly the same level -- line level (mapping to making the variance one). Once you have that done, you can then set the levels appropriately for the mix (which would involve making some louder, some quieter, and so on -- but at least you're starting from a point where they're consistent).
It's not an exact match for what normalisation is doing -- in particular the line level part -- but I've found it useful.
Why?
Raschka explains that normalisation is required to avoid problems like exploding or vanishing gradients. I think that's worth going through.
I'll assume that you already know how gradient descent works, but it's worth writing a quick overview just to set the stage.
When we're training something with gradient descent, we run some inputs through it and then measure the error -- how far off the result was from what we wanted. If we simplify a bit, for our LLM, if we feed in:
the fat cat sat on the...and we want mat to be the predicted next token, then we create an error function that will be zero if that is what the LLM comes up with, and a number larger than zero if it predicts dog or whatever. (I think that the details of how we write an error function that does this effectively will come in the next chapter. Please also note that I'm grossly simplifying here -- as I covered in the last post, we're actually predicting next tokens for every token in the input rather than just the last one. But that doesn't matter for this bit.)
We then use differentiation to work out, for every parameter, what the gradient is against that error. That means, essentially, what direction would we move it in, and how much, if we wanted to make the error one unit higher. We want to make the error lower rather than higher, of course -- and we don't want to make it exactly one unit lower, we only want to move a small distance in the "error lower" direction. So we multiply the gradient by our learning rate (a number less than one), then subtract it from the parameter so that we move a small distance in the right direction.
Once that has been done for all parameters, we can run another training sample through and repeat until we're done.
The nice thing about libraries like PyTorch is that they abstract away all of the differentiation to work out these gradients completely, and we don't need to do any of the maths ourselves. But, unfortunately, it's a kind of leaky abstraction -- that is, we can't completely forget about what is going on under the hood.
With a multi-layer network like an LLM, the way the gradient calculation works is that it firstly calculates the gradients for the parameters in the last layer -- the ones that directly produced the output we're trying to correct. These gradients are then used as part of the calculations for the penultimate layer, and so on -- back-propagation of the error from the "end" of the network to the start.
So, let's think about vanishing and exploding gradients. These are problems that affect the early layers in the network -- the ones that are processed later during backprop. The gradients that you have for the layers at the "end" of the network are pretty solid, but as they're propagated back to earlier ones, they're multiplied by various numbers -- the parameters themselves, for example -- and as a result, they (of course) get larger or smaller.
It's entirely possible for a particular gradient, as it heads backwards, to approach or become zero. That means it kind of reaches a dead end -- it can't have any effect on the layer where it reaches zero, or on any of the earlier ones. Hence, vanishing gradient.
An exploding gradient is (fairly obviously) the opposite issue -- a gradient gets multiplied by a large number, or a series of large numbers, and becomes so large that it drowns out all of the others.
Remember that when a gradient propagates back from layer n to layer n−1, all of the gradients in layer n−1 will depend on all of the gradients in n, assuming that the network is fully connected between layers. So that means that one huge gradient in layer n will have an oversized effect on all of the gradients in n−1.
So, the point of this layer normalisation is to stop that from happening; to stop our gradients from either vanishing or exploding as they propagate back from the last layer to first.
To use the analog audio signal metaphor from earlier -- we're trying to set things up so that the guitar isn't so loud that we can't even hear the drums, and so that the singer is at least audible enough that we can can hear when she's singing.
How?
So, how do we set the mean to zero and the variance to one for each token's context vector?
The mean
The mean is really easy: we simply work out the existing mean of the vector, and then subtract that from all of the values with a broadcast subtraction. That pretty obviously sets the mean to zero. For example, if we start with the vector (2,4,6):
The mean is:
2+4+63=123=4Subtracting the mean from each of those gives us
(2−44−46−4)=(−202)...and the mean of that is:
−2+0+23=03=0
The variance
To make the variance equal to one, we work out the existing variance, and then divide by its square root (the standard deviation). My initial reaction was "why not divide by the variance itself?".
To answer that, I think it's worth doing a bit of a refresher on how variance and standard deviation work.
To calculate variance, you work out the mean of a vector, then calculate for each element the difference between the element and the mean, then square all of those, then take the average of that.
So, for the vector (2,4,6); we already know from above that the differences from the mean are (−2,0,2).
Squaring those gives us:
(−22,02,22)=(4,0,4)We now work out the mean of that:
4+0+43=83=2.6˙
That's the variance, and the square root of that is called the standard deviation. The squaring is necessary because we want to make the numbers positive -- exactly why we do that rather than using the absolute value is something I'm sure I was taught sometime in the late 80s, but I won't dig into it now. (I'm considering a separate blog post series, something like "high school maths refreshed for older devs digging into machine learning". If you'd be interested, please leave a comment below!)
So, why do we divide by the standard deviation rather than the variance itself?
Let's think about what dividing by the standard deviation does. It essentially converts each element in the vector into a number that says how many standard deviations it is away from the mean. So if something is one standard deviation away from the mean in the original vector, it will have the value 1 if it's one SD greater than the mean, or −1 if it's one SD less that the mean.
It's pretty clear that the standard deviation of such a vector will be one, almost by definition! We've scaled all of the numbers such that the numbers that were one SD from the mean are 1 or −1, but those numbers are still the ones that are one SD away from the mean, so the SD must therefore be one.
And because the variance is the SD squared, that means that it must also be one.
Now imagine that we divided by the variance. That would mean that each element would be a measure of how many variances its "original" was away from the mean. Assuming that the variance was not already one, it must be a number either smaller or larger than one. That means that the numbers one SD away from the mean must now be something that is not one, So the SD is no longer one and neither is its square, the variance.
Another way to look at it is in terms of units. Let's say that we were dealing with a vector where each number was a measurement of a length in meters. What we're trying to do is convert it to a vector that is normalised such that each number is unitless.
Our variance was calculated by working out the mean (an average of measurements in meters, so it must be in meters), then subtracting that from each measurement (so we have a list of measurements' offsets from their mean -- also meters), and then we square those, getting a list of numbers that -- unit-wise -- represent square meters. We then divide them by the length of the vector (which is unitless).
So the variance, if we consider it in terms of its units, is in square meters, If we were to divide the list of measurements by that number, we'd get values that represented some kind of something per meter -- that is, the unit would be m−1.
But what we want to do is simply turn those measurements into dimensionless values. That means that we need to divide by something that has units of meters -- which is the standard deviation.
I'm not 100% happy with those explanations -- the latter, in particular, like all units-based explanations, is susceptible to issues with unitless values. For example, there could be a multiplication by some constant in there, or we might use the length of the vector squared.
However, I think they're enough for now.
But doesn't this break our embeddings?
Having understood how the normalisation works, there was one remaining problem I had in my mental model -- something that applies to all of the manipulations that we apply to the output of the multi-head attention mechanism, but that is particularly clear here.
Embedding vectors represent meaning; in general, the length isn't super-important for that, but the direction is. So if we change the direction of an embedding, we change its meaning -- and that sounds bad.
Scaling a context vector by its standard deviation just changes the vector's length, so that's not super-important. But changing its mean changes the direction!
For example, consider this toy embedding in a 2D space, (1,5). Let's chart it:
We set the mean to zero by working out the mean of the original numbers:
1+52=62=3
...then subtracting that mean from each of them:
(1−35−3)=(−22)
We don't really need to chart that to see that it's not pointing in the same direction as the original, but let's do it anyway:
The division of this vector by its standard deviation would just shrink it, so that's not an issue -- remember that the length of an embedding vector isn't really important -- but the complete change in direction felt like it would be an issue.
My concern was that we had these context vectors coming out of our multi-head attention layers that had some kind of meaning. Setting the mean to zero in this normalisation layer looked like it was going to completely break them!
The mistake I was making was in thinking that the context vectors coming out of the MHA layer were meaningful vectors in embedding space. When we train our LLM, we're training it as a whole -- indeed, there's no real guarantee that the context vectors moving between attention layers have any real meaning at all. The only constraint we're applying (via our error function and gradient descent via backprop) is that the context vectors from the very last layer point in the direction of the predicted next token's embedding.
Now, it's reasonable to assume that the context vectors that are passed between the layers have some kind of meaning in embedding space -- given how people talk about them, I imagine some tests have been done to check, though perhaps that's a deep interpretability problem -- but the important thing is that what is passed between the layers has already gone through this normalisation.
That is, if we assume that what is coming out of a single layer -- MHA, normalisation, and all of the other stuff we're going to cover in the next few posts -- is a meaningful context vector in embedding space, then the weights that we have inside the MHA layer itself -- our Wquery, Wkey, and Wvalue matrices -- will be trained so that the values that come out of that pure MHA part of the layer will not be in embedding space in and of themselves. They will instead be trained to output numbers that will be in embedding space, after they have been normalised.
The model I'm forming here is that the normalisation is not so much a mechanical thing that we're doing to the numbers that flow through the network; it's more of a constraint on the kinds of values that can flow through at particular points, and during training, our network learns to shape the "unconstrained" parts -- the attention weights -- so that as a whole it does the desired task while working within those constraints.
A work-in-progress metaphor I have for this is that it's kind of like a company where the interactions between teams are strongly codified, but teams can work pretty much independently so long as they present the required "API" to other teams. I've heard that AWS works kind of like this; the team that is responsible for -- say -- their email-sending service has a lot of freedom in how they work, so long as they provide an email service that operates in a particular way.
Might need more work, but I'll see if it sticks :-)
Scale and shift
Once we've done all of this and have our context vectors all nicely aligned with a mean of 0 and a variance of 1, there's one final step; we multiply them by a scale vector, then add on a shift vector. Each of these has a length equal to the embedding dimensions -- that is, the zeroth number in a given context vector is multiplied by the zeroth one in the scale vector, then has the zeroth in the shift vector added on. These vectors are trainable -- a given normalisation step will use the same ones for every context vector, but the actual values to use will be one of the parameters learned as part of the training.
That kind of feels like it's undoing some of the work we've done to normalise things! I also can't think of an obvious mapping for either of them in the audio metaphor I've been using so far.
But at a stretch, perhaps the scale vector could be seen as some kind of equaliser step? What it's doing is is saying that after this normalisation step, we're going to increase how much the context vectors point in certain dimensions, and decrease it in others. Perhaps an engineering trick that just happens to work? That feels like a bit of a cop-out, though.
The shift vector is even harder to put into the audio metaphor. I was thinking that it might be to avoid issues with the activation function that comes up later (having set the mean to zero means that we might "throw away" a lot of information with an activation function that treats numbers less than zero as zero), but there's a linear layer before that, which could perfectly well handle that side of things.
After discussion with Claude, I think that the best way to look at it is that having a mean of zero and a variance of one is not necessarily the best set of statistical properties for the context vectors to have. By setting all of them to those values, but then applying scale and shift, we give the model a way to learn what the most appropriate statistical properties for them to have actually are, and then to use that.
The important thing is that all of the context vectors coming in will have different variances and means, while the ones going out will have the same. So perhaps that still works in the terms of "getting all of the incoming signals to the same level" metaphor, albeit at a level of abstraction that might be becoming a bit too high...
It's still a little foggy in my mind, though. Perhaps something where just doing some experiments would help -- another one to add to my ever-growing list of things to try out later!
Wrapping up
So, as a first draft of a mental model for this -- layer normalisation constrains the values coming out of a particular attention layer so that they have a mean of zero -- no "DC bias" -- and they are within a reasonably similar range -- line signal level -- so that we don't have issues with exploding or vanishing gradients. We then apply a scale vector and a shift vector, which means that they now have the same mean and variance, but not necessarily zero and one.
That all seems pretty reasonable! I think we can wrap this one up here, and next time I'll move on to another part of the folding, spindling and mutilating that we do to the outputs of our simple multi-head attention modules in order to make the network trainable and able to do its job.