Mixture-of-Depths: Dynamically allocating compute in transformer-based language models

Authors: David Raposo, Sam Ritter, Blake Richards, Timothy Lillicrap, Peter Conway Humphreys, Adam Santoro

What

This paper introduces Mixture-of-Depths (MoD), a novel technique for transformer models that dynamically allocates compute resources by allowing tokens to skip entire transformer blocks based on learned routing decisions, thereby reducing computational cost without sacrificing performance.

Why

This paper is important because it addresses the inherent inefficiency of traditional transformers, which expend uniform computational effort per token regardless of the complexity of the prediction. MoD offers a pathway to significantly reduce the computational cost of training and inference in transformers, particularly relevant for resource-intensive large language models, by selectively allocating compute resources where they are most needed.

How

The authors propose a method where a per-block router assigns scalar weights to each token, indicating the importance of processing that token through the block. The top-k tokens with the highest weights are processed through the self-attention and MLP layers, while the rest bypass the block through a residual connection. This dynamic allocation is achieved using a non-causal top-k routing scheme during training and a causal predictor-based routing scheme during inference, both of which are trained through the language modeling objective and an auxiliary task. The authors perform extensive experiments with different model sizes and FLOP budgets, comparing MoD transformers with traditional transformers, demonstrating significant performance gains and computational savings.

Result

Key findings include: (1) MoD transformers can outperform isoFLOP-optimal baseline transformers in terms of both performance and speed. (2) Optimal MoD configurations involve routing every other block and using a low capacity (e.g., 12.5% of the sequence length) for the computationally intensive blocks. (3) Learned routing is crucial for MoD’s effectiveness, significantly outperforming stochastic routing schemes. (4) MoD can be seamlessly integrated with Mixture-of-Experts (MoE) models, further enhancing performance and efficiency. (5) The non-causal nature of top-k routing during training can be effectively addressed during autoregressive sampling using a causal predictor, resulting in minimal performance degradation.

LF

The paper acknowledges limitations and suggests future work: (1) While the current work focuses on a decoder-only setting, extending MoD to encoder-decoder architectures requires further investigation for efficient handling of sequential decoding with non-causal routing. (2) The paper primarily explores routing between standard transformer blocks and residual connections. Investigating routing to diverse computational paths like memory lookup or tool-use functions could be beneficial. (3) Future research could explore decoupling routing decisions for queries, keys, and values in self-attention, potentially leading to more nuanced and efficient compute allocation. (4) MoD’s potential in drastically increasing context length for predictions by efficiently managing long-term memory through selective routing warrants further investigation.

Abstract

Transformer-based language models spread FLOPs uniformly across input sequences. In this work we demonstrate that transformers can instead learn to dynamically allocate FLOPs (or compute) to specific positions in a sequence, optimising the allocation along the sequence for different layers across the model depth. Our method enforces a total compute budget by capping the number of tokens () that can participate in the self-attention and MLP computations at a given layer. The tokens to be processed are determined by the network using a top- routing mechanism. Since is defined a priori, this simple procedure uses a static computation graph with known tensor sizes, unlike other conditional computation techniques. Nevertheless, since the identities of the tokens are fluid, this method can expend FLOPs non-uniformly across the time and model depth dimensions. Thus, compute expenditure is entirely predictable in sum total, but dynamic and context-sensitive at the token-level. Not only do models trained in this way learn to dynamically allocate compute, they do so efficiently. These models match baseline performance for equivalent FLOPS and wall-clock times to train, but require a fraction of the FLOPs per forward pass, and can be upwards of 50% faster to step during post-training sampling.