NumPy Batched LSTM Forward/Backward
Efficient pure NumPy LSTM processes batched sequences (n,b,input_size); init with Xavier + forget bias=3; verified via sequential match and numerical gradients.
Parameter Initialization for Stable Training
LSTM weights form a single matrix WLSTM of shape (input_size + hidden_size + 1, 4 * hidden_size), with +1 for biases as the first row. Use Xavier initialization: random normal scaled by 1/sqrt(input_size + hidden_size). Set biases to zero initially, but apply 'fancy_forget_bias_init=3' to forget gate biases (indices hidden_size:2*hidden_size) to start with negative bias, encouraging forget gates to stay off early in training since raw gate outputs are ~N(0,1).
Batched Forward Pass Logic
Input X: (n,b,input_size). Hidden d = WLSTM.shape1/4. Init c0/h0 as zeros((b,d)) if None. For each timestep t:
- Build Hint,:,0=1 (bias), Hint,:,1:input_size+1=Xt, Hint,:,input_size+1:=prev_h (h0 at t=0).
- Compute raw IFOGt = Hint @ WLSTM (main compute).
- Gates: sigmoid on first 3*d (input/forget/output), tanh on last d (gate candidate).
- Cell Ct = input_gate * gate_candidate + forget_gate * prev_c.
- Output Houtt = output_gate * tanh(Ct).
Cache stores all intermediates (Hin, IFOG, IFOGf, C, Ct, etc.) for backward. Returns full Hout (n,b,d), final C/H, cache.
Backward Pass Gradient Computation
Input dHout_in (n,b,d). Accumulate dCn-1/dHoutn-1 if provided for state carryover. Reverse loop over t:
- dIFOGf output slice (2d:3d) = tanh(Ctt) * dHoutt.
- dCt from tanh' * output_gate * dHoutt, plus forget/input contributions to prev_c.
- Backprop activations: tanh' on gate candidate, sigmoid'=(y(1-y)) on gates.
- dWLSTM += Hint.T @ dIFOGt; dHint = dIFOGt @ WLSTM.T.
- Extract dXt = dHint,1:input+1; propagate dHoutt-1/dh0 from dHint,input+1:; dc0/dh0 similarly.
Returns dX (n,b,input), dWLSTM, dc0, dh0.
Verification Ensures Correctness
Test 1 (sequential vs batch): n=5,b=3,d=4,input=10. Run forward sequentially (one timestep at a time, carrying c/h), confirm Hout matches full batch forward.
Test 2 (gradient check): Numerical grad = (fwd(+δ) - fwd(-δ))/(2δ), δ=1e-5. Relative error threshold warning=1e-2, error=1. Checks every element of X/WLSTM/c0/h0 against analytic grads from loss = sum(H * wrand). All params pass with low error, confirming backprop accuracy.