Retro Language Models: Rebuilding Karpathy's RNN in PyTorch

3 hours ago 1

Archives

Categories

Blogroll

I recently posted about Andrej Karpathy's classic 2015 essay, "The Unreasonable Effectiveness of Recurrent Neural Networks". In that post, I went through what the essay said, and gave a few hints on how the RNNs he was working with at the time differ from the Transformers-based LLMs I've been learning about.

This post is a bit more hands-on. To understand how these RNNs really work, it's best to write some actual code, so I've implemented a version of Karpathy's original code using PyTorch's built-in LSTM class -- here's the repo. I've tried to stay as close as possible to the original, but I believe it's reasonably PyTorch-native in style too. (Which is maybe not all that surprising, given that he wrote it using Torch, the Lua-based predecessor to PyTorch.)

In this post, I'll walk through how it works. In follow-up posts, I'll dig in further, actually implementing my own RNNs rather than relying on PyTorch's.

All set?

Previously, on Retro Language Models...

If you already have a basic understanding of what RNNs are and roughly how they work, you should be fine with this post. However, if you're coming directly from normal "vanilla" neural nets, or even Transformers-based LLMs (like the one I'm working through in my LLM from scratch series), then it's definitely worth reading through the last post, where I give a crash course in the important stuff.

So with that said, let's get into the weirdest bit from a "normal" LLM perspective: the dataset.

Datasets for RNN training

Every now and then on X/Twitter you'll see wry comments from practitioners along the lines of "AI is 5% writing cool models and 95% wrangling data". My limited experience bears this out, and for RNNs it's particularly weird, because the format of the data that you feed in is very different to what you might be used to for LLMs.

With a transformers-based LLM, you have a fixed context length -- for the GPT-2 style ones I've posted about in the past, for example, you have a fixed set of position embeddings. More recent position encoding mechanisms exist that aren't quite so constraining, but even then, for a given training run you're going to be thinking in terms of a specific context length -- let's call it n -- that you want to train for.

So: you split up your training data into independent chunks, each one n long. Then you designate some subset of those your validation set (and perhaps another bunch your test set), and train on them -- probably in a completely random order. You'll be training with batches of course; each batch would likely be a completely random set of chunks.

To get to the core of how different RNNs are, it helps to start with an idealised model of how you might train one. Remember, an RNN receives an input, uses that to modify its internal hidden state, and then emits an output based on the updated hidden state. Then you feed in the next input, update the hidden state again, get the next output, and so on.

Let's imagine that you wanted to train an RNN on the complete works of Shakespeare. A super-simple -- if impractical -- way to do that would be to feed it in, character by character. Each time you'd work out your cross-entropy loss. Once you'd run it all through, you'd use those accumulated per-character losses to work out an overall loss (probably just by averaging them). You would run a backward pass using that loss, and use that to adjust the parameters.

If you're feeling all at sea with that backpropagation over multiple steps of a single neural network with hidden state, check out the "Training RNNs" section of the last post.

You can see that in this model, we don't have any kind of chunked data. The whole thing is just run through as a single sequence. But there are three problems:

  1. Vanishing/exploding gradients. Let's say that we're training a three-layer network on the 5,617,124 characters of the Project Gutenberg "Complete Works of Shakespeare". That's essentially backpropagation through a 16-million layer network. You won't get far through that before your gradients vanish to zero or explode to infinity. The only meaningful parameter updates will be for the last something-or-other layers.
  2. Batching. Running multiple inputs through a model in parallel has two benefits: it's faster and more efficient, and it means that your gradient updates are informed by multiple inputs at the same time, which will make them more stable.
  3. Validation. There's nothing in there as a validation set, so we will have no way of checking whether our model is really learning, or just memorising the training set. (There's the same problem with the test set, but for this writeup I'll ignore that, as the solution is the same too.)

Let's address those -- firstly, those vanishing or exploding gradients. In the last post I touched on truncated backpropagation through time (TBPTT). The idea is that instead of backpropagating through every step we took while going through our batched input sequences, we run a number of them through, then backpropagate, and then continue. Importantly, we keep the hidden state going through the whole sequence -- but we detach it from the compute graph after each of these steps, which essentially means that we start accumulating gradients afresh, as if it was a new sequence, but because it started from a non-zero initial hidden state, we're still getting some training value from the stuff we've already been through.

Imagine we have this simple sequence:

abcdefghijklmnopqrstuvwxy

Let's say we're doing TBPTT of length 3: we can split up our training set so that it looks like this:

abc:def:ghi:jkl:mno:pqr:stu:vwx:y

So now, we just feed in "a", then "b", then "c", then do our TBTT -- we calculate loss just over those items, update our gradients, and then detach the hidden state, but keep its raw, un-gradient-ed value. Then we start with that stored hidden state, and feed in "d", "e", "f". Rinse and repeat.

In practice we'd probably throw away that short sequence at the end (because it would cause issues with gradient updates -- more here), so we'd just get this:

abc:def:ghi:jkl:mno:pqr:stu:vwx

Now, let's look into batching. It's a bit harder, but with a bit of thought it's clear enough. Let's say that you want b items in your batch. You can just split your data into b separate sequences, and then "stack them up", like this with b=2:

abc:def:ghi:jkl mno:pqr:stu:vwx

So for training, we'd feed our vector a, m in as a batch, calculate loss on both of them, then b, n, and so on. The important thing is that each batch position -- each row, in that example -- is a consistent, continuous, meaningful sequence in and of itself.

Finally, for validation, you also need some real sequences. For that, you can just split up the batched subsequences, with a "vertical" slice. Let's take the rather extreme view that you want 50% of your data for validation (in reality it would be more like 10-20%, but using 50% here makes it clearer):

Your training set would wind up being this:

abc:def mno:pqr

...and the validation set this:

