Muon's Orthogonal Updates Cause Neuron Death in Tall Matrices
Muon computes the polar factor UVᵀ of gradient matrix G (via thin SVD) for semi-orthogonal weight updates W ← W - η UVᵀ, enabling fast convergence on nanoGPT speedrun benchmarks over AdamW. In tall matrices like SwiGLU MLP up-projections (more rows n than columns m), row-norm anisotropy emerges: impossible for perfectly orthogonal matrices to have uniform row norms of 1, so some rows get massive updates while others starve. By training step 500, >1/4 neurons die permanently, starving downstream layers and compounding inefficiency. Leverage scores (squared row norms of U) become highly anisotropic, amplifying the death spiral.
NorMuon patches this with inverse RMS row normalization to unit norm, boosting performance but sacrificing polar factor precision. U-NorMuon refines to target norm √(n/m) for column-orthogonal tall matrices, eliminating death and stabilizing gradients even in untouched layers like down-projections—at 340M scale, it outperforms Muon/NorMuon with isotropic leverage.
Aurora Solves Joint Constraints for Precise, Uniform Updates
Aurora reformulates as steepest descent maximizing Tr(GᵀU) under dual constraints: UᵀU = Iₙ (left semi-orthogonality) and ||U_||₂ = √(m/n) ∀i (uniform row leverage). This forces all singular values of U to 1, achieving perfect orthogonality without trade-offs—unlike NorMuon's post-hoc normalization.
Implement as drop-in Muon replacement: Riemannian Aurora (gradient projection on Stiefel/equal-leverage manifold) or vanilla Aurora (simpler). For wide/square matrices, orthogonality implies uniformity, so unchanged. Open-source code supports scale; adds only 6% compute vs. Muon.
SOTA Results Scale with MLP Width
At 1.1B parameters, Aurora trains 100x data-efficient model on open internet data, beating larger models on HellaSwag. Tops modded-nanoGPT speedrun (prior SOTA: NorMuon). Gains grow with MLP expansion (wider = taller matrices = more anisotropy risk), confirming hypothesis. Use for GPT-style training to avoid silent capacity loss.