Archives
Categories
Blogroll
Chapter 5 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)" explains how to train the LLM. There are a number of things in there that required a bit of thought, so I'll post about each of them in turn.
The chapter starts off easily, with a few bits of code to generate some sample text. Because we have a call to torch.manual_seed at the start to make the random number generator deterministic, you can run the code and get exactly the same results as appear in the book, which is an excellent sanity check.
Once that's covered, we get into the core of the first section: how do we write our loss function?
In order to train an ML system like this using gradient descent, we need a function that will tell us how inaccurate its current results are. The training process runs some inputs through, works out the value of this error, or loss, and then uses that to work out gradients for all of our parameters. We then use those to adjust the parameters a little bit in the right direction, and then try again with some new training inputs. Rinse and repeat, checking constantly that things are going in the right direction -- and we're done.
The standard way to express this in most areas of AI -- including deep learning -- is to have an error/loss function that is zero when the model has given exactly the expected result for our training data, and a value higher than zero if it's wrong, with the size of the value being an indication of how far off the model's output was. We want to descend the gradient of that loss function until we hit a minimum.
So, how do we define a function that says how wrong our LLM's outputs are? Let's start by looking at what they are. Our LLM receives a sequence of token IDs that represent the text that we want to process, and outputs a sequence of vectors of logits, one for each token in the input sequence. The logits for each token are the prediction for what comes next, based only on the tokens from the start, up to and including that specific token.
Now let's look at our training data. Way back at the start of this series, I commented that it seemed odd that the training targets for an LLM were the original sequence "shifted left, with a new token on the end" -- that is, we might have this input/expected output pair:
"The fat cat sat on the" -> " fat cat sat on the mat"We now know that the LLM is predicting logits for all of the "prefix" sequences in the original sequence, so we can see that pair as kind of a shorthand for:
- "The" -> " fat"
- "The fat" -> " cat"
- "The fat cat" -> " sat"
- "The fat cat sat" -> " on"
- "The fat cat sat on" -> " the"
- "The fat cat sat on the" -> " mat"
That's pretty cool, because we're getting six separate training "runs" from one input sequence. And of course, during training we're likely to be running batches of sequences through, each with its own shifted-left target sequence.
For me, the first step in understanding the loss function that we want to use for these batches of input sequences was to realise that every one of these prefix sequence/target pairs can be treated as stand-alone. So, if we had a batch size of one and just fed in "The fat cat sat on the", we'd have six separate logit vectors, each of which needed to be evaluated against a single target token.
And if our batch size is more than one, the same applies. From the viewpoint of calculating the loss function, if we have b sequences in our batch, each of n tokens, and so we get b×n logit vectors -- and of course we have the same number of target tokens -- then it doesn't really matter that lots of the input sequences involved are prefixes of each other. We just have a bunch of stand-alone results to compare to their targets.
For each one, we just work out a loss for that pair -- and then we can aggregate all of those losses together at the end. The actual aggregation is almost absurdly simple -- we just take the arithmetic average of all of the individual sequence-target losses!
But how do we work out those individual losses? That's where cross entropy loss comes in.
Cross entropy loss: the "how"
Before I started reading this part of the book, I had an inkling of what you might do. Let's imagine an LLM that has a three-token vocab. We've fed in some sequence, and we've received logits. Let's say they look like this:
[10.3, 7.6, 9.9]We can run that through softmax to convert it into probabilities -- what I (perhaps slightly sloppily) called moving it into a "normalised vocab space" in my post on the maths for LLMs. That would look like this:
[0.5755, 0.0387, 0.3858]Now, we've fed in a training sequence, so we have a training target for that. Let's say it's the second token in the vocab (index 1 in our lists). We can express our target in the same form as the softmaxed output, as a one-hot vector:
[0.0, 1.0, 0.0]So, I thought, perhaps we can write a function that works out how far off the number in each position in that first vector is from its counterpart in the second, then combines those differences in some way to get an aggregate loss. That would guide the training process towards reducing the numbers in positions 0 and 2 in the list, and increasing the number in position 1, edging it slightly towards the right output.
Maybe that's not a bad idea, but we can actually do it more simply. If all we do is measure how far off the prediction in the target position is -- that is, how far away 0.0387 is from 1 in the example -- and use that to work out our gradients, we'll train the LLM to increase that value.
That has a huge beneficial side-effect: because the sum of all of the outputs of softmax must be 1, by adjusting that one upwards, we're adjusting the others downwards "automatically" -- so, we get the adjustment to those numbers for free, without having to work out some kind of loss number for every value in our logits vector -- and remember, there are 50,257 of those with the GPT-2 vocab size.
So, we want a loss function that can express how far our prediction for the target token -- let's stick with the example of 0.0387 -- is from 1. There's a pretty obvious one. The prediction is a probability so it's guaranteed to be in the range 0..1, so we could use:
L=1−pcorrect
...where pcorrect is the probability that our LLM assigned to the target (eg. the correct) token. That will be zero if our LLM was perfectly correct and predicted that it was going to be the target, and increasingly more than zero (up to one) if it's wrong.
However, there is a better one -- I'll explain why it's better later on -- and this is what we use:
L=−logpcorrect
Remember that the logarithm log is defined such that logx is the value -- let's call it y -- such that for some "base", b, by=x.
So, log1 is 0, because that's the power of any base that is 1. And logx, if x<1, is going to be a negative number -- and the size of that negative number will increase quite rapidly as we get closer to zero. Using the natural logarithm, where the base is Euler's number, e:
Now we want our error function to return a positive value, but that's fine -- we just flip the sign with that negation operator in the formula for L above.
That's a special case of the cross entropy loss function, and all we need to do to work out a loss for a run of a training batch for b sequences, each n tokens long, is to calculate the loss using this function for each of the b×n prefix sequences using the targets that we have for each, and then take the average.
That's essentially what Raschka does in the book in his example (although instead of negating the logs and then averaging them, he averages first and then negates the result). And in the PyTorch code that follows, where he uses the built-in torch.nn.functional.cross_entropy function, he just does a bit of flattening first so that our outputs are converted from a b×n×v tensor of logits to a b·n×v tensor -- that is, to treat prefix sequences from different batches as being "the same" -- and likewise the targets from a b×n to a vector of length b·n. (He also just passes the logits directly in, rather than softmaxing first, as the PyTorch function does the softmax for you.)
And to understand what's going on in the rest of the chapter, I think that's all we need!
But this cross entropy thing interested me, and I wanted to dig in a bit. Why do we use this thing with the logarithm, and why is it called "cross entropy"? It is a bit of a side-quest, but I don't like the idea of using something whose name I don't understand. How does it relate to entropy, and what is it cross with?
If you're not interested in that, do feel free to move on now -- you don't need to read the next bit for the rest of the book. But if you're like me and want to dig in, read on...
Digging into cross entropy
Understanding cross entropy loss means understanding what entropy is, then what cross entropy is, so let's kick off with the former.
Entropy in physics -- strictly, thermodynamics -- is loosely speaking how disordered or "messy" a system is. That much I remember from my undergraduate days -- and it's a concept that pops up in literature a lot. In a closed system, with no energy coming in (and therefore nothing to "tidy things up") entropy always increases. Writers of science (and other) fiction love to bring it up because "things get worse if you don't fight against disorder" is a lovely plot point, both metaphorically and concretely.
Obviously that's a very handwavy summary, but I think that's all we need to move on to what it means in the context of information theory. Back in the 1940s and 50s, Claude Shannon wanted to quantify how much information was actually expressed in a message. Let's say that you're receiving numbers from some source. If it always sends zeros, then it's low-information. If it just sends roughly equal numbers of zeros and ones, there's more information there. If it's a wide range of different numbers, it's even higher information -- and the distribution of numbers could also influence how information-rich it is. How to capture that mathematically?
Entropy for probability distributions
Let's take a more concrete example. Imagine you needed to work out what to wear for the day (without looking at the weather forecast); if you were in the Sahara desert, you could be pretty certain that it was going to be hot and sunny. On the other hand, if you were in London during the early springtime, it could be anything from warm and sunny to torrential rain. If we were to express that as probability distributions, then for the Sahara the probability of "hot and sunny" is really high and everything else is low, whereas for London the distribution is much flatter, with most outcomes having roughly the same probability.
There's an obvious parallel between these "neat" and "messy" distributions in probability and the "neatness" of a low-entropy physical system and the "messiness" of a high-entropy one. Shannon wanted to create a formula that would allow us to take a probability distribution -- the distribution of the numbers being sent by that source -- and get a single number that said what its entropy was -- how messy it was.
Let's start defining some terms. We have a probability distribution that says for each possible outcome, how likely it is. We'll treat that as a function called p, so p(x) is the probability of the weather being x, where x could be "sunny", "rainy", etc.
We want to write a function H(p), that will take that probability distribution p and return a number that represents its entropy across all values of x.
That's (fairly) obviously going to have to look at the value of p for each possible x and somehow munge them together. So the first step is to work out some way of working out how much a given p(x) contributes to entropy.
Shannon decided to start with a measure of how surprising p(x) is -- that is, if outcome x happened, how surprised would we be? We're generally surprised when low-probability things happen, and less surprised when high-probability things happen. If we're in London and it starts drizzling, we're not all that surprised. If we're in the middle of the Sahara desert and it starts drizzling, we're going to be a bit disconcerted.
That means that we want a number for p(x)'s contribution to entropy that is high for low-probability events, and low for high-probability events. But now we're kind of back to where we were earlier; how surprising p(x) is could, it seems, be reasonably expressed as 1−p(x).
Let's chart that:

Looking at it, something stands out that makes it not quite ideal for surprise: Let's say that something happens one in a thousand times -- that is, its probability is 0.001, so this surprise measure would be 0.999. Now let's say that something else happens half the time, so its probability is 0.5 and so is its surprisingness. Now imagine something that happens almost all of the time, say probability 0.999 so surprise 0.001. The "jump" in surprisingness from something that happens almost always to something that happens half the time isn't a huge step, at least for me intuitively. The jump from something that happens half the time to something that is a one in a thousand chance is much bigger! But they're equally spaced in our "surprise space".
Thinking about it in terms of information -- something that happens half the time doesn't add on much information, but if something happens that is a one in a thousand chance, it's told us something important -- at the very least, that this occurrence is possible, even if it's not likely.
There are a couple of other more mathematical points too:
- Let's imagine that we're looking at combined probabilities. The probability of rolling a three on a die is, of course, 1/6≈0.167. So on this measure its surprisingness would be about 0.833. Now let's think about the probability of rolling two threes in a row. High-school probability tells us that this is (1/6)×(1/6)≈0.028, so its surprisingness is 0.972. There's no obvious connection between those two surprise numbers, 0.833 and 0.972. What would be nice (and you'll see why shortly) would be if we could add the surprisingness of two independent events like these dice rolls together to get the surprisingness of them both happening.
- This is almost a more mathematical way of looking at the example I gave above about the one in a thousand, 50:50 and almost-certain events, but as you can see from the chart, the derivative of the surprise function is constant -- that is, the rate of change of surprisingness is the same at a zero probability as it is at 0.5 as it is at 1. Especially in our case of training an LLM, that doesn't sound quite right -- and as I understand it, that's a general problem in other domains.
So, what alternative do we have? We want something that has nice high surprise factors for low probability events, and has the nice additive property where the surprise for two independent events happening is the sum of their respective "surprisingnesses".
Given what we wind up using above, you probably won't be surprised to know that it's −logp(x).
Let's chart that using the natural logarithm :

We can see straight away that the initial problem I described above is gone. Something that is almost certain to happen will have a surprise number of roughly zero, something that happens with 50:50 chances will be about 0.7, but something that happens one in ten times will be about 2.3. Surprisingness increases rapidly for very unlikely things -- a one in a thousand chance has a surprisingness of about 6.9, for example, and a one in a million chance has 13.8.
The derivative is also not constant as a result; that's pretty obvious from the chart.
But the neat bonus is that you can add these surprise factors together! If you remember your high school maths, you can prove it, but let's just use the dice example from above. The surprisingness of rolling a three is:
−log(1/6)≈1.79
...and the surprisingness of rolling two threes is
−log(1/6×1/6)≈3.58
That is, you can add together the surprisingness of two independent events and get how surprising both of them happening is.
That's pretty nifty -- and so, we have our formula for how surprising a particular outcome, p(x), is.
But we want to work out the entropy for all of p for all xs, so we need to combine them somehow. What should the contribution of this number, the surprise of p(x), be to that entropy?
Well, we need to scale it. Something that is surprising but rarely happens should contribute less to entropy than something that is surprising but happens very often. And conveniently, we already know how likely x is to happen -- it's p(x)!
So for each possible outcome, we have a way to work out how much scaled surprise it should contribute to the entropy:
(−logp(x))·p(x)=−p(x)·logp(x)
That may sound very circular; both halves of that formula -- the surprisingness and the scaling factor -- are essentially the same thing! And I must admit that I'm a bit put off by it too. I know that the surprise number represents how much new information we get by seeing the outcome, and the probability is how often we get that information, but it still doesn't sit quite right.
But that's a me problem, not a problem with the equation -- and the good news is that when we come on to cross entropy (as opposed to the regular entropy that we're looking at right now), it becomes a little more clear.
So let's finish off entropy first. We have a scaled surprise number for all of the individual xs, and we want a combined one for the whole distribution p. Now, remember that one of the reasons we chose the particular measure for surprise that we did was that they could be added to get a meaningful number. So that's exactly what we do with these scaled surprise numbers! We just add together the scaled surprise for every value of x, and that gives us the entropy of our distribution -- or, in mathematical notation:
H(p)=−∑xp(x)·logp(x)
Cross entropy
So now we have a formula that can tell us how high the entropy of a distribution is. How does that help with training our LLM? That's where the "cross" in cross entropy comes in.
Let's think about our LLM. It's trying to predict the next token in a sequence; you've fed it "the fat cat sat on the" . Internally it has some kind of model about what the next token should be, and it's spat out its logits for that sequence, which we run through softmax to get our probability vector where the value at position n is its prediction of the probability that the next token will be the one with the ID n.
Now, that vector is a probability distribution -- let's call it q. And it has its own entropy, H(q). For "the fat cat sat on the", then for a trained LLM, it's likely to be quite low-entropy because "mat" is high-probability (low-surprise) and most of the other tokens are low-probability (so high surprise). But if you were to feed it garbage like "Armed heNetflix", it would likely have no idea about what the next token might possibly be and would return a flatter, higher-entropy distribution.
That in itself is kind of interesting (at least to me), but what we actually want to do is find out how accurate the model is at predicting the next token. And for that, we need to modify the equation a little.
Remember that the per-outcome calculation for entropy was "how surprising is this outcome", which was −logp(x), times "how frequent is this outcome", which was just p(x).
In this new world where we're predicting next tokens, the surprisingness is actually an attribute of the model. The less likely the LLM thinks that a particular token is, the higher the surprise factor if that token actually turns out to be the next one.
But the frequency is an attribute of the real world -- whether or not that is a valid next token in the training data.
So, we extend the formula so that instead of just measuring the entropy of a probability distribution, it measures the entropy of that distribution if you have a model that's predicting a (potentially different) distribution. We've already said that the LLM's predicted distribution is q, so let's call the real-world distribution p and define cross entropy:
H(p,q)=−∑xp(x)·logq(x)
You can see that the surprisingness of each outcome is based on the LLM's prediction q, but the scaling factor to allow for how often it comes along is from reality, p.
That actually makes more sense to me than the original pure-entropy formula! The two halves of the per-outcome calculation are clearly different.
So now we have the beginnings of our loss function. The higher the cross entropy between reality and the model's prediction, the higher the loss. We want our training to guide things in a direction that lowers the cross entropy between our model's predictions and reality.
But after all that it just simplifies away
But how do we go from that simple but non-trivial formula down to the "just do minus the log of the prediction for the actual next token" calculation that we had back at the start of this marathon?
Let's think about what p and q are in an actual training run. We've fed in "the fat cat sat on" and we're trying to score the predictions from the LLM -- that's q, which is a prediction per token.
Let's say that our training data, predictably enough, has "mat" as the next token. We want to represent that as a probability distribution -- and we only have one possibility. That means that the "real" distribution p is basically a one-hot vector -- every number is zero apart from the one for "mat", which is one.
And if you look at the equation for cross entropy above, that means that every number in our big sum for values of x that are not "mat" will be
0·logq(x)=0
And for x being "mat", it will be:
1·logq(x)=logq(x)
So the whole equation, for one-hot distributions of p, collapses to
H(p,q)=−logq(x)
...where x is the one-hot outcome. And that's exactly the equation we use.
Certainty
One thing that might feel slightly strange about this is that we're being so "certain" about the correct output. We've fed in "the fat cat sat on the" and then calculated cross entropy based on a one-hot that expresses something like "the right answer is definitely 'mat', no ifs or buts".
That, of course, is wrong! Even though "the fat cat sat on the mat" is a cliché, and so "mat" is really very likely indeed, the next token could also reasonably be "dog" or "lap" or something else entirely.
I can imagine that you could actually do some kind of training where you fed in multiple possibilities for a target -- that is, instead of using a one-hot vector, you'd use a vector that expressed the true distribution of the possible next tokens after that sequence, so "mat" would be high, "dog" and "lap" a bit lower, and impossible words like "eat" could be zero. Then you'd use the full cross entropy equation instead of this stripped-down version.
But that would be very hard to set up -- imagine trying to work out sequence/next-token probability vectors across the kind of huge pile of training data used for LLMs. And in reality, because we're doing gradient descent, our training that pushes the LLM in the direction of "mat" for this sequence will also be mixed in with training on other sequences with other tokens like "dog" or "lap", each of which nudges the parameters in its preferred direction, and eventually they'll pretty much average out to the right distribution, because it really will see "mat" more than the alternatives.
And, in practice, that works perfectly well, and that's why LLMs are so good at next-token prediction.
Wrapping up
So that's it -- we use a simple formula to work out the cross entropy of the LLM's prediction against the reality as expressed by our training targets. It's what the more complex cross entropy function collapses to when our training targets are expressed as "one-hot" probability distributions. We work that out for all prefix sequence/target pairs across all items in our training batch, average them, and that gives us a loss that -- as we reduce it through gradient descent -- minimises the cross entropy and thus the error.
I hope that was useful and at least reasonably clear!
I found digging into this quite fun, and I'm pretty sure I've got the details mostly correct -- but if anyone reading knows better, as always the comments are open below for corrections, or of course requests for clarification.
Coming up: the next part of the training story that made my −logp(x) quite high -- perplexity.
.png)