ghi:jkl stu:vwx

And we're done!

So that's what we wind up feeding in. And it kind of looks a bit like what we might wind up feeding in to a regular LLM training loop! It's a set of fixed-length chunks. But there's one critically important difference -- they're not in an arbitrary order, and we can't randomise anything. The sequence of inputs in, for example, batch position one, needs to be a real sequence from our original data.

This has been a lot of theoretical stuff for a post that is meant to be getting down and dirty with the code. But I think it's important to get it clear before moving on to the code because when you see it, it looks pretty much like normal dataset-wrangling -- so you need to know why it's really not.

Let's get into the code now. In the file next_byte_dataset.py, we define our dataset:

class NextByteDataset(Dataset): def __init__(self, full_data, seq_length): super().__init__() assert seq_length > 0, "Sequence length must be > 0" self.seq_length = seq_length

The full_data that we pass in will be our complete training corpus -- eg. the complete works of Shakespeare -- and seq_length is the limit we're going to apply to our truncated backpropagation through time -- that is, three in the example above. Karpathy's blog post mentions using 100, though he says that limiting it to 50 doesn't have any major impact.

Next, we make sure that we have at least enough data to do one of those TBPTTs, plus one extra byte at the end (remember, we need our targets for the predictions -- the Ys are the Xs shifted left with an extra byte at the end).

self.num_sequences = (len(full_data) - 1) // self.seq_length assert self.num_sequences > 0, "Not enough data for any sequences"

...and we stash away the data, trimmed so that we have an exact number of these sequences, plus one extra byte for our shifted-left targets.

self._data = full_data[:(self.num_sequences * self.seq_length) + 1]

Now we create a tokeniser.

self.tokenizer = NextByteTokenizer(sorted(set(self._data)))

This is related to something I mentioned in the last post. Karpathy's post talks about character-based RNNs, but the code works with bytes. The RNNs receive as their input a one-hot vector. Now, if we just used the bytes naively, that would mean we'd need 256 inputs (and accept 256 outputs) to handle that representation. That's quite a lot of inputs, and the network would have to learn quite a lot about them -- which would be wasteful, because real human-language text, at least in European languages, will rarely use most of them.

His solution is to convert each byte into an ID; there are exactly as many possible IDs as there are different bytes in the training corpus, and they're assigned an ID based on their position in their natural sort order -- that is, if our corpus was just the bytes 43, 12 and 99, then we'd have this mapping :

0 -> 12 1 -> 43 2 -> 99

