Archives
Categories
Blogroll
In chapter 5 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)", we finally trained our LLM (having learned essential aspects like cross entropy loss and perplexity along the way). This is amazing -- we've gone from essentially zero to a full pretrained model. But pretrained models aren't all that useful in and of themselves -- we normally do further training to specialise them on a particular task, like being a chatbot.
Chapter 6 explains a -- to me -- slightly surprising thing that we can do with this kind of fine-tuning. We take our LLM and convert it into a classifier that assesses whether or not a given piece of text is spam. That's simple enough that I can cover everything in one post -- so here it is :-)
The decapitation technique
The core idea in the chapter is the trick that we use to change an LLM that is designed to predict the next token into one that classifies texts into different categories.
What we do is remove its final layer -- the one that maps from 768-dimensional embedding space to vocab space so that we get our logits to turn into a probability distribution -- and replace it with a much simpler one, which maps from the 768-dimensional embedding space to a 2-dimensional one, where logits -- after softmax -- represent the respective probabilities of spam vs ham (non-spam). As Raschka says, it's as if we're constraining our model to having an output vocabulary of two tokens. We're replacing our output head with a classification head instead.
Because it involves removing the output head of the existing model, I'm calling it the decapitation technique. (ChatGPT tells me that it has a more prosaic name -- it's a linear probe.)
This was something I recognised from Jeremy Howard's fast.ai course -- it looks like it might no longer be part of it (the trick in question has been absorbed into the library that the course uses), but I do remember doing an image classifier by removing the output head from an existing one and then training a new, simpler one to replace it.
With an LLM, there's an extra tweak; we only consider the logits provided for the last token. That makes sense; with the original head, the last token's embedding, projected into vocab space, is the predicted next token for the sequence as a whole, and it is the only one with information from all of the other tokens blended into it by the attention layers. Being the richest representation of the sequence, it's obviously the best one to check when we're trying to classify that sequence.
But that leads to another interesting thing about the training -- when we're calculating our loss, we only consider the cross-entropy loss between the logits vector for the last token and the target category. That makes sense simply because we don't have spam vs ham predictions for the shorter prefix sequences -- but also, we honestly don't care what its predictions for them might be. Which leads to the interesting possibility that they might wind up not being spam vs ham predictions at all -- the model has more "freedom" in how it uses them, so they could be anything.
Datasets
There is a lot of data-wrangling going on in this chapter -- all of which is relatively simple, but one thing that stood out for me was that we make sure that we have the same amount of spam and ham in our training, validation and test sets. Intuitively this makes sense. If we had a training set that was (say) 95% ham and only 5% spam, then a model that was essentially this:
...would be right 95% of the time, which is around the accuracy of the trained model that we wind up with. But it would also be pretty useless. Making sure that we have similar amounts of both is a good way to avoid the model getting trained in a dumb direction like that.
I also spotted something a bit odd in the code that loads the different datasets in. Here it is:
Look at those drop_last parameters -- they're True for the training set, but False for the others.
From the docs for PyTorch's DataLoader
drop_last (bool, optional) -- set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
So: if the number of training samples isn't divisible by the batch size, and we had drop_last set to False for training data, then our last batch would have fewer items in it.
I think it makes intuitive sense that you'd want all of the batches going through during your training to be the same size. Let's imagine that we have ten items in most of our batches; loosely speaking, that means that on each gradient update, each one of those items will contribute 10% of the update. But if the last batch has only four items, then those four will each contribute 25%. So the items in the last, smaller batch will have an outsized impact on the model's training, which would obviously be a bad thing.
For our validation and test sets, it's pretty clear that it doesn't matter -- all we want to do for them is work out how good the model is at classifying them, and they don't have a direct impact on the model's parameter updates.
The initial test
Before we start training, we see whether our model -- with its original projection-to-vocab head in place -- can classify text as spam or not. We feed it this:
Is the following text 'spam'? Answer with 'yes' or 'no': 'You are a winner you have been specifically selected to receive $1000 cash or a $2000 reward.'...and it responds with this:
The following text 'spam'? Answer with 'yes' or 'no': 'You are aSo, no luck -- not a big surprise.
Now, quite some time ago when I started playing with LLMs, I started a series trying to build an AI chatbot using the OpenAI APIs, which -- at the time -- were just text-completion based, rather than the chat-template ones that they are now. I found that even with some of the older models, you could get it to work like a chatbot by telling it what this text was meant to look like. So I tried that:
This is a transcript of a conversation between a helpful bot, 'Bot', and a human, 'User'. The bot is very intelligent and always answers the human's questions with a useful reply. Human: Is the following text 'spam'? Answer with 'yes' or 'no': 'You are a winner you have been specifically selected to receive $1000 cash or a $2000 reward.'Unfortunately it didn't help much -- here's what it came up with:
Bot: User: Bot: User: Bot:Ah well, worth a try! I might give it another go once we've instruction-trained it.
So, we know we need to train our LLM -- the next bit of the book explains how to do it.
Training and randomness
Just like last time, even though Raschka uses torch.manual_seed in lots of places to make the results reproducible, I got different numbers to his. Interestingly, in this chapter I got pretty much the same numbers for the first few steps, but after the accuracy functions were introduced, my results started differing. Again, I don't think this matters much so long as the numbers are similar. (With the caveat of some slightly disappointing results I got at the end.)
I was interested in the trick of freezing the gradients on almost all of the LLM's layers -- only allowing the last one, and the final layer-norm to be trained, plus the new output head. It kind of reminded me of LoRA, which reduces the amount of work you need to do to fine-tune a model by limiting the number of weights you change -- though the way that works is quite different (you introduce new, smaller weight matrices that "adapt" the results of the existing frozen ones, and then train those).
Raschka suggests an exercise where you try training all of the layers -- that is, with none of them frozen. I did that, and the results were really interesting -- more on that later...
One thing that did surprise me a bit was that in this chapter we're training with dropout of zero. It seems strange, but I think it's because we're freezing those layers. The number of parameters that are actually being trained is small, so dropping them out might just throw away signal from the training, and make it converge more slowly. What's worse is that there's nothing stopping dropout from happening in those frozen layers, too -- we've set requires_grad to False, so they're not being updated by the training loop, but the model is still in training mode, so dropout would happen in them if it wasn't set to zero.
Anyway, the code to actually run the training is pretty simple and I won't repeat what is already explained perfectly well in the book. My train (with whatever the differences with the seed it had from Raschka's) came out with broadly similar results:
Training accuracy: 97.12% Validation accuracy: 95.97% Test accuracy: 95.67%It took 15 seconds on my RTX 3090 -- within the "less than half a minute" range he mentions for the V100 and A100 datacenter cards. Plotting the loss:

