Future Lens: Anticipating Subsequent Tokens from a Single Hidden State
Authors: Koyena Pal, Jiuding Sun, Andrew Yuan, Byron C. Wallace, David Bau
What
This paper investigates whether hidden state vectors in large language models (LLMs) encode information sufficient to predict multiple tokens ahead, going beyond the typical one-token prediction.
Why
This research is significant because it probes the depth of information encoded within individual hidden states of LLMs, potentially revealing a deeper understanding of how these models process and retain information over longer spans of text.
How
The authors test their hypothesis by employing three methods on GPT-J-6B: (1) training linear models to approximate future hidden states and decode them, (2) conducting causal intervention by transplanting hidden states to different contexts, and (3) training a “soft prompt” to optimize the extraction of subsequent token information from a hidden state.
Result
The study finds that individual hidden states, especially in middle layers, contain significant information about future tokens, going beyond immediate next-token predictions. Notably, the “learned prompt causal intervention” method achieves the highest accuracy in predicting subsequent tokens, even surpassing a bigram baseline.
LF
The authors acknowledge limitations regarding the training data size, the focus on a single LLM (GPT-J-6B), the lack of prior baselines for this specific task, and the limitation of predicting up to four tokens ahead. Future work could explore larger datasets, other LLMs, alternative baseline models (e.g., RNNs, Non-Autoregressive generation), and extend the prediction horizon beyond four tokens.
Abstract
We conjecture that hidden state vectors corresponding to individual input tokens encode information sufficient to accurately predict several tokens ahead. More concretely, in this paper we ask: Given a hidden (internal) representation of a single token at position in an input, can we reliably anticipate the tokens that will appear at positions ? To test this, we measure linear approximation and causal intervention methods in GPT-J-6B to evaluate the degree to which individual hidden states in the network contain signal rich enough to predict future hidden states and, ultimately, token outputs. We find that, at some layers, we can approximate a model’s output with more than 48% accuracy with respect to its prediction of subsequent tokens through a single hidden state. Finally we present a “Future Lens” visualization that uses these methods to create a new view of transformer states.