Mixture-of-Experts explained with PyTorch implementation

1 hour ago 2

Faruk Alpay

Mixture-of-Experts (MoE) layers let us grow a Transformer’s capacity without slowing it down linearly. Instead of one monolithic feed-forward network (FFN), an MoE layer contains multiple expert subnetworks plus a small router (gate). The router looks at each token’s features and assigns it (via a softmax) to one or more experts. In effect, each token “asks” only a few experts to process it, and only those experts’ parameters are used. This means we can have (say) 64× more parameters in total, yet each token only triggers 1–2 experts — so the compute cost grows modestly. In practice, this divide-and-conquer approach lets different experts specialize on different patterns, while the overall model remains efficient to train and run.

Press enter or click to view image in full size

This Switch Transformer encoder block replace the dense feed forward network (FFN) layer present in the Transformer with a sparse Switch FFN layer (light blue). The layer operates independently on the tokens in the sequence. The diagram takes two tokens (x1 = “More” and x2 = “Parameters” below) then, it is being routed (solid lines) across four FFN experts, where the router independently routes each token. The switch FFN layer returns the output of the selected FFN multiplied by the router gate value (dotted-line). (Source)

How MoE Layers Work

In a standard Transformer, each layer uses a single dense FFN that applies two linear transforms (with a nonlinearity) to every token, but in an MoE variant that one FFN is replaced by E parallel FFNs (the “experts”). A small routing network — often just a linear layer — assigns each token a set of expert scores (logits). If the hidden size is d and there are E experts, the router produces an E-dimensional logit vector per token, then a softmax converts these into a probability distribution over experts. The model then picks the top-k experts (commonly 1 or 2) and ignores the rest, so only those selected experts actually process the token, enabling specialization and large capacity without proportional compute cost.

Intuitively, the router’s softmax behaves like a learned attention mechanism over the experts. Because the gating network and experts train together, the model gradually figures out which expert is best suited for each kind of input. In practice, each token only touches a small portion of the whole model. This matters because it effectively decouples model size from compute cost.

For example, Google’s GLaM model used 64 experts (each itself a large FFN) but with top-2 gating, so each token only cost 2× the computation of a dense layer. Similarly, the 8‑expert Mixtral-8x7B model (45B total parameters) requires only about 2× the compute of a 14B dense model, because each token is dispatched to 2 experts.

Key point: In an MoE layer we “pay” compute only for k experts per token (sparse activation), but still have E× as many parameters (massive capacity).

MoE vs. Dense FFN

Dense FFN (standard). A standard Transformer layer applies the same feed-forward network (FFN) to every token. This FFN typically contains two linear layers with a nonlinear activation such as ReLU or GELU between them. Because each token uses the same FFN, all parameters of that FFN participate in every forward pass.

Sparse MoE FFN. A Mixture-of-Experts layer replaces the single FFN with E parallel FFNs, referred to as experts. For each token, only a small subset of these experts is used. A gating network assigns weights to all experts, and the top-k experts are selected for that token. Only the chosen experts perform computation, and their outputs are combined — usually as a weighted sum. Experts not selected for that token perform no computation and therefore do not contribute to the output.

Routing (hard vs soft). There are two main ways to use the gate’s outputs:

  • Soft routing: treat the gate’s output as a full probability distribution and compute a weighted sum of all experts’ outputs.
  • Hard routing: select only the top-k experts (top-k gating) and ignore the others. This saves computation because experts that are not selected do no work for that token.

To make hard routing train well in practice, it is common to add small amounts of noise to the router logits (for example, in a “noisy top-k” scheme) and to include auxiliary losses that encourage the gate to spread tokens fairly across experts, so that no single expert becomes overloaded while others remain unused.

Implementing MoE in PyTorch

Let’s build a toy MoE layer step by step. Assume familiarity with torch.nn modules and basic Transformer ideas. The goal here is to make the intuition clear and keep the focus on the core mechanics.

import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
"""A simple feed-forward expert (one layer for demo)."""
def __init__(self, d_model, d_hidden):
super().__init__()
self.fc1 = nn.Linear(d_model, d_hidden)
self.act = nn.ReLU()
self.fc2 = nn.Linear(d_hidden, d_model)
def forward(self, x):
# x: (seq_len, d_model) or (batch, seq_len, d_model)
# For simplicity assume x is (batch, d_model) for one token.
return self.fc2(self.act(self.fc1(x)))
class Router(nn.Module):
"""Gating network: maps token -> expert logits."""
def __init__(self, d_model, num_experts):
super().__init__()
self.fc = nn.Linear(d_model, num_experts)
def forward(self, x):
# x: (batch, d_model)
return F.softmax(self.fc(x), dim=-1) # (batch, num_experts)

