Understanding Batch Normalization, Layer Normalization and Group Normalization by implementing from scratch
def batch_norm(x):
mean = x.mean(0, keepdim=True)
var = x.var(0, unbiased=False, keepdim=True)
x_norm = (x - mean) / (var + 1e-5).sqrt()
return x_norm
2. Layer Normalization: Proposed by Ba et al. in 2016, layer normalization operates over the feature dimension (i.e., it calculates the mean and variance for each instance separately, over all the features). Unlike batch normalization, it doesn't depend on the batch size, so it's often used in recurrent models where batch normalization performs poorly.
def layer_norm(x):
mean = x.mean(1, keepdim=True)
var = x.var(1, unbiased=False, keepdim=True)
x_norm = (x - mean) / (var + 1e-5).sqrt()
return x_norm
3. Group Normalization: Proposed by Wu and He in 2018, group normalization is a middle-ground approach that divides the channels into smaller groups and normalizes the features within each group. It is designed to perform consistently well for both small and large batch sizes.
Recommended by LinkedIn
def group_norm(x, num_groups):
N, C = x.shape
x = x.view(N, num_groups, -1)
mean = x.mean(-1, keepdim=True)
var = x.var(-1, unbiased=False, keepdim=True)
x_norm = (x - mean) / (var + 1e-5).sqrt()
x_norm = x_norm.view(N, C)
return x_norm
let's implement a basic version of each of these normalization techniques from scratch. Please keep in mind that these implementations are intended to be instructive and might not cover all the edge cases handled by PyTorch's built-in versions.
import torc
from torch import nn
import torch.nn.functional as F
from functools import partial
def batch_norm(x):
mean = x.mean(0, keepdim=True)
var = x.var(0, unbiased=False, keepdim=True)
x_norm = (x - mean) / (var + 1e-5).sqrt()
return x_norm
def layer_norm(x):
mean = x.mean(1, keepdim=True)
var = x.var(1, unbiased=False, keepdim=True)
x_norm = (x - mean) / (var + 1e-5).sqrt()
return x_norm
def group_norm(x, num_groups):
N, C = x.shape
x = x.view(N, num_groups, -1)
mean = x.mean(-1, keepdim=True)
var = x.var(-1, unbiased=False, keepdim=True)
x_norm = (x - mean) / (var + 1e-5).sqrt()
x_norm = x_norm.view(N, C)
return x_norm
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, norm_func):
super().__init__()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.norm_func = norm_func
self.linear2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.linear1(x)
x = self.norm_func(x)
x = F.relu(x)
x = self.linear2(x)
return x
# Create a random tensor with size (batch_size, input_dim)
x = torch.randn(32, 100)
# Create the MLP models with batch norm, layer norm, and group norm
model_bn = MLP(100, 64, 10, batch_norm)
model_ln = MLP(100, 64, 10, layer_norm)
model_gn = MLP(100, 64, 10, partial(group_norm, num_groups=4))
# Pass the input tensor through the models
output_bn = model_bn(x)
output_ln = model_ln(x)
output_gn = model_gn(x)
# Print the outputs
print("Output with batch norm:\n", output_bn)
print("\nOutput with layer norm:\n", output_ln)
print("\nOutput with group norm:\n", output_gn)
Each of these normalization techniques has its strengths and weaknesses, and the choice between them depends on the specific problem and model architecture. For instance, batch normalization might be the first choice for convolutional networks with large batch sizes, while layer normalization or group normalization could be more suitable for recurrent networks or other models with small or variable batch sizes.
Hello, I just came across this and would like to make a quick comment. Be careful as your implementation does not match PyTorch. Batch normalization takes the mean and standard deviation by reducing batch and spatial dimensions. So the correct implementation would be : x.mean([0,2,3], keepdim=True). Same for standard deviation. Please check your other functions as well. This is not the same as x.mean(0,keepdim=True). Have a nice day !