We just run the full dataset through set to get the set of unique bytes, then sort it -- that gives us a Python list in the right order so that we can just do lookups into it to map from an ID to the actual byte. The NextByteTokenizer class is defined in next_byte_tokenizer.py and is too simple to be worth digging into; it just defines quick and easy ways to get the vocab size (the number of IDs we have), and to encode sequences of bytes into PyTorch tensors of byte IDs and to decode them in the other direction. Because these byte IDs are so similar to the token IDs that we use in LLMs, I've adopted the name "tokens" for them just because it's familiar (I don't know if this is standard).

So, at this point, we have our data and our tokenizer; we finish up by stashing away an encoded version of the data ready to go:

self._data_as_ids = self.tokenizer.encode(self._data)

Next we define a __len__ method to say how long our dataset is -- this is calculated in terms of how many TBPTT sequences it has:

def __len__(self): return self.num_sequences

-- and a __getitem__ method:

def __getitem__(self, ix): start = ix * self.seq_length end = start + self.seq_length xs = self._data[start:end] x_ids = self._data_as_ids[start:end] ys = self._data[start + 1:end + 1] y_ids = self._data_as_ids[start + 1:end + 1] return x_ids, y_ids, xs, ys

This works out the start and the end of the ixth subsequence of length self.seq_length in the data. It then returns four things:

  1. x_ids: the byte IDs of the bytes in that sequence -- these are the ones we'll run through the model, our Xs. Note that these are slices of the PyTorch tensors that were returned by the tokeniser, so they're tensors themselves.
  2. y_ids: the shifted-left-by-one-plus-an-extra-byte target sequence as byte IDs -- the Ys for those Xs. These are likewise tensors.
  3. xs: the raw bytes for the x_ids.
  4. ys :the raw bytes for the y_ids.

The code as it stands doesn't actually use the last two, the raw bytes -- but they did prove useful when debugging, and I've left them in just in case they're useful in the future.

If you look back at the more theoretical examples above, what this Dataset is doing is essentially the first bit: the splitting into BPTT-length subsequences and dropping any short ones from the end -- the bit where we go from

abcdefghijklmnopqrstuvwxy

to

abc:def:ghi:jkl:mno:pqr:stu:vwx

The only extra thing is that it also works out our target sequences, which will be a transformation like this:

bcdefghijklmnopqrstuvwxyz

to

bcd:efg:hij:klm:nop:qrs:tuv:wxy

So that's our NextByteDataset. Next we have a simple function to read in data; like the original code I just assume that input data is in some file called input.txt in a directory somewhere:

def read_corpus_bytes(directory): path = Path(directory) / "input.txt" return path.read_bytes()

Now we have the next step, the function batchify:

def batchify(dataset, batch_size): num_batches = len(dataset) // batch_size batches = [] for batch_num in range(num_batches): batch_x_ids_list = [] batch_xs_list = [] batch_y_ids_list = [] batch_ys_list = [] for batch_position in range(batch_size): item = dataset[batch_num + batch_position * num_batches] x_ids, y_ids, xs, ys = item batch_x_ids_list.append(x_ids) batch_xs_list.append(xs) batch_y_ids_list.append(y_ids) batch_ys_list.append(ys) batches.append(( torch.stack(batch_x_ids_list), torch.stack(batch_y_ids_list), batch_xs_list, batch_ys_list, )) return batches

This looks a little more complicated than it actually is, because it's building up a list of tuples, each one of which is a set of x_ids, y_ids, xs and ys. If we imagine that it only did the x_ids, it would look like this:

def batchify(dataset, batch_size): num_batches = len(dataset) // batch_size batches = [] for batch_num in range(num_batches): batch_x_ids_list = [] for batch_position in range(batch_size): item = dataset[batch_num + batch_position * num_batches] x_ids, _, __, __ = item batch_x_ids_list.append(x_ids) batches.append(torch.stack(batch_x_ids_list),) return batches

So, what it's doing is working out how many batches of size batch_size there are in the sequence. With our toy sequence

abc:def:ghi:jkl:mno:pqr:stu:vwx

...and a batch size of two, there are 8 // 2 = 4. In this case, it would then loop from zero to 3 inclusive. Inside that loop it would create a list, then loop from zero to 1 inclusive. The first time round that loop it would get the item at batch_num + batch_position * num_batches, which is 0 + 0 * 4 = 0, so the subsequence abc. It would add that to the list. Then it would go round the inner loop again, and get the item at the new batch_num + batch_position * num_batches. batch_position is now 1, so that would be 0 + 1 * 4 = 4, so it would get the subsequence at index 4, which is mno, and add that to the list.

We'd now have finished our first run through the inner loop, and we'd have the list [abc, mno], so we stack them up into a 2-D tensor:

abc mno

Hopefully it's now fairly clear that in our next pass around the outer loop, we'll pull out the items at index 1 and index 5 to get our next batch, def and pqr, and so on, so that at the end we have done the full calculation to get this:

abc:def:ghi:jkl mno:pqr:stu:vwx

...as a list of 2×3 PyTorch tensors.

And equally hopefully, it's clear that the code in batchify is just doing that, but not only for the x_ids but also for the y_ids, xs and ys.

One thing to note before moving on is what happens if the number of items doesn't divide evenly into batches -- this code:

num_batches = len(dataset) // batch_size

...means that we'll drop them. So, for example, if we wanted a batch size of three with our toy sequence

abc:def:ghi:jkl:mno:pqr:stu:vwx

...then we'd get this:

abc:def ghi:jkl mno:pqr

...and the stu and vwx would be dropped.

And that's it for the dataset code! You might be wondering where the split to get the validation set comes -- that's actually later on, in the training code that actually uses this stuff.

So let's move on to that!

The training code

This is, logically enough, in the file train_rnn.py. There's quite a lot of code in there, but much of it is stuff I put in for quality-of-life (QoL) while using this. It's useful -- but I'll skip it for now and come back to it later. Initially, I want to focus on the core.

We'll start with the main function at the bottom. It starts like this:

@click.command() @click.argument("directory") @click.argument("run_name") def main(directory, run_name): run = RunData(directory, run_name) dataset = NextByteDataset(read_corpus_bytes(run.data_dir), run.train_data["seq_length"]) batches = batchify(dataset, run.train_data["batch_size"])

The RunData-related stuff is QoL, so we'll come back to it later. All we need to know right now is that it's a way of getting information into the system about where its input data is, plus some other stuff -- in particular our TBPTT sequence length and our batch_size. So it uses that to read in some training data, then initialises one of our NextByteDatasets with it and the seq_length, then uses batchify to split it into batches.

Next we have this:

val_batch_count = int(len(batches) * (run.train_data["val_batch_percent"] / 100)) if val_batch_count == 0: val_batch_count = 1 train_batch_count = len(batches) - val_batch_count assert train_batch_count > 0, "Not enough data for training and validation" train_batches = batches[0:train_batch_count] val_batches = batches[train_batch_count:] print(f"We have {len(train_batches)} training batches and {len(val_batches)} validation batches")

So our RunData gives us a validation data percentage; we do some sanity checks and then just slice off an appropriate amount from the end of the batches we got to split the data into train and validation sets. That's the equivalent of the transform from the example earlier from

abc:def:ghi:jkl mno:pqr:stu:vwx

To this training set:

abc:def mno:pqr

...and this validation set:

ghi:jkl stu:vwx

Now, we create our model:

model = KarpathyLSTM(vocab_size=dataset.tokenizer.vocab_size, **run.model_data)

We're using a new KarpathyLSTM class, which is an extension of the PyTorch built-in LSTM class -- we'll come back to that later. It's also getting parameters (things like the size of the hidden state and the number of layers) from the RunData.

Finally, we do the training in a train function:

train(model, run, dataset.tokenizer, train_batches, val_batches)

So let's look at that now. It starts like this:

def train(model, run, tokenizer, train_batches, val_batches): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)

That's fairly standard boilerplate to use CUDA if we have it, and to put the model onto whatever device we wind up using. Next:

optimizer_class = getattr(torch.optim, run.train_data["optimizer"], None) if optimizer_class is None: raise Exception(f"Could not find optimizer {run.train_data['optimizer']}") optimizer = optimizer_class( model.parameters(), lr=run.train_data["lr"], weight_decay=run.train_data["weight_decay"], )

The class name for the optimiser is another one of those things from the RunData, as are the learning rate and weight decay hyperparameters. So we just create an instance of it, and give it the model's parameters to work with along with those.

Next, we get our patience:

if "patience" in run.train_data: patience = run.train_data["patience"] assert patience > 0 else: patience = math.inf

This is a QoL thing, but I think it's worth going into what it actually means. When we're training, we normally train for a fixed number of epochs. However, sometimes we might find that our model was overfitting -- say, at epoch 50 out of 100 we might see that the training loss was still decreasing, but our validation loss started rising.

Any further training past that point might be pointless -- if we're doing things properly, we're saving checkpoints of the model periodically, so we'd be able to resurrect the model that we had at the point where validation loss was lowest, but we're still wasting time continuing training.

