Batch GEMMs for Fast LSTM in Torch
Fuse LSTM operations into nngraph module to batch 4 GEMMs, slashing overhead vs standard nn.LSTM (optimized by @jcjohnson).
Batch GEMMs to Cut LSTM Overhead
Standard Torch LSTMs compute input (i2h) and hidden (h2h) projections separately, doubling GEMM calls and kernel launch overhead. This gist fuses them: compute i2h + h2h in one 4x wider GEMM (gates i,f,o,c), then slice for sigmoid/tanh. Result: single GEMM pass per timestep, 2-3x faster on GPU for char-level models (as in Karpathy's Python LSTM gist). Trade-off: fixed rnn_size, no peepholes, Lua-only (Torch7).
Usage: m = LSTM.fast_lstm(input_size, rnn_size) returns gModule({x, prev_c, prev_h}, {next_c, next_h}). Feed sequences by unrolling: for t=1,T do h,c = m:forward({x[t], c, h}) end.
Gate Computation Graph
Builds nn.gModule with:
i2h = nn.Linear(input_size, 4*rnn_size)(x)+h2h = nn.Linear(rnn_size, 4*rnn_size)(prev_h)→all_input_sums = nn.CAddTable()({i2h, h2h})(batched gates).- Sigmoid chunk:
nn.Narrow(2,1,3*rnn_size)(all_input_sums)→ gates i,f,o. - Input transform:
nn.Narrow(2,3*rnn_size+1,rnn_size)(all_input_sums)→ tanh(c~). - Cell:
next_c = forget_gate ⊙ prev_c + in_gate ⊙ c~(CMulTable + CAddTable). - Hidden:
next_h = out_gate ⊙ tanh(next_c).
Full code:
function LSTM.fast_lstm(input_size, rnn_size)
local x = nn.Identity()()
local prev_c = nn.Identity()()
local prev_h = nn.Identity()()
local i2h = nn.Linear(input_size, 4 * rnn_size)(x)
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
local all_input_sums = nn.CAddTable()({i2h, h2h})
local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
in_transform = nn.Tanh()(in_transform)
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
})
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
return nn.gModule({x, prev_c, prev_h}, {next_c, next_h})
end
Production Notes
From Karpathy (2015): Powers char-rnn models. Justin Johnson's tweaks batch everything. Scales to seq len 1000s on GTX 580-era GPUs. Modern PyTorch equiv: torch.nn.LSTM with bias=False + fused CUDA kernels (faster still). Port to Flux.jl or JAX for today, but graph fusion principle endures for custom RNNs.