Two-Phase Training Boosts Token Throughput Without Architecture Changes

Token Superposition Training (TST) accelerates LLM pre-training by increasing text processed per FLOP during an initial superposition phase, then recovering to standard next-token prediction. In Phase 1 (first r fraction of steps, optimal r=0.2-0.4), segment input sequences of length L into non-overlapping bags of s contiguous tokens (s=3-16, model-size dependent). Average embeddings per bag to create s-tokens, shortening effective sequence to L/s. To match baseline FLOPs per step, scale input data length by s×, ingesting s× more tokens per compute unit.

Output predicts next bag via multi-hot cross-entropy (MCE) loss: assign 1/s probability mass to each of s target tokens, implemented as mean of s standard CE losses using existing fused kernels—no new heads or parameters. Phase 2 resumes from checkpoint with vanilla next-token prediction for remaining 1-r steps, fully removing TST code. Expect 1-2 nat loss spike at transition, resolving in thousands of steps; final model matches standard inference exactly.

Shared embeddings across phases are critical: re-initializing them at boundary on 3B model raises final loss to 2.938 (vs. TST 2.676, baseline 2.808), proving Phase 1 builds transferable representations. Input averaging may regularize embedding geometry (forcing linear separability of s-grams) or act as coarse pre-pre-training; bag prediction echoes multi-token prediction but cheaper, without extra params.

Results: Lower Loss and Speedups at Equal FLOPs or Loss

Validated on 270M/600M dense (SmolLM2/Llama3 shapes), 3B dense (SmolLM3), 10B-A1B MoE (Qwen3), using DCLM or DCLM+FineWeb-Edu data, AdamW/Warmup-Stable-Decay LR, TorchTitan/FSDP on B200 GPUs.

At 3B (s=6, r=0.3): 20k TST steps hit loss 2.676 (vs. baseline 2.677 at 36k steps), using 247 GPU-hours vs. 443 (1.8× speedup); HellaSwag 62.4 vs. 62.3, ARC-Easy 66.3 vs. 65.9.

At 10B-A1B MoE (s=16, r≈0.25): TST processes 2T tokens to loss 2.236 (below baseline 2.252 at 1.05T), using 4,768 GPU-hours vs. 12,311 (2.5× speedup); beats baseline on HellaSwag (71.2 vs. 70.1), ARC-Easy (74.2 vs. 73.8), ARC-Challenge (47.3 vs. 46.3), MMLU (39.0 vs. 37.4).

TST wins equal-FLOPs/equal-loss comparisons; baseline wins equal-data (TST spends less compute per token). Ablations confirm input/output mechanisms orthogonal: each beats baseline alone, combined best.

ModelsrTST GPU-hrsBaseline GPU-hrsSpeedup
3B60.32474431.8×
10B MoE160.254768123112.5×

Practical Implementation: Minimal Code Changes, Defined Hyperparams

PyTorch tweaks: (1) fold inputs into bags pre-embedding; (2) average embeddings (sum in float32 for precision); (3) MCE loss on output. For large s≥8, use power-law weighting (1/i, k≈-1.25) over uniform.

Hyperparams:

Models Ranger
270M3-80.2-0.4
600M6-100.2-0.4
3B60.3
10B MoE16~0.25

Failures to avoid: positional encodings pre-average (hurts), RoPE rescaling at switch (risks higher loss), separate heads per token (no gain, more cost), binary/CE losses (underperform), retaining TST in Phase 2 (untested).

Use TST for compute-bound runs with ample data; skip if data-limited. Output-only variant suits data-bound.

Trade-offs: Compute Efficiency at Cost of Data

TST trades lower compute per token for higher throughput, ideal when GPUs bottleneck but data abundant. No inference overhead, drop-in for existing stacks. Simplest version (mean in/out, hard switch) optimal—no extras needed.