A common solution to that is to have early stopping in the training loop. If the validation loss starts rising then we bail out early, and don't do the full number of epochs that we originally planned to do.

Naively, we might keep track of the validation loss from the last epoch, and then if the current epoch has a higher loss, then we bail out. However, sometimes you find that validation loss rises a bit, but then starts going down again -- it's kind of like a meta version of finding a local minimum in the loss function itself.

The solution to that is to use patience -- a measure of how many epochs of rising validation loss you're willing to put up with before you do your early exit. That's the number we're getting from our RunData here -- it's a positive number (note the paranoid assert), and if it's not defined we just assume that we have infinite patience.

The next two lines are related to patience too -- before we go into our main training loop, we define the two variables we need to control early exit with patience:

best_val_loss = None best_epoch = None

Pretty obviously, those are the best validation loss that we've seen so far, and the number of the epoch where we saw it.

Right, finally we get to some training code! We have our epoch loop:

for epoch in tqdm(range(run.train_data["epochs"]), desc="Run"): print(f"Starting epoch {epoch}")

We're using the rather nice tqdm module to get progress bars showing how far we are through the train (ignoring any early exits due to running out of patience, of course).

Next:

print("Sample text at epoch start:") print(repr(generate_sample_text(model, tokenizer, 100, temperature=1)))

We start the epoch by generating some random text from the model. This gives us a reasonably easy-to-understand indication of progress as we go.

Next we put our model into training mode:

...set an initial empty hidden state:

You might be wondering why the hidden state is getting a variable of its own, given that it's meant to be hidden -- it's right there in the name! Don't worry, we'll come to that.

Next we initialise some variables we'll use to keep track of loss -- the total loss across all of the batches we've pushed through, plus the total number of tokens.

total_train_loss = 0 total_train_tokens = 0

The metric we track for each epoch is the loss per token, so we use those to work out an average.

Now it's time to start the inner training loop over our batches:

for x_ids, target_y_ids, xs, ys in tqdm(train_batches, desc=f"Epoch {epoch} train"):

We're just unpacking those tuples that were created by batchify into our x_ids and y_ids (I think I was being ultra-cautious about things here when I added target_ to the start of y_ids). And again we're using tqdm to have a sub-progress bar for this epoch.

Next, we move our Xs and Ys to the device we have the model sitting on:

x_ids = x_ids.to(device) target_y_ids = target_y_ids.to(device)

And then run it through the model. The code to do this looks like this:

if hidden_state is not None: h_n, c_n = hidden_state hidden_state = (h_n.detach(), c_n.detach()) y_logits, hidden_state = model(x_ids, hidden_state) else: y_logits, hidden_state = model(x_ids)

...and I think it's worth breaking down a bit. You can see that there's a branch at the top, if there's a hidden state then we need to pass it in and if there isn't, we don't. But let's focus on the no-hidden state option in the else branch first, because there's something surprising there:

y_logits, hidden_state = model(x_ids)

Remember the description of an RNN from above:

an RNN receives an input, uses that to modify its internal hidden state, and then emits an output based on the updated hidden state. Then you feed in the next input, update the hidden state again, get the next output, and so on.

We can easily extend that to handle batches -- you'd give the RNN a batch of inputs (let's say a tensor b×1, and get a batch of results, also b×1. You'd also need the RNN to hold b hidden states, but that's not a big jump.

But what we're doing in that code is something different -- we're feeding in a whole series of inputs -- that is, x_ids is of size b×n, where n is our desired TBPTT sequence length.

What's worse, in our description above, the hidden state was just that -- something hidden in the model. Now it's being returned by the RNN!

What's going on?

Let's start off with that hidden state. We often need to do stuff with the hidden state from outside the RNN -- indeed, we're detaching it as an important part of our TBPTT. So the PyTorch RNN actually does work rather like the simplified model that I described in my last post, and treats the hidden state like an output, like in this pseudocode:

hidden_state = zeros() for ii in inputs: output, hidden_state = model(ii, hidden_state)

That is, the hidden state is an input and a return value, like this:

An RNN viewed as a simple NN with the input and hidden state going in, and the output and new hidden state coming out

OK, so the hidden state thing makes sense. How about the fact that we're feeding in a whole set of inputs?

This is actually just due to a quality of life thing provided by PyTorch's various RNN classes. Wanting to feed in a sequence is, of course, a super-common thing to want to do with an RNN. So instead of having to do something like the pseudocode above, it's baked in. When you run

y_logits, hidden_state = model(x_ids)

...then because x_ids is b×n, it just runs the RNN n times, accumulating the outputs, then returns the outputs as another b×n tensor, along with the final hidden_state from the last run through that loop. (There is a wrinkle there that we'll come to shortly.)

With that explained, hopefully that branch is clear. We don't have a hidden state right now, so we run all of the inputs across all of our batch items through the RNN in one go, and we get the outputs plus the hidden state that the RNN had at the end of processing that batch of sequences.

Now let's look at the other branch, where there is a pre-existing hidden state:

h_n, c_n = hidden_state hidden_state = (h_n.detach(), c_n.detach()) y_logits, hidden_state = model(x_ids, hidden_state)

Hopefully the last line is clear -- we're just doing the same as we did in the else branch, but we're passing the hidden state in because in this case we actually have one.

The first two lines are a bit more complex. As you know, we need to detach the hidden state from PyTorch's computation graph in order to truncate our backpropagation through time. We're doing that here at the start of the loop just to make sure that each batch that we're pushing through starts with a guaranteed-detached hidden state.

So that explains those calls to the detach methods.

The fact that our hidden state is a tuple of two things that we have to detach separately is a little deeper; for now, all we need to know is that the LSTM models that we're using are a variant of RNN that has two hidden states rather than one, and so we need to handle that. I'll go into that in more depth in a future post.

Once we've done that, we've completed our forward pass for the epoch. Let's move on to the backward pass.

Next, we have this:

train_loss = calculate_loss(y_logits, target_y_ids) train_loss.backward()

Pretty standard stuff. calculate_loss is defined further up in the file:

def calculate_loss(y_logits, target_y_ids): return F.cross_entropy(y_logits.flatten(0, 1), target_y_ids.flatten())

It's exactly the same as the function we used to calculate loss in the LLM-from-scratch posts:

loss = torch.nn.functional.cross_entropy( logits.flatten(0, 1), target_batch.flatten() )

I wrote more about that here if you're interested in the details.

Next, we do something new:

if run.train_data["max_grad_norm"] > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), run.train_data["max_grad_norm"])