...you can see that it looks very similar to what happens in the book.
The accuracy was kind of interesting, though:

You can see that mine actually dropped off on the validation set near the start -- however, it did recover nicely, so no big deal. It looks like the best results were in somewhere around epoch 4, but there's not a huge drop off from there to where we stop at the end of epoch 5.
The final results were a little disappointing, though. In the last pages of the chapter, we run two sentences through our trained classifier. Firstly, something that is very spammy:
You are a winner you have been specially selected to receive $1000 cash or a $2000 rewardAnd secondly something that isn't spammy:
Hey, just wanted to check if we're still on for dinner tonight? Let me know!Unfortunately, due to the differences in seeding between my model and the one Raschka was working with, I got the same result for both: not spam.
After checking the code carefully, I decided to take a look at what the actual predictions looked like; in the classify_review function that we write to do these tests, I added this:
For the ham case, this printed:
tensor([[0.9771, 0.0229]], device='cuda:0')The first element (index 0) is the probability that it is predicting for ham, and the second is the probability for spam. So it was 97% sure that the ham message was indeed ham. That's good!
For the spam message, though, it was pretty much on the fence:
tensor([[0.5987, 0.4013]], device='cuda:0')It thought that there was a 59.87% chance that it was ham, but a 40.13% chance that it was spam. It would be interesting to know what Raschka's own train came up with, but I can imagine that it might be a similarly close-run thing, and I was just unlucky with the way the randomness in training fell out.
But I decided to try training all of the layers to see if I got any different results.
Training all of the layers
This took 42 seconds rather than 15 -- hardly a big deal at this scale, but you can see how not having to train everything and getting a 3x speedup would be worthwhile for larger models.
The results were definitely better, though:
Training accuracy: 99.71% Validation accuracy: 98.66% Test accuracy: 97.67%Maybe a little bit of overfitting going on? But the important thing is that the validation and test samples' accuracy are higher too, we're not just memorising the training set.
The loss graph also looks solid:

And accuracy is really interesting:

You can see that validation accuracy got up to 100% sometime during the first epoch, lining up with the point where the loss plateaus! Perhaps a snapshot taken then would have been the best form of the model.
But anyway, I tried running the two samples from the last pages of the chapter through this new version of the model -- the results were exactly right! "You are a winner..." was classified as spam, and "Hey, just wanted to check..." as ham. Looking at the predictions that my modification to the classification code printed out, things were even clearer. The spam had this:
tensor([[0.0285, 0.9715]], device='cuda:0')A 97% chance that it was spam, which is excellent.
Looking at the ham message, it said this:
tensor([[9.9968e-01, 3.2288e-04]], device='cuda:0')That's a 99.9% chance that it's ham!
So that was a nice note to end on. Not only did I have results that showed that the classification worked -- the fact that it worked with just the 30 seconds of extra training that enabling gradients on all of the layers gave meant that it was unlikely that my original slightly-disappointing results were due to an error in the code -- it was, as I'd suspected, just a bit of bad luck with some randomness.
Wrapping up
So, that's it for chapter 6! Classification with a decapitated LLM done and dusted. Next time, it's on to instruction fine-tuning -- something that I spent quite a lot of time on last year, so let's see if all that work will pay off.
.png)
![Recreating Lions' commentary for teaching OS at MIT [video]](https://www.youtube.com/img/desktop/supported_browsers/opera.png)
