Understanding Batch Normalization, Layer Normalization and Group Normalization by implementing from scratch

Understanding Batch Normalization, Layer Normalization and Group Normalization by implementing from scratch

  1. Batch Normalization: This technique, introduced by Ioffe and Szegedy in 2015, normalizes the data across the batch dimension (i.e., for each feature, it calculates the mean and variance across all instances in the batch). It is widely used in Convolutional Neural Networks (CNNs) as it can accelerate training and improve generalization. However, it can cause issues in certain scenarios, such as small batch sizes or sequence models, where the batch size changes every time step.

  • In the formula for batch normalization, given a batch of activations for a specific layer, it first calculates the mean and standard deviation for the batch. Then, it subtracts the mean and divides by the standard deviation to normalize the values. An epsilon is added to the standard deviation for numerical stability.
  • Following normalization, batch normalization applies a scale factor "gamma" and shift factor "beta". These two parameters are learnable and allow the layer to undo the normalization if it finds it's not useful.
  • During training, mean and variance are computed on the fly for each batch. During testing, a running average of these calculated during training is used.

def batch_norm(x):
    mean = x.mean(0, keepdim=True)
    var = x.var(0, unbiased=False, keepdim=True)
    x_norm = (x - mean) / (var + 1e-5).sqrt()
    return x_norm        

2. Layer Normalization: Proposed by Ba et al. in 2016, layer normalization operates over the feature dimension (i.e., it calculates the mean and variance for each instance separately, over all the features). Unlike batch normalization, it doesn't depend on the batch size, so it's often used in recurrent models where batch normalization performs poorly.

  • Layer normalization computes the mean and standard deviation across each individual observation instead (over all channels in case of images or all features in case of an MLP) rather than across the batch. This makes it batch-size independent and can therefore be used in models like RNNs or in transformer models.

def layer_norm(x):
    mean = x.mean(1, keepdim=True)
    var = x.var(1, unbiased=False, keepdim=True)
    x_norm = (x - mean) / (var + 1e-5).sqrt()
    return x_norm        

3. Group Normalization: Proposed by Wu and He in 2018, group normalization is a middle-ground approach that divides the channels into smaller groups and normalizes the features within each group. It is designed to perform consistently well for both small and large batch sizes.

  • Group normalization divides channels into groups and normalizes the features within each group. It's computationally straightforward and doesn't have any restrictions regarding batch size. Group normalization performs particularly well in small batch scenarios where batch normalization suffers.

def group_norm(x, num_groups):
    N, C = x.shape
    x = x.view(N, num_groups, -1)
    mean = x.mean(-1, keepdim=True)
    var = x.var(-1, unbiased=False, keepdim=True)
    x_norm = (x - mean) / (var + 1e-5).sqrt()
    x_norm = x_norm.view(N, C)
    return x_norm        

let's implement a basic version of each of these normalization techniques from scratch. Please keep in mind that these implementations are intended to be instructive and might not cover all the edge cases handled by PyTorch's built-in versions.

import torc
from torch import nn
import torch.nn.functional as F
from functools import partial

def batch_norm(x):
    mean = x.mean(0, keepdim=True)
    var = x.var(0, unbiased=False, keepdim=True)
    x_norm = (x - mean) / (var + 1e-5).sqrt()
    return x_norm

def layer_norm(x):
    mean = x.mean(1, keepdim=True)
    var = x.var(1, unbiased=False, keepdim=True)
    x_norm = (x - mean) / (var + 1e-5).sqrt()
    return x_norm

def group_norm(x, num_groups):
    N, C = x.shape
    x = x.view(N, num_groups, -1)
    mean = x.mean(-1, keepdim=True)
    var = x.var(-1, unbiased=False, keepdim=True)
    x_norm = (x - mean) / (var + 1e-5).sqrt()
    x_norm = x_norm.view(N, C)
    return x_norm

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, norm_func):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.norm_func = norm_func
        self.linear2 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.norm_func(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

# Create a random tensor with size (batch_size, input_dim)
x = torch.randn(32, 100)

# Create the MLP models with batch norm, layer norm, and group norm
model_bn = MLP(100, 64, 10, batch_norm)
model_ln = MLP(100, 64, 10, layer_norm)
model_gn = MLP(100, 64, 10, partial(group_norm, num_groups=4))

# Pass the input tensor through the models
output_bn = model_bn(x)
output_ln = model_ln(x)
output_gn = model_gn(x)

# Print the outputs
print("Output with batch norm:\n", output_bn)
print("\nOutput with layer norm:\n", output_ln)
print("\nOutput with group norm:\n", output_gn)
        

Each of these normalization techniques has its strengths and weaknesses, and the choice between them depends on the specific problem and model architecture. For instance, batch normalization might be the first choice for convolutional networks with large batch sizes, while layer normalization or group normalization could be more suitable for recurrent networks or other models with small or variable batch sizes.



Hello, I just came across this and would like to make a quick comment. Be careful as your implementation does not match PyTorch. Batch normalization takes the mean and standard deviation by reducing batch and spatial dimensions. So the correct implementation would be : x.mean([0,2,3], keepdim=True). Same for standard deviation. Please check your other functions as well. This is not the same as x.mean(0,keepdim=True). Have a nice day !

To view or add a comment, sign in

More articles by Pasha Sheikh

  • Implementing kl divergence in pytorch

    KL divergence, or Kullback-Leibler divergence, is a measure used in information theory and statistics to quantify the…

Others also viewed

Explore content categories