This is something that is generally very useful in RNNs. They are prone to vanishing and exploding gradients, and this code is to help handle the exploding case. What it says is, if we've defined a max_grad_norm, we use it to clip gradients when they get too big, which means that training is going be better because we're not going to have updates swinging wildly up and down.

Let's say that we set max_grad_norm to 1.0. If, at the time this code is run, the norm of the gradients -- which is a measurement of their size -- is, say, 10, then they would all be scaled down to 10% of their size, making the new norm 1.0. So that keeps them in check, and stops any wild variations in gradient updates.

So, in short -- it's a stabilisation technique to stop exploding gradients leading to issues with training.

Next, we have our normal code to update the parameters based on these (potentially clipped) gradients:

optimizer.step() optimizer.zero_grad()

And finally, we update our count of how many inputs we've seen and our total loss so far in this epoch:

num_tokens = x_ids.numel() total_train_tokens += num_tokens total_train_loss += train_loss.item() * num_tokens

That's our training loop! Once we've done that code -- run our input through the model, calculated loss, worked out our gradients, clipped them if necessary, done our update and stored away our housekeeping data, we can move on to the next batch in our sequences.

When we've gone through all of the batches that we have, our training for the epoch is complete. We print out our loss per-token:

train_per_token_loss = total_train_loss / total_train_tokens print(f"Epoch {epoch}, average train loss per-token is {train_per_token_loss}")

...and then it's time for our validation loop. This is so similar to the training loop that I don't think it needs a detailed explanation:

with torch.no_grad(): total_val_loss = 0 total_val_tokens = 0 model.eval() hidden_state = None for x_ids, target_y_ids, xs, ys in tqdm(val_batches, desc=f"Epoch {epoch} validation"): x_ids = x_ids.to(device) target_y_ids = target_y_ids.to(device) y_logits, hidden_state = model(x_ids, hidden_state) val_loss = calculate_loss(y_logits, target_y_ids) num_tokens = x_ids.numel() total_val_tokens += num_tokens total_val_loss += val_loss.item() * num_tokens val_per_token_loss = total_val_loss / total_val_tokens print(f"Epoch {epoch}, validation loss is {val_per_token_loss}")

The only big difference (apart from the lack of a backward pass and parameter updates) is that we're not detaching the hidden state, which makes sense -- we're in a no_grad block with the model in eval mode, so there is no computation graph to detach them from.

Validation done, it's time for a bit of housekeeping:

is_best_epoch = True if best_val_loss is None: best_val_loss = val_per_token_loss best_epoch = epoch elif val_per_token_loss < best_val_loss: best_val_loss = val_per_token_loss best_epoch = epoch else: is_best_epoch = False

All we're doing here is keeping track of whether this is the best epoch in terms of validation loss. The is_best_epoch boolean is exactly what it says it is. If we're on our first run through the loop (best_val_loss is None) then we record our current val loss as best_val_loss, and store this epoch's number into best_epoch. Otherwise, we do have an existing best_val_loss, and if our current val loss is lower than that one, we also stash away our current loss and epoch as the best ones. Otherwise we are clearly not in the best epoch so we update is_best_epoch to reflect that.

Once we've done that, we save a checkpoint:

save_checkpoint( run, f"epoch-{epoch}", model, tokenizer, epoch, train_per_token_loss, val_per_token_loss, is_best_epoch )

I'll go into the persistence stuff -- saving and loading checkpoints -- later on.

Next, a QoL thing -- we generate a chart showing how training and validation loss have been going so far:

generate_training_chart(run)

Again, I'll go into that later.

Finally, we do our early stopping if we need to:

if epoch - best_epoch >= patience: print("validation loss not going down, stopping early") break

If the current epoch is more than patience epochs past the one that had the best validation loss so far, then we stop.

That's the end of the outside loop over epochs for our training! If we manage to get through all of that, we print out some sample text:

print("Sample text at training end:") print(repr(generate_sample_text(model, tokenizer, 100, temperature=0.5)))

...and we're done! That's our training loop. Now let's move on to the model itself.

The model

I called my model class a KarpathyLSTM, and you can see the code here.

It's actually not a great name, as it implies there's something specifically Andrej Karpathy-like about it as a way of doing LSTMs, while what I was trying to express is that it wraps a regular PyTorch LSTM with some extra stuff to make it work more like his original Lua Torch implementation. I tried to come up with a more descriptive name, but they all started feeling like the kinds of class names you get in "Enterprise" Java code like

AbstractSingletonProxyFactoryBeanBuilderStrategyAdapterDecoratorVisitorObserverCommandChainOfResponsibilityTemplateMethodFacadeManagerProviderFactoryFactory

so I gave up and named it after Karpathy. Hopefully he'll never find out, and won't mind if he does...

The Lua code does four things differently to PyTorch's built-in LSTM class:

  1. It accepts the inputs as "token IDs", and maps them to a one-hot vector itself.
  2. It applies dropout after the last layer of the LSTM (rather than just internally between the layers).
  3. It expands the output vector back out to the vocab size with a linear layer after the LSTM so that we have logits across our vocab space. This is because an LSTM's output has the same dimensionality as the hidden state.
  4. It runs those logits through softmax so that it returns probabilities.

