Show HN: Implementing and Training a Transformer and Tokenizer in Rust

3 months ago 3

An implementation of a causal decoder transformer and BPE tokenizer.

Built with Rust 🦀 and tch-rs 🔥.

Transformer Diagram

This project relies on just for script management. Install it with cargo,

Then install libtorch locally and setup the correct configuration with:

Currently, there's only a setup script for arm64 macos. If you are not on arm64 macos, Take a look at the tch-rs docs for more information on how to setup libtorch for your machine.

To generate text from a model trained on tinystories-10k, run

cargo run --release -- generate \ --dataset-file example/dataset_2048_stories.parquet.ron \ --transformer-config-file example/transformer_config.ron \ --token-count 256 \ --input "Once" \ --transformer-safetensors example/trained_transformer.safetensors

Once upon a time, there was a little boy named Timmy. Timmy liked to play outside and play with his toys when he saw a bird in the sky. He was very excited to see what it was where it was, so he asked the bird if it could be like to fly to the clouds. The bird said yes!

Timmy and his owner went out of the forest and went to the park. They were so happy to see the bird in the park. But then, the bird saw a big bird who was trying to feet it. Timmy and his mom were very scared and wanted to help. They decided to say hello and stay in the woods. They said he could not find the bird's family. They said they could grow. They said they could make the bird happy.

Later that day, Timmy and his family went to the park to play. They got a new bird in the sky. They were happy and said, "Thank you, mom! You are so kind." They went back to the park. They played with the ball every day.

After that, Timmy and his family went to the park to play. They saw something very big and wide and round. They decided to pretend it was a big, red car to play. They played with the car and had lots of fun.

Train a Model from Scratch

The tinystories-10k dataset is included in this repository under ./stories.parquet. Any .parquet file, who's first column contains the text entries to train on, can be used.

Start by tokenizing the stories.parquet dataset,

cargo run --release -- --seed 1234 tokenize \ --data-parquet stories.parquet \ --transformer-config-file example/transformer_config.ron

This will generate a file in datasets/dataset_2048_stories.parquet.ron, containing the tokenized dataset.

Train a model on the tokenized dataset,

cargo run --release -- --seed 1234 train \ --dataset-file datasets/dataset_2048_stories.parquet.ron \ --train-config-file example/train_config.ron \ --transformer-config-file example/transformer_config.ron

This will generate .safetensors checkpoints in a directory under ./training. The fully trained model's .safetensors will be named final.safetensors.

The example model was trained with MPS on an M3 Macbook Air, it took around ~2.5 hours. If you are on an MPS-supporting Mac, you can enable it explicitly with --mps

cargo run --release -- --seed 1234 --mps train \ --dataset-file datasets/dataset_2048_stories.parquet.ron \ --train-config-file example/train_config.ron \ --transformer-config-file example/transformer_config.ron

The example configuration in the repo will yield a 2.1M parameter model, Ronen Eldan and Yuanzhi Li were able to produce a 1M parameter model with reasonable performance in TinyStories: How Small Can Language Models Be and Still Speak Coherent English?, so it may be possible to make the model even more lightweight, especially for CPU training.

You can generate text from your trained model with

cargo run --release -- generate \ --dataset-file datasets/dataset_2048_stories.parquet.ron \ --transformer-config-file example/transformer_config.ron \ --token-count <TOKEN_COUNT> \ --input <INPUT> \ --transformer-safetensors <SAFETENSORS>

This implementation was built to be as verbose as possible. The core of the model can be found under ./src/model.rs, heavily inspired by Karpathy's minGPT.

Some details:

The BPE tokenizer can be found under ./src/tokenizer.rs. Tokenized datasets are serialized into .ron (Rusty-Object-Notation) files for later use. The tokenizer includes space-prefixing, no merges across words, and some do-not-merge tokens.

The training loop can be found under ./src/train.rs. It utilizes an AdamW optimizer. There is no learning rate scheduling in this branch, take a look at the linear-warmup-cosine-decay branch for an implementation of a more refined scheduler (though it may be out of date with main).

Custom models and training strategies can be configured with .ron files. See ./example/train_config.ron and ./example/transformer_config.ron for example configuration.

  • Run cargo run --release -- -h to get information on what commands are available, and upper level arguments (ie. seed, MPS).
  • Run cargo run --release -- <COMMAND> -h to get information on what arguments specific commands take.

Some important arguments:

  • The tch and Rust RNGs can be seed with the --seed argument, ie cargo run --release -- --seed 1234 <COMMAND>
  • Training can start from a checkpoint, just pass the --transformer-safetensors argument to the train.
  • Temperature can be modified when generating text, just pass the --temperature command to generate.

Built by Liam Ilan, made in Canada 🇨🇦 ❤️.

Read Entire Article