KernelBench Tests LLMs on GPU Kernel Generation
KernelBench's 250 NN tasks reveal LLMs generate compilable CUDA but falter on correctness for fused ops and architectures; agentic loops with profiling could enable near-peak GPU utilization.
Optimized Kernels Bridge Theory and Real-World ML Performance
Big O complexity misleads ML architecture comparisons because established models like standard attention run 5x faster due to years of kernel tuning exploiting GPU features like memory hierarchy and thread utilization. Newer ideas, such as a 30% theoretically efficient attention variant, require weeks of custom CUDA for fairness. KernelBench quantifies this gap: replace PyTorch refs with custom kernels (Triton, CUTLASS, etc.) that match outputs (1e-2 abs/rel tolerance on 5 fixed-shape random inputs) and speedup wallclock time. At scale, 5% gains slash ChatGPT's 500k+ kWh/day—equivalent to 180k US households—while enabling accurate eval of novel architectures under fixed compute budgets.
Trade-offs: Specialized kernels for given shapes beat general ones in speed but risk edge cases; agentic systems iterate via Nsight Compute feedback on bottlenecks, refining parallelization and memory ops toward peak utilization.
KernelBench's Progressive Task Levels Build to Real Systems
250 core tasks split into levels, all forward-pass only, self-contained PyTorch models with get_inputs() for testing:
- Level 1 (100 tasks): Foundational ops (conv1D/2D/3D variants, matmul, layernorm); manually curated one-shots generate variants by dims/kernel sizes.
- Level 2 (100 tasks): Fusions like conv + bias + ReLU; script picks mainloop (matmul/conv) + 2-5 epilogues (acts/norms), LLM generates PyTorch spec from one-shot.
- Level 3 (50 tasks): Full nets (MobileNet/VGG/MiniGPT/AlexNet); mix of LLM-gen and GitHub-cleaned.
- Level 4 (20 aspirational): HF models requiring src browsing/library mods; programmatic gen via API swaps.
No train/test split—focus open-ended optimization. Example: Vector add PyTorch becomes JIT CUDA via load_inline(), launching 256-thread blocks for 12x+ diag-matmul wins by skipping diag() construct (scale rows directly: outrow*M+col = diagrow * matrow*M+col). Fusions like matmul/div/sum/scale hit 3x via single kernel.
LLMs Show Promise but Need Inference Scaling for Correctness
Greedy eval (temp=0) on frontier models: High compilation (most CUDA valid), but correctness drops with complexity—Level 1: top models >50%; Level 2/3: <20%, o1 > gpt-4o via inference compute. Pass@k (N=100, high temp): DeepSeek-Coder-V2 @k=10 reaches 40-60% Level 1, Llama3.1-70B lags; scaling samples boosts 15.9%→56% solve rates per Large Language Monkeys.
Among correct samples, speedups modest (median <1x PyTorch/torch.compile), but outliers >12x (e.g., diag-matmul) or 3x fusions highlight potential. Correctness-performance tension: Aggressive opts risk errors. Leaderboard (Kernelsseum) tracks top-5 greedy kernels/problem on L40S GPU; future open submissions. Baselines underscore base model quality over pure sampling.