Bits-per-Byte (BPB): a tokenizer-agnostic way to measure LLMs

3 weeks ago 1

Karpathy recently released nanochat repo which cotains code for training the best ChatGPT under $100. While skimming the high level code, I noticed across bits per bytes instead of typical cross entropy loss. And, i found it interesting, so i decided to dig in.

TL;DR

  • Bit per byte (BPB) is just cross-entropy measured per byte. We divide cross-entropy by log(2) to convert to bits.
  • Because it’s per byte, BPB is tokenizer-agnostic and lets you compare models fairly even when they use different vocabularies and rules.
  • Perplexity and token-level loss change when you change the tokenizer; BPB largely doesn’t.

LLM doesn’t predict the text, it predicts the (next) token. But token definitions depend on the tokenizer (BPE, Unigram, merges, special tokens, etc.). Swap tokenizers and the same sentence can become more or fewer tokens. So per-token metrics (avg CE, perplexity) change even if the underlying modeling quality didn’t.

Some popular tokenizer choices are:

Model Tokenizer Vocab Size
GPT-4 cl100k_base (BPE) 100,256
LLaMA 3 TikToken (BPE) 128,000
Gemini 2.5 SentencePiece (Unigram) 256,000
Claude closed-source undisclosed

Different tokenizers ≠ comparable “tokens”. So a model that uses a coarser tokenizer (fewer, longer tokens) can appear to have a lower per-token loss or perplexity, simply because the denominator changed.

Instead of normalizing loss per token, normalize per byte of UTF-8 text that those tokens represent. Then, no matter how you split words into tokens, you’re still asking: how many bits, on average, does the model need to encode each byte of text?

Below is the simplified and more readable version of the original code.

import math import torch import torch.distributed as dist @torch.no_grad() def evaluate_bpb(model, batches, steps: int, token_bytes: torch.Tensor) -> float: """ Compute Bits-Per-Byte (BPB) over `steps` batches. Shapes (your mental model): B = batch size Seq = sequence length V = vocab size Inputs: - model: callable like model(x, y, loss_reduction='none') -> loss per token. Expects: x: (B, Seq) token ids (int64) y: (B, Seq) target token ids (int64), may contain ignore_index (<0) Returns: loss2d: (B, Seq) per-token loss in NATs (float32/float16) - batches: iterable yielding (x, y) as above. - steps: number of batches to evaluate. - token_bytes: (V,) int64 — byte length of each token id; 0 for special tokens (those should not count toward BPB). Notes: - BPB = (sum of losses in NATs over *counted* tokens) / (ln(2) * total_counted_bytes) - Tokens contribute to the denominator by their byte length; tokens with 0 bytes (specials) and ignored targets (<0) are excluded from both numerator & denominator. """ device = model.get_device() if hasattr(model, "get_device") else next(model.parameters()).device # Accumulators across steps (and later across ranks) sum_nats = torch.tensor(0.0, dtype=torch.float32, device=device) # scalar sum_bytes = torch.tensor(0, dtype=torch.int64, device=device) # scalar token_bytes = token_bytes.to(device=device, dtype=torch.int64) # (V,) batch_iter = iter(batches) for _ in range(steps): x, y = next(batch_iter) # x: (B, Seq), y: (B, Seq) x = x.to(device) y = y.to(device) loss2d = model(x, y, loss_reduction='none') # (B, Seq) NATs loss1d = loss2d.reshape(-1) # (B*Seq,) y1d = y.reshape(-1) # (B*Seq,) if (y1d < 0).any(): # Mask out ignore_index (<0) before indexing into token_bytes valid = (y1d >= 0) # (B*Seq,) ysafe = torch.where(valid, y1d, torch.zeros_like(y1d)) # (B*Seq,) nb = torch.where(valid, token_bytes[ysafe], torch.zeros_like(y1d)) # (B*Seq,) int64 else: nb = token_bytes[y1d] # (B*Seq,) int64 # Count only tokens with positive byte length counted = (nb > 0) # (B*Seq,) bool sum_nats += (loss1d[counted]).sum() # scalar sum_bytes += nb[counted].sum() # scalar int64 # Distributed sum over all ranks, if initialized if dist.is_initialized() and dist.get_world_size() > 1: dist.all_reduce(sum_nats, op=dist.ReduceOp.SUM) dist.all_reduce(sum_bytes, op=dist.ReduceOp.SUM) total_nats = float(sum_nats.item()) total_bytes = int(sum_bytes.item()) # Guard against division by zero (e.g., all tokens were special/ignored) if total_bytes == 0: return float("nan") bpb = total_nats / (math.log(2.0) * total_bytes) return bpb
Read Entire Article