Let's look at the code now:

class KarpathyLSTM(nn.Module): def __init__(self, vocab_size, hidden_size, num_layers, dropout): super().__init__() self.vocab_size = vocab_size self.lstm = nn.LSTM( input_size=vocab_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True, ) self.final_dropout = nn.Dropout(dropout) self.decoder = nn.Linear(hidden_size, vocab_size) def forward(self, x_ids, state=None): one_hot = F.one_hot(x_ids, num_classes=self.vocab_size).float() if state is not None: outputs, new_state = self.lstm(one_hot, state) else: outputs, new_state = self.lstm(one_hot) outputs = self.final_dropout(outputs) logits = self.decoder(outputs) return logits, new_state

You can see that it's doing 1 to 3 of those steps above -- the one-hot, the extra dropout, and the linear layer to project back to vocab space. The only other oddity there is this kwarg:

That's the wrinkle I was talking about when we went through the training loop and was discussing batches. The PyTorch LSTM by default expects the batch dimension to be the second one of the input tensors -- that is, instead of passing in a b×n tensor, it wants an n×b one. That's not what I'm used to (nor is it what the original Lua code uses, if I'm reading it correctly), but luckily it can be overridden by the logically-named batch_first option.

The only step we don't do in this class is the softmaxing of the logits to convert them to probabilities. That's because PyTorch's built-in torch.nn.functional.cross_entropy wants logits rather than probabilities, so it was easier to just call softmax on the outputs where necessary.

So that's our model. Let's take a look at the code that we can use to run it and generate some text.

Running the model

The code for this is in generate_sample_text.py. Ignoring the click boilerplate that parses the command-line options, we can start here:

def main(directory, run_name, checkpoint, length, temperature, primer_text):

So, we're taking the directory and run name that the QoL helpers that I'll be describing later, a specific checkpoint of a training run to use, the number of bytes that we want to generate, the temperature to use when sampling (more about temperature here) and a "primer" text.

That last one is because in order to get something out of our RNN, we need to feed something in. I tried using a single random byte from the vocab initially (that's still the default, as we'll see shortly), and that was OK, but the bytes aren't equally represented in the training data (eg. "z" is less common than "e", but weird bytes that only occur in occasional multibyte unicode characters are rarer still) -- and that means that we might be trying to get our RNN to start with something it hasn't seen very much, so we get bad results. Even worse, because some of the input text is unicode, there's no guarantee that a random byte is even valid on its own -- it might be something that only makes sense after some leader bytes. So I found that in general it's best to provide a fixed string to start with -- say, "ACT" for Shakespeare, or "He said" for "War and Peace".

So, with those command-line flags, we start off by using the QoL stuff to get the metadata we need about the model:

run = RunData(directory, run_name)

...then we use our persistence code to load up the desired checkpoint:

model, tokenizer = load_checkpoint(run, checkpoint)

At this point we have the version of the model that was saved for that checkpoint, and its associated tokeniser. We move this to an appropriate device -- CUDA if we have it, CPU otherwise:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)

...and then use a helper function to generate some text:

text = generate_sample_text( model, tokenizer, length=length, primer_text=primer_text, temperature=temperature, )

Once we have that, we print it out, after decoding it as UTF-8:

print((primer_text or "") + text.decode("utf-8", errors="replace"))

If a primer was provided, we print it first, but if the primer was a random byte we don't. Also, because the generated bytes might include invalid Unicode, we just replace those with "?" when we decode (that errors="replace" kwarg).

Let's look at the generate_sample_text helper next.

def generate_sample_text(model, tokenizer, length, primer_text=None, temperature=0): assert length >= 1 with torch.no_grad(): model.eval()

So, after a little bit of paranoia about our desired sequence length, we make sure we're not tracking gradients and put the model into eval mode (to disable dropout). Next, we work out our primer bytes -- either by picking a random one, or by decoding the string that we were provided into its constituent UTF-8 bytes:

if primer_text is None: primer_bytes = [tokenizer.random_vocab_byte()] else: primer_bytes = primer_text.encode("utf-8")

The primer needs to be converted to the byte token IDs that our tokeniser uses:

primer = tokenizer.encode(primer_bytes).unsqueeze(0)

The unsqueeze(0) is something you might remember from the LLM posts -- we need to run a batch through our RNN, and the primer_bytes is just a tensor of n bytes. unsqueeze adds on an extra dimension so that it's 1×n, as we want.

Next, we put the primer onto the same device as the model:

primer = primer.to(next(model.parameters()).device)

As an aside, I think I might start using code like that more often, I often find myself passing device variables around and TBH it seems much more natural to just ask the model what device it's using.

Next, we run it through the model:

y_logits, hidden_state = model(primer)

Now we use a helper function to sample from those logits to get our first generated byte:

next_id = sample(y_logits[:, -1, :], temperature)

Note that we are explicitly taking the last item from y_logits. It is a b×n×v tensor, where b is our batch size (always one in this script), n is the length of the primer that we fed in, and v is our vocab size. The y_logits[:, -1, :] just extracts the last item along the n dimension so that we have the b×v logits that came out of the RNN for the last character of the primer, which is what we want.

We'll get to the sample function later, but it returns a b×1 tensor, so now, we just extract the byte ID from it and put it into a new list:

output_ids = [next_id.item()]

Next comes our autoregressive loop -- we've already generated one byte, so we loop length - 1 times to get the rest, each time running the model on the last byte we got, sampling from the distribution implied by the logits, and adding it onto our output_ids list:

for ii in range(length - 1): y_logits, hidden_state = model(next_id, hidden_state) next_id = sample(y_logits[:, -1, :], temperature) output_ids.append(next_id.item())

