Rethinking K-Means Dataflow

Flash-KMeans is an open-source, IO-aware implementation of Lloyd’s k-means algorithm designed for modern AI pipelines where clustering occurs within training and inference loops. Unlike algorithmic approaches that use pruning or sampling to approximate results, Flash-KMeans maintains exact mathematical parity with standard k-means. Its performance gains—up to 200x faster than FAISS and 33x faster than NVIDIA cuML—are derived entirely from optimizing how data moves between GPU memory hierarchies (HBM and SRAM).

Eliminating Memory Bottlenecks

The library targets two primary bottlenecks inherent in standard GPU-based k-means implementations:

  • Assignment Stage (FlashAssign): Standard implementations materialize a full N×K distance matrix in High Bandwidth Memory (HBM), which is costly to write and read. FlashAssign adopts a strategy similar to FlashAttention, streaming tiles of points and centroids into on-chip SRAM and fusing distance computation with an online argmin. This reduces IO complexity from O(NK) to O(Nd + Kd), preventing the distance matrix from ever being fully materialized.
  • Centroid Update Stage (Sort-Inverse Update): Standard implementations rely on atomic adds that cause hardware contention when multiple threads attempt to update the same 'hot' centroid. Flash-KMeans uses a Sort-Inverse approach: it sorts the assignment vector by cluster ID, allowing thread blocks to perform reductions on contiguous segments in on-chip memory. This minimizes atomic operations and avoids the performance degradation caused by scatter-style updates.

Performance and Practicality

Flash-KMeans is built with Triton GPU kernels and supports out-of-core processing for massive datasets by using chunked stream overlap to hide PCIe transfer latency. Benchmarks on an NVIDIA H200 (FP16, d=128) demonstrate significant end-to-end speedups, including a 17.9x improvement over the best baseline for large-scale clustering (N=8M, K=1024). The library provides both a batched tensor API and a scikit-learn-style interface, making it a drop-in replacement for existing production vector-search and clustering workflows.