Here, Expert is just a two-layer FFN with a ReLU in between. Router is a single linear layer + softmax that produces a distribution over experts for each token.

Now the components are combined into an MoE layer. The layer receives a batch of token vectors, computes their gate weights, selects the top-k experts for each token, and aggregates the corresponding expert outputs.

class MoELayer(nn.Module):
def __init__(self, d_model, d_hidden, num_experts, top_k=2):
super().__init__()
self.experts = nn.ModuleList([Expert(d_model, d_hidden)
for _ in range(num_experts)])
self.router = Router(d_model, num_experts)
self.top_k = top_k
def forward(self, x):
# x: (batch, d_model) token vectors
batch_size = x.size(0)
# Compute gate probabilities
gates = self.router(x) # (batch, num_experts), already softmaxed
# Pick top-k experts for each token
topk_vals, topk_idx = torch.topk(gates, self.top_k, dim=-1) # each is (batch, top_k)
# Normalize the top-k weights
topk_norm = topk_vals / topk_vals.sum(dim=1, keepdim=True) # (batch, top_k)

# Compute outputs from each expert for all tokens
# We will collect expert outputs manually to avoid running all experts on all tokens.
outputs = x.new_zeros(batch_size, x.size(1)) # accumulate output

for i, expert in enumerate(self.experts):
# Create a mask for tokens routed to this expert
mask = (topk_idx == i) # (batch, top_k) bool
if mask.any():
# Expand mask and select tokens
mask = mask.float() # 1.0 where token chooses this expert
# Mask shape: (batch, top_k) -> (batch,)
mask1d = mask.sum(dim=1) # how many times each token picked expert i (0 or 1)
if mask1d.sum() == 0:
continue
# Get indices of tokens that go to expert i
token_indices = mask1d.nonzero(as_tuple=False).squeeze(-1)
x_i = x[token_indices] # select those tokens
# Apply expert to these tokens
y_i = expert(x_i) # (n_sel, d_model)

# Weight by corresponding normalized gate values
# For each selected token j, find its normalized weight
weights = topk_norm[token_indices] # (n_sel, top_k)
# Find which position in top_k was expert i
# Create a (n_sel,) vector of weights for this expert:
mask_k = (topk_idx[token_indices] == i).float() # (n_sel, top_k)
w_i = (weights * mask_k).sum(dim=1) # (n_sel,)
outputs[token_indices] += y_i * w_i.unsqueeze(1)
return outputs # (batch, d_model)

The above MoELayer illustrates the full routing logic (though in practice this would be vectorized more efficiently). Each token is dispatched to its top-k experts; the selected experts run, their outputs are weighted, and the results are summed. Experts that are not selected contribute nothing to the token’s output.

Note: This code is for illustration. Libraries like HuggingFace implement MoE more efficiently (e.g. by grouping tokens per expert).

Example: Gating Outputs

Let’s explore a quick example of the router’s behavior (without full training). Imagine having 4 experts and 5 randomly generated token vectors:

router = Router(d_model=8, num_experts=4)
tokens = torch.randn(5, 8) # 5 tokens, hidden size 8
gates = router(tokens) # (5,4) softmax probabilities
print("Gate probabilities per token:\n", gates)
# Pick top-2 experts for each token
top2_vals, top2_idx = torch.topk(gates, 2, dim=-1)
print("Top-2 expert indices per token:\n", top2_idx)

This will print out random softmax weights and which experts got chosen. In a real trained model, these gates would reflect learned preferences.

Routing behavior: By default the router outputs a probability for each expert; we then do top-k selection. Unselected experts are pruned to zero output.

Gate Weights and Routing Patterns

Even without supervision, MoE tends to learn meaningful routing. Each expert often specializes on some subset of inputs. For example, the authors of the Switch Transformer observed that encoder experts often specialize on specific token groups (e.g. one expert handles punctuation, another handles conjunctions, etc.). Decoder experts may be less cleanly specialized, but still show consistency in which tokens they process.

Press enter or click to view image in full size

Token groups routed to different experts in an MoE. Each column is an expert; entries show examples of tokens that the model routed to that expert during training (e.g. one expert got mostly punctuation, another conjunctions, etc.). This emerges purely from the model learning to minimize loss while keeping experts balanced. (Source)

