Coalescing Tension: Why Naive Transpose Fails on GPU

Matrix transpose (B[y]x = A[x]y) seems trivial on CPU but exposes GPU memory hierarchy flaws. GPUs thrive on coalesced access—32 threads in a warp hitting contiguous global memory (GMEM) addresses merge into few transactions. Transpose forces a tradeoff: row-major reads are coalesced but column-major writes are strided (non-coalesced), or vice versa. Strided access kills bandwidth for large matrices.

Author implements two baselines with 16x16 blocks (256 threads) for 100% SM occupancy on RTX 5060 (1,536 threads/SM max; 6 blocks/SM). Larger 32x32 blocks drop to 66% occupancy; smaller 8x8 increases overheads like block count and register pressure.

__global__ void transpose_coalesced_read_kernel(float *a, float *b, int width, int height) {
    int x = blockIdx.x * blockDim.x + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;
    if (x < width && y < height) {
        b[x * height + y] = a[y * width + x];  // Coalesced read, strided write
    }
}

Coalesced-write variant swaps indices. Both suffer one slow side, bottlenecking bandwidth. Decision: Prioritize occupancy over tile size initially, as hardware limits (e.g., threads/SM) dictate block config before SMEM size (32x32 tile=4KB, fine vs 100KB/SM).

"Matrix transpose has a fundamental tension: If we read by rows (coalesced reads), the transposed write must be by columns (non-coalesced, strided writes)." – Highlights the irreducible access pattern conflict driving all optimizations.

Shared Memory Tiling: Cache for Coalesced GMEM I/O

Solution: Use on-chip shared memory (SMEM) as tiling cache. Partition matrix into 32x32 tiles; each 32x8 threadblock (256 threads) loads tile rows coalesced from GMEM into SMEM, transposes in-place, writes columns as coalesced GMEM rows.

Key decisions:

  • Tile 32x32 despite 32x8 threads: Loop unrolls over 4 row batches (tiling_row=8, j += tiling_row) for contiguity.
  • Dual fast/slow paths: Full tiles skip bounds checks; partials handle edges.
  • __syncthreads() synchronizes before transpose read.
__shared__ float tile[32][32];
// Load: tile[ty + j][tx] = a[(y + j)*width + x];  // Coalesced rows to SMEM
__syncthreads();
// Write: b[(y + j)*height + x] = tile[tx][ty + j];  // Columns from SMEM to coalesced rows

Tradeoff: Fixes GMEM but introduces SMEM bank conflicts on transpose read—column access hits same bank repeatedly (32 banks, 4B each; warp column read serializes 32 requests).

"SMEM is on-chip, has extremely high bandwidth, and tolerates random access far better than GMEM." – Explains why tiling shifts bottleneck from GMEM bandwidth to SMEM access patterns.

Bank Conflicts: Padding vs. Swizzling Tradeoffs

SMEM banks (32x4B): Same-bank warp accesses serialize. Transpose read tile[tx][ty] maps columns to one bank (bank ID = col % 32).

Padding fix: 32x33 array; each row shifts by 1 bank, distributing column access. Simple, but wastes ~3% space (1KB extra/SMEM tile).

Swizzling fix: XOR mapping (col ^ row) remaps addresses bijectively—no conflicts, no padding. Every row/column hits all banks uniquely. Used in CuDNN/CUTLASS.

tile[sy][(sx*4 + 0) ^ sy] = va.x;  // Write with swizzle
vb.x = tile[(sx*4 + 0)][sy ^ (sx*4 + 0)];  // Read inverse swizzle

Proof: For fixed col C, row x ≠ y implies (x⊕C) ≠ (y⊕C), as XOR is invertible. Script verifies 32x32 bank uniformity post-swizzle.

Decision chain: Padding first (easy), swizzle ultimate (space-efficient, aligns with async copies/Tensor Cores). Swizzle chosen for production as it preserves alignment.

"If padding is like 'expanding the parking lot' (sacrificing space) to avoid congestion, then swizzle is a genius traffic controller." – Metaphor captures swizzle's elegance over brute-force padding.

Vectorization: Float4 for Bandwidth Peak

Further: float4 (LDS.128/ST S.128) loads 128B/warp vs scalar. Remap 256 threads to 64x4-element handlers (tid %8 = sx, tid/8=sy; each loads 4 elems).

No SMEM loop: Unpack float4 va to 4 scalar SMEM writes (padding breaks vec alignment; swizzle complicates). On write, pack from SMEM to float4.

Tradeoffs:

  • Gains GMEM bandwidth (4x elems/thread).
  • SMEM unpack/pack scalar (vec SMEM tricky with padding/swizzle).
  • Block still 256 threads, full occupancy.

Swizzled vec version applies XOR per component. Result: Coalesced vec GMEM + conflict-free SMEM.

"With 32×32, only one block fits per SM, giving just 2/3 occupancy and seriously hurting pipeline efficiency. With 16×16, an SM can host 6 blocks totaling 1,536 threads — 100% occupancy." – Hardware-specific reasoning trumps generic "larger=better".

Key Takeaways

  • Start with 16x16 blocks for occupancy on consumer GPUs (RTX 50xx: 1536 threads/SM); scale tile size independently via loops.
  • Always tile with SMEM for strided GMEM patterns—coalesce both read/write sides.
  • Audit SMEM bank conflicts post-tiling; visualize bank maps for column/row access.
  • Prefer XOR swizzle over padding: Zero extra space, bijection guarantees no conflicts.
  • Vectorize GMEM with float4 early; defer full SMEM vec until alignment/swizzle resolved.
  • Bounds-check only on edges (fast/slow paths); unroll loops for throughput.
  • Hardware knowledge (threads/SM, SMEM KB) drives block sizing over theory.
  • Test iteratively: Naive → SMEM → Padding/Swizzle → Vec.
  • Full code: Vitamin-CUDA GitHub.

"Writing an efficient transpose kernel is a classic litmus test for a CUDA engineer’s skills." – Frames transpose as skill benchmark, urging hands-on iteration.