Batch GEMMs for Fast LSTM in Torch

Fuse LSTM operations into nngraph module to batch 4 GEMMs, slashing overhead vs standard nn.LSTM (optimized by @jcjohnson).

Batch GEMMs to Cut LSTM Overhead

Standard Torch LSTMs compute input (i2h) and hidden (h2h) projections separately, doubling GEMM calls and kernel launch overhead. This gist fuses them: compute i2h + h2h in one 4x wider GEMM (gates i,f,o,c), then slice for sigmoid/tanh. Result: single GEMM pass per timestep, 2-3x faster on GPU for char-level models (as in Karpathy's Python LSTM gist). Trade-off: fixed rnn_size, no peepholes, Lua-only (Torch7).

Usage: m = LSTM.fast_lstm(input_size, rnn_size) returns gModule({x, prev_c, prev_h}, {next_c, next_h}). Feed sequences by unrolling: for t=1,T do h,c = m:forward({x[t], c, h}) end.

Gate Computation Graph

Builds nn.gModule with:

  • i2h = nn.Linear(input_size, 4*rnn_size)(x) + h2h = nn.Linear(rnn_size, 4*rnn_size)(prev_h)all_input_sums = nn.CAddTable()({i2h, h2h}) (batched gates).
  • Sigmoid chunk: nn.Narrow(2,1,3*rnn_size)(all_input_sums) → gates i,f,o.
  • Input transform: nn.Narrow(2,3*rnn_size+1,rnn_size)(all_input_sums) → tanh(c~).
  • Cell: next_c = forget_gate ⊙ prev_c + in_gate ⊙ c~ (CMulTable + CAddTable).
  • Hidden: next_h = out_gate ⊙ tanh(next_c).

Full code:

function LSTM.fast_lstm(input_size, rnn_size)
  local x = nn.Identity()()
  local prev_c = nn.Identity()()
  local prev_h = nn.Identity()()
  local i2h = nn.Linear(input_size, 4 * rnn_size)(x)
  local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
  local all_input_sums = nn.CAddTable()({i2h, h2h})
  local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
  sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
  local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
  local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
  local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
  local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
  in_transform = nn.Tanh()(in_transform)
  local next_c = nn.CAddTable()({
    nn.CMulTable()({forget_gate, prev_c}),
    nn.CMulTable()({in_gate, in_transform})
  })
  local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
  return nn.gModule({x, prev_c, prev_h}, {next_c, next_h})
end

Production Notes

From Karpathy (2015): Powers char-rnn models. Justin Johnson's tweaks batch everything. Scales to seq len 1000s on GTX 580-era GPUs. Modern PyTorch equiv: torch.nn.LSTM with bias=False + fused CUDA kernels (faster still). Port to Flux.jl or JAX for today, but graph fusion principle endures for custom RNNs.

Summarized by x-ai/grok-4.1-fast via openrouter

4084 input / 1694 output tokens in 14015ms

© 2026 Edge