In a continuous-MoE trained on MNIST digits, each expert tends to receive inputs with specific digit shapes. In the figure, for example, Expert 0 mostly handles 1s, Expert 1 mostly handles 2s, Expert 2 focuses on 7-like digits, and Expert 3 specializes in 5s.

Press enter or click to view image in full size

Each expert becomes sensitive to certain digits (e.g. Expert 0 to “1”, Expert 1 to “2”, Expert 2 to “7”, and Expert 3 to “5”). (Source)

The router often learns to send similar inputs to the same expert. The figure above makes this concrete. Each block of 3×3 images shows the inputs that a particular expert “likes” the most:

  • Expert 0 receives almost only the digit 1 — thin, vertical strokes.
  • Expert 1 mainly sees 2-shaped digits (with an occasional 1 that looks similar to a 2’s first stroke).
  • Expert 2 is used for digits shaped like 7 (plus a few that resemble 7s, such as a 2 or an 8 with a similar top).
  • Expert 3 specializes in the digit 5.

Even though we never told the model “Expert 0 should handle ones, Expert 3 should handle fives,” the routing mechanism and training objective encourage this pattern to emerge on their own. As each expert focuses on a narrower slice of the problem, the overall system behaves like a team of specialists rather than a single network trying to handle every possible input by itself.

Efficiency and Capacity Trade-offs

A big advantage of MoEs is efficiency at scale. Since only a few experts fire per token, the per-token compute is small, even if the model has many experts (and thus huge total parameters). In practice this means we can train much bigger models for the same speed. For example:

  • Mixtral-8x7B (an 8‑expert decoder model with 45B params) only uses the flops of a 14B dense model, because each token runs through 2 experts.
  • More generally, an MoE with E experts and top-k gating has about E× more parameters but only k× the compute of the base FFN. If kE, that’s a huge capacity win with small overhead.

Simple numerical illustration: With 8 experts and top-2 gating, each token touches 2 experts. So the work is 2× a normal layer (ignoring router). But there are 8 distinct FFNs (8× parameters). If each FFN is 10× larger hidden size, you get 80× the parameter count while only 2× compute — a 40× effective gain in capacity per compute unit. (The extra router params are tiny by comparison.)

Press enter or click to view image in full size

Illustration of an MoE layer with 8 experts and top-2 routing: each token activates only 2 experts (≈2× compute) while accessing the capacity of 8 large FFNs (≈80× parameters), yielding about 40× more capacity per unit of compute. (Figure by Author.)

On the flip side, MoEs need extra VRAM (all experts’ weights must reside in memory) and some overhead (for selecting experts and moving data). But modern frameworks and hardware mitigate this. For example, FlashAttention, batching tricks, and careful sharding make MoEs practical. The key takeaway: MoE decouples model size from computation , enabling extremely large models at scale.

Press enter or click to view image in full size

Comparison of a dense feed-forward layer and a Mixture-of-Experts layer in terms of parameters, compute per token, and VRAM usage: the MoE design holds vastly more parameters while only modestly increasing compute and memory, illustrating how expert routing lets model size grow without a proportional rise in per-token cost. (Figure by Author.)

MoEs do introduce costs: all expert weights must stay in memory, and routing tokens between experts adds some overhead. Modern libraries and hardware reduce this burden through optimized kernels and sharding strategies, so MoEs remain practical in large systems. In return, they allow model size to grow dramatically without a matching rise in per-token computation.

Conclusion

MoE layers let you build massive models without paying the full compute cost. The idea is simple: split the work across many experts, but only activate a few per token. You get 8x or even 64x more parameters while barely doubling your training time.

The routing mechanism is surprisingly clever. Without any explicit instruction, experts naturally specialize on different patterns. One handles punctuation, another processes numbers, and so on. This happens automatically during training, which is pretty cool when you think about it.

Of course there are downsides. You need enough memory to hold all the experts, even if most sit idle for any given token. Load balancing can be tricky too. But for many applications the tradeoff is worth it. Models like Mixtral show that MoE works at scale.

If you’re building something that needs serious capacity but can’t afford the compute of a giant dense model, MoE is worth trying. The implementation isn’t too complex, and the potential gains are real. Just don’t expect miracles with small datasets or simple tasks. MoE shines when you have lots of data and need lots of capacity.

Read Entire Article