Once that's done, we have our generated length byte IDs in output_ids, so we just use the tokeniser to turn them back into bytes and return the result:

return tokenizer.decode(output_ids)

Easy, right? Now let's look at sample. The function takes logits and the temperature:

def sample(logits, temperature):

Firstly, we handle the case where temperature is zero. By convention this means greedy sampling -- we just always return the highest-probability next token, so we can use argmax for that:

if temperature == 0.0: return torch.argmax(logits, dim=-1, keepdim=True)

If the temperature is non-zero, we divide the logits by it and run softmax over the result:

probs = torch.softmax(logits / temperature, dim=-1)

...and then we just sample from the probability distribution that we get from that:

return torch.multinomial(probs, num_samples=1)

And that's it! The only things to explain now are the quality of life stuff, and the persistence functions that handle saving and loading checkpoints. Let's look at our QoL things first.

The quality-of-life stuff

When I started building this code I knew I wanted to run RNNs on multiple input texts -- Shakespeare, "War and Peace", etc. I also realised that for each of those input texts, I'd want to try different model sizes.

The underlying concept I came up with was to have "experiments", which would each have a particular training text. Each experiment would have multiple "runs", which would have particular training hyperparameters -- the model size, number of epochs, and so on.

I decided to represent that with a directory structure, which you can see here. One subdirectory per experiment, and if you go into the gilesthomas.com one you'll see that it has two subdirectories, data for the training data and runs for the different training runs I tried.

The data directory contains a file called input.txt, which is the training data itself. That one only exists in the gilesthomas.com experiment, though, because I was concerned with copyright for the other training sets. There is a SOURCE.md file in all data directories for all experiments, though, which explains how to get the data.

The runs directory has more in it. Each run is for a particular set of hyperparameters, so let's look at the ones for the large-model run. We have two files, model.json, which looks like this:

{ "hidden_size": 512, "num_layers": 3, "dropout": 0.5 }

It's essentially the model-specific hyperparameters, the ones we pass in when creating our KarpathyLSTM -- for example, remember this from the training code:

model = KarpathyLSTM(vocab_size=dataset.tokenizer.vocab_size, **run.model_data)

run.model_data is this JSON dict loaded into Python.

There's also train.json, which has the training data:

{ "seq_length": 100, "batch_size": 100, "val_batch_percent": 5, "optimizer": "Adam", "lr": 0.0003, "weight_decay": 0.0, "epochs": 10000, "max_grad_norm": 5.0, "patience": 5 }

Hopefully these are all familiar from the training code; they all go into run.train_data, so they're used in code like this:

optimizer_class = getattr(torch.optim, run.train_data["optimizer"], None) if optimizer_class is None: raise Exception(f"Could not find optimizer {run.train_data['optimizer']}") optimizer = optimizer_class( model.parameters(), lr=run.train_data["lr"], weight_decay=run.train_data["weight_decay"], )

So, now when we look at the start of the train_rnn.py and generate_sample_text.py scripts, and see things like this:

def main(directory, run_name, checkpoint, length, temperature, primer_text): run = RunData(directory, run_name)

...it should be clear that we're loading up those JSON dicts from those files. You can see that code at the start of persistence.py. It looks like this:

class RunData: def __init__(self, directory, run_name): self.root_dir = Path(directory) if not self.root_dir.is_dir(): raise Exception(f"Could not find directory {self.root_dir}") self.run_dir = self.root_dir / "runs" / run_name if not self.run_dir.is_dir(): raise Exception(f"No runs directory {self.run_dir}") self.data_dir = self.root_dir / "data" if not self.data_dir.is_dir(): raise Exception(f"No data directory {self.data_dir}")

So, some basic sanity checking that we have the directories we expect. Next:

self.checkpoints_dir = self.run_dir / "checkpoints" if not self.checkpoints_dir.is_dir(): self.checkpoints_dir.mkdir()

...we create a checkpoints directory if it doesn't exist, stashing away its path, then finally we load up those two JSON files:

self.train_data = json.loads((self.run_dir / "train.json").read_text()) self.model_data = json.loads((self.run_dir / "model.json").read_text())

The rest of that file handles checkpointing, so let's move on to that.

Saving and loading checkpoints

Remember, in the training loop, each epoch we saved a checkpoint:

save_checkpoint( run, f"epoch-{epoch}", model, tokenizer, epoch, train_per_token_loss, val_per_token_loss, is_best_epoch )

..and at the start of the code to generate some text, we load one:

model, tokenizer = load_checkpoint(run, checkpoint)

Let's take a look at saving first. Each checkpoint is a directory with a filename based on the timestamp when it was saved, inside the directory for the run that it relates to, so firstly we work out the full path for that:

def save_checkpoint( run, descriptor, model, tokenizer, epoch, train_loss, val_loss, is_best_epoch ): now = datetime.datetime.now(datetime.UTC) save_dir = run.checkpoints_dir / f"{now:%Y%m%dZ%H%M%S}-{descriptor}"

