Batched L2 Norm Layer for Torch Neural Nets
Custom Torch nn.Module normalizes each row of n x d input tensor to unit L2 norm, with efficient batched forward/backward passes for training.
Core Layer Design
This nn.L2Normalize module processes 2D tensors (batch size n x vector dim d), normalizing each row vector to unit L2 norm (||x||_2 = 1). Use it in Torch neural nets for tasks like embedding normalization, where direction matters more than magnitude. Instantiate via local layer = nn.L2Normalize(), then integrate into models like Sequential for end-to-end differentiability.
Forward pass (updateOutput): Computes per-row L2 norms squared via elementwise square and sum over dim 2 (input:cmul(input):sum(2)), takes sqrt, then elementwise divides input by expanded norms (input:cdiv(buffer:expandAs(input))). Avoids loops for batch efficiency; buffers reuse across calls.
Gradient Computation
Backward pass (updateGradInput) derives local Jacobian of L2 transform for chain rule. Key steps:
- Forms identity tensor repeated over batch (
torch.eye(d):repeatTensor(n,1):view(n,d,d)). - Scales diagonal by norm squared (
cmul(eye, normSquared:view(n,1,1):expand(n,d,d))). - Subtracts outer products (
-torch.bmm(input:view(n,d,1), input:view(n,1,d))). - Divides by cubed norms (
cdiv(pow(buffer,3):expand(n,d,d))). - Applies via batched matmul:
bmm(diag, gradOutput:view(n,d,1)):resize(n,d)(fixed with:squeeze()post-line 31). This ensures correct gradients during backprop, critical for training stability in nets with normalization layers.
Implementation Notes and Fixes
Code uses lazy buffer init (self.buffer = self.buffer or input.new()) for memory efficiency. Assumes mini-batch inputs only (errors on non-2D). Community feedback: Could swap manual norm for torch.norm() in forward for simplicity; Karpathy confirmed feasibility. Atcold noted dimension mismatch in gradInput without :squeeze() after bmm resize—fixed by author. Soumith (Torch maintainer) provided additional pointers (unspecified). Thin gist from 2015; modern PyTorch has torch.nn.functional.normalize(p=2, dim=1) as built-in alternative.