Multi-Token Prediction Performance on GSM8K Mathematical Reasoning

cover
11 Jun 2025

Abstract and 1. Introduction

2. Method

3. Experiments on real data

3.1. Benefits scale with model size and 3.2. Faster inference

3.3. Learning global patterns with multi-byte prediction and 3.4. Searching for the optimal n

3.5. Training for multiple epochs and 3.6. Finetuning multi-token predictors

3.7. Multi-token prediction on natural language

4. Ablations on synthetic data and 4.1. Induction capability

4.2. Algorithmic reasoning

5. Why does it work? Some speculation and 5.1. Lookahead reinforces choice points

5.2. Information-theoretic argument

6. Related work

7. Conclusion, Impact statement, Environmental impact, Acknowledgements, and References

A. Additional results on self-speculative decoding

B. Alternative architectures

C. Training speeds

D. Finetuning

E. Additional results on model scaling behavior

F. Details on CodeContests finetuning

G. Additional results on natural language benchmarks

H. Additional results on abstractive text summarization

I. Additional results on mathematical reasoning in natural language

J. Additional results on induction learning

K. Additional results on algorithmic reasoning

L. Additional intuitions on multi-token prediction

M. Training hyperparameters

I. Additional results on mathematical reasoning in natural language

Figure S13: Performance on the mathematical reasoning benchmark GSM8K (Cobbe et al., 2021). We evaluate pretrained next-token and multi-token prediction models trained on 200B and 500B tokens of natural language in 8-shot mode using nucleus sampling (Holtzman et al., 2020) with probability mass 0.95 and various sampling temperatures. Reported are the frequencies of the correct final answer to appear among k samples, for k = 1, 10, 100, estimated from 200 samples like in code generation benchmarks (Chen et al., 2021). After 200B tokens, the 2-token prediction model has a clear advantage over the next-token baseline but the order reverses after 500B tokens. The 4-token prediction model is worse throughout. We interpret this similarly to the findings in Section 4.1: the follow-your-nose chains-of-thought required for GSM8K may be difficult to learn from a limited amount of data, attesting to the data efficiency of multi-token prediction training. Once the correct circuits for correct autoregressive chains-of-thought in this domain have formed, however, multi-token prediction comes at a cost.

This paper is available on arxiv under CC BY 4.0 DEED license.

Authors:

(1) Fabian Gloeckle, FAIR at Meta, CERMICS Ecole des Ponts ParisTech, and contributed equally;

(2) Badr Youbi IdrissiFAIR at Meta, LISN Université Paris-Saclay, and contributed equally;

(3) Baptiste Rozière, FAIR at Meta;

(4) David Lopez-Paz, FAIR at Meta and his the last author;

(5) Gabriel Synnaeve, FAIR at Meta and the last author.