(The checkpoints directories inside experiments are explicitly ignored in our .gitignore file so that we don't accidentally commit them.)

Now, we don't want half-saved checkpoints due to crashes or anything like that, so we initially create a directory to write to using the path that we're going to use but with .tmp at the end:

save_dir_tmp = save_dir.with_suffix(".tmp") save_dir_tmp.mkdir()

Next, we write a meta file (the path within the checkpoint's dir is worked out by a helper function) containing some useful information about the model's progress -- it's epoch number, the training and validation loss, and the id_to_byte mapping that its tokeniser uses (from which we can later construct a new tokeniser):

meta = { "epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, "id_to_byte": tokenizer.id_to_byte, } meta_file(save_dir_tmp).write_text(json.dumps(meta) + "\n")

Then we dump the model's current parameters into a file using save_file function from the Hugging Face safetensors library (getting the file's path through another helper function):

save_file(model.state_dict(), safetensors_file(save_dir_tmp))

Now that our checkpoint is complete, we can rename our temporary directory to the real name for the checkpoint:

save_dir_tmp.rename(save_dir)

Next, we do some symlinks. We want a symlink in the checkpoints directory called best, which links to the checkpoint that had the lowest validation loss. The training loop is tracking whether any given epoch had the lowest, and you can see it passed in an is_best_epoch parameter, so if that's true, we create the symlink, removing any pre-existing one:

symlink_target = Path(".") / save_dir.name if is_best_epoch: best_path = run.checkpoints_dir / "best" best_path.unlink(missing_ok=True) best_path.symlink_to(symlink_target, target_is_directory=True)

For completeness, we also create one that points to the most recent checkpoint -- that will always be the one we're doing right now, so:

latest_path = run.checkpoints_dir / "latest" latest_path.unlink(missing_ok=True) latest_path.symlink_to(symlink_target, target_is_directory=True)

And that's it for saving!

Loading is even simpler (and note that we can just specify "best" as the checkpoint due to that symlink -- I pretty much always do):

def load_checkpoint(run, checkpoint): checkpoint_dir = run.checkpoints_dir / checkpoint if not checkpoint_dir.is_dir(): raise Exception(f"Could not find checkpoint dir {checkpoint_dir}")

So, we've made sure that the checkpoint directory is indeed a directory. Next, we load up the model metadata:

meta = json.loads(meta_file(checkpoint_dir).read_text())

...and we use safetensors' load_file to load our parameters:

state = load_file(safetensors_file(checkpoint_dir))

Now we can construct a tokeniser based on that id_to_byte mapping that we put into the metadata:

tokenizer = NextByteTokenizer(meta["id_to_byte"])

...and an KarpathyLSTM based on the other metadata parameters:

model = KarpathyLSTM(vocab_size=tokenizer.vocab_size, **run.model_data)

and load the parameters into the model:

model.load_state_dict(state)

That's it! We can return the model and the tokeniser for use:

So that's all the code needed for checkpointing. Now let's look at the final QoL trick, one that I left out of the earlier list because it needs the checkpoints to work: charting our progress.

Charting training progress

Remember this line from the training loop, which was called after we saved our checkpoint?

generate_training_chart(run)

It generates charts like this:

Chart of loss over a training run

The chart is updated every epoch, and saved into the root of the checkpoints directory. There's also a helpful index.html file placed there that reloads that generated chart every second, so you can just load it into a browser tab while you are training and watch it live.

Let's look into the code. It's in generate_training_chart.py. The function starts like this:

def generate_training_chart(run): train_points, val_points, best_epoch = get_training_data(run)

So, we use a utility function (which we'll get into in a moment) to load up the data -- training and validation loss per epoch, and the specific epoch that was the best.

Once we have that, we just use pyplot (with my preferred xkcd styling) to plot the two loss lines:

plt.title("TRAINING RUN LOSS") plt.xkcd() plt.rcParams['font.family'] = "xkcd" fig, ax = plt.subplots(figsize=(8, 6), dpi=100) train_epochs, train_losses = zip(*train_points) val_epochs, val_losses = zip(*val_points) ax.plot(train_epochs, train_losses, label="TRAINING LOSS", marker="o") ax.plot(val_epochs, val_losses, label="VALIDATION LOSS", marker="s")

We also plot a single vertical red line at the best epoch so that we can see if we're past that and running into the patience period:

ax.axvline( best_epoch, color="red", linestyle="--", linewidth=1.5, label="BEST EPOCH" )

Then a bit more pyplot boilerplate...

ax.set_title("TRAINING RUN LOSS") ax.set_xlabel("EPOCH") ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.set_ylabel("LOSS") ax.legend() fig.tight_layout() image_file = run.run_dir / "training_run.png" fig.savefig(image_file, bbox_inches="tight") plt.close(fig)

...and we've got our chart, saved as training_run.png.

Finally, we just copy that useful auto-reloading index.html into the same directory as the chart:

this_dir = Path(__file__).resolve().parent html_source = this_dir / "templates" / "training_run.html" html_dest = run.run_dir / "training_run.html" shutil.copyfile(html_source, html_dest)

...and we're done.

So, how do we get the data? Originally I was keeping lists of loss values over time, but eventually realised that the data was already there in the checkpoint metadata files. So, the get_training_data helper function just iterates over the checkpoints, skipping the latest symlinks, creating lists of (epoch number, loss) tuples for both training and validation loss using the numbers in those metadata files, and for the best symlink just storing its epoch number:

def get_training_data(run): train_losses = [] val_losses = [] best_epoch = None for item in run.checkpoints_dir.iterdir(): if item.name == "latest": continue meta = json.loads(meta_file(item).read_text()) if item.name == "best": best_epoch = meta["epoch"] continue train_losses.append((meta["epoch"], meta["train_loss"])) val_losses.append((meta["epoch"], meta["val_loss"]))

Those loss lists will just be in whatever random order iterdir returned them in, so we sort them by epoch number:

train_losses.sort(key=lambda x: x[0]) val_losses.sort(key=lambda x: x[0])

...and we have something we can return to the charting code:

return train_losses, val_losses, best_epoch

That brings us to the end of the charting code -- and, indeed, to the end of all of the code in this repo! So let's wrap up.

Phew!

That was quite a long writeup, but I think it was worthwhile. Indeed, if you look at the commit history, you'll see that there were one or two things where while explaining the code I realised that it was doing things badly -- not so badly that it didn't work, or gave bad results, but doing things in a way that offended my sense of what's right as an engineer.

Hopefully it was interesting, and has set things up well for the next step, where I'll use the same framework, but plug in my own RNN implementation so that we can see how it compares. Stay tuned :-)

Read Entire Article