Two-Phase Training Boosts Token Throughput Without Architecture Changes
Token Superposition Training (TST) accelerates LLM pre-training by increasing text processed per FLOP during an initial superposition phase, then recovering to standard next-token prediction. In Phase 1 (first r fraction of steps, optimal r=0.2-0.4), segment input sequences of length L into non-overlapping bags of s contiguous tokens (s=3-16, model-size dependent). Average embeddings per bag to create s-tokens, shortening effective sequence to L/s. To match baseline FLOPs per step, scale input data length by s×, ingesting s× more tokens per compute unit.
Output predicts next bag via multi-hot cross-entropy (MCE) loss: assign 1/s probability mass to each of s target tokens, implemented as mean of s standard CE losses using existing fused kernels—no new heads or parameters. Phase 2 resumes from checkpoint with vanilla next-token prediction for remaining 1-r steps, fully removing TST code. Expect 1-2 nat loss spike at transition, resolving in thousands of steps; final model matches standard inference exactly.
Shared embeddings across phases are critical: re-initializing them at boundary on 3B model raises final loss to 2.938 (vs. TST 2.676, baseline 2.808), proving Phase 1 builds transferable representations. Input averaging may regularize embedding geometry (forcing linear separability of s-grams) or act as coarse pre-pre-training; bag prediction echoes multi-token prediction but cheaper, without extra params.
Results: Lower Loss and Speedups at Equal FLOPs or Loss
Validated on 270M/600M dense (SmolLM2/Llama3 shapes), 3B dense (SmolLM3), 10B-A1B MoE (Qwen3), using DCLM or DCLM+FineWeb-Edu data, AdamW/Warmup-Stable-Decay LR, TorchTitan/FSDP on B200 GPUs.
At 3B (s=6, r=0.3): 20k TST steps hit loss 2.676 (vs. baseline 2.677 at 36k steps), using 247 GPU-hours vs. 443 (1.8× speedup); HellaSwag 62.4 vs. 62.3, ARC-Easy 66.3 vs. 65.9.
At 10B-A1B MoE (s=16, r≈0.25): TST processes 2T tokens to loss 2.236 (below baseline 2.252 at 1.05T), using 4,768 GPU-hours vs. 12,311 (2.5× speedup); beats baseline on HellaSwag (71.2 vs. 70.1), ARC-Easy (74.2 vs. 73.8), ARC-Challenge (47.3 vs. 46.3), MMLU (39.0 vs. 37.4).
TST wins equal-FLOPs/equal-loss comparisons; baseline wins equal-data (TST spends less compute per token). Ablations confirm input/output mechanisms orthogonal: each beats baseline alone, combined best.
| Model | s | r | TST GPU-hrs | Baseline GPU-hrs | Speedup |
|---|---|---|---|---|---|
| 3B | 6 | 0.3 | 247 | 443 | 1.8× |
| 10B MoE | 16 | 0.25 | 4768 | 12311 | 2.5× |
Practical Implementation: Minimal Code Changes, Defined Hyperparams
PyTorch tweaks: (1) fold inputs into bags pre-embedding; (2) average embeddings (sum in float32 for precision); (3) MCE loss on output. For large s≥8, use power-law weighting (1/i, k≈-1.25) over uniform.
Hyperparams:
| Model | s Range | r |
|---|---|---|
| 270M | 3-8 | 0.2-0.4 |
| 600M | 6-10 | 0.2-0.4 |
| 3B | 6 | 0.3 |
| 10B MoE | 16 | ~0.25 |
Failures to avoid: positional encodings pre-average (hurts), RoPE rescaling at switch (risks higher loss), separate heads per token (no gain, more cost), binary/CE losses (underperform), retaining TST in Phase 2 (untested).
Use TST for compute-bound runs with ample data; skip if data-limited. Output-only variant suits data-bound.
Trade-offs: Compute Efficiency at Cost of Data
TST trades lower compute per token for higher throughput, ideal when GPUs bottleneck but data abundant. No inference overhead, drop-in for existing stacks. Simplest version (mean in/out, hard switch) optimal—no extras needed.