2.11. Batch Normalization
Contents
2.11. Batch Normalization#
In neural networks, the output of the first layer feeds into the second layer, the output of the second layer feeds into the third, and so on. When the parameters of a layer change, so does the distribution of inputs to subsequent layers.
These shifts in input distributions are called as Internal covariate shift and they can be problematic for neural networks, especially deep neural networks that could have a large number of layers.
Batch normalization (BN) is a method intended to mitigate internal covariate shift for neural networks.
Machine learning methods tend to work better when their input data consists of uncorrelated features with zero mean and unit variance. When training a neural network, we can preprocess the data before feeding it to the network to explicitly decorrelate its features; this will ensure that the first layer of the network sees data that follows a nice distribution.
However even if we preprocess the input data, the activations at deeper layers of the network will likely no longer be decorrelated and will no longer have zero mean or unit variance since they are output from earlier layers in the network. Even worse, during the training process the distribution of features at each layer of the network will shift as the weights of each layer are updated.
To overcome this, at training time, a batch normalization layer normalises all the input features to a unit normal distribution \(\mathcal{N}(\mu=0,\sigma=1)\). A running average of the means and standard deviations is kept during training, and at test time these running averages are used to center and normalize features.
Adding BN layers leads to faster and better convergence (where better means higher accuracy)
Adding BN layers allows us to use higher learning rate (\(\eta\)) without compromising convergence
Implementation
In practice, we consider the batch normalization as a standard layer, such as a perceptron, a convolutional layer, an activation function or a dropout layer and it is generally applied after calculating the weighted sum \(z_l\) and before applying the non-linear activation function \(f_l(z_l)\).
For any layer \(l\), Consider \(z\) of size \((m,h_l)\) (where \(h_l\) is the number of neurons in that hidden layer) be an input to batch normalization (\(\text{BN}\)). In this case the batch normalization is defined as follows:
where \(\mu\) of size \((h_l,1)\) and \(\sigma\) of size \((h_l,1)\) are the respective population mean and standard deviation of \(z\) over the full batch (of batch size \(m\)). Note that we add a small constant \(\epsilon > 0\) to the variance estimate to ensure that we never attempt division by zero.
After applying standardization, the resulting minibatch has zero mean and unit variance. The variables \(\gamma\) of size \((h_l,1)\) and \(\beta\) of size \((h_l,1)\) are learned parameters that allow a standardized variable to have any mean and standard deviation.
In simple terms, zero mean and unit standard deviation can reduce the expressive power of the neural network. To maintain the expressive power of the network, it is common to replace the standardized variable \(\hat{z}\) with \(\gamma \hat{z} + \beta\) where parameters like \(W\) and \(b\), \(\gamma\) and \(\beta\) can also be learned.
Forward pass and Back Propagation in Batch Normalization Layer#
Let us apply batch normalization (\(\text{BN}\)) on layer \(l\) after the weighted sum and before the activation function.
Forward pass Batch Normalization (vectorized)
We know from the standard forward propagation (link to previous chapter) that
This \(z_l\) will be an input to batch normalization (\(\text{BN}\)) and let the output we get from this be \(q_l\). Also, let
and
Therefore,
where the parameters are as defined above. And finally, passing \(q_l\) through activation function \(f_l(x)\). Fianlly,
Backpropagation Batch Normalization (vectorized)
We know from the standard backward propagation (link to previous chapter) that (let us denote the cost function \(J(W, b, \gamma, \beta)\) as \(J\) for simplicity)
Note: Let \(\sum_c\) denote the sum along the column (i.e. sum of column-1 then sum of column-2 and so on) to get a vector of size \((h_l, 1)\)
This will serve as an input in calculating the partial derivative of cost function \(J\) with respect to \(\gamma\), \(\beta\) and \(z_l\) and its size will be \((m,h_l)\)
Partial derivative of \(J\) with respect to \(\beta\)
Partial derivative of \(J\) with respect to \(\gamma\)
Partial derivative of \(J\) with respect to \(\hat{z_l}\)
Partial derivative of \(J\) with respect to \(\mu\)
Partial derivative of \(J\) with respect to \(\sigma^2\)
Partial derivative of \(J\) with respect to \(z_l\)
And finally,
Follow [1] or [2] (links to external websites) derivations in case you are more interested.
Python code for forward and backward pass of Batch normalization#
This is our input to BN layer (\(z_l\))
z
represents: \(z_l\)
import numpy as np
np.random.seed(42)
z = np.random.randint(low=0,high=10,size=(7,3))
m, d = z.shape
z
array([[6, 3, 7],
[4, 6, 9],
[2, 6, 7],
[4, 3, 7],
[7, 2, 5],
[4, 1, 7],
[5, 1, 4]])
We next need some initial value of \(\gamma\) and \(\beta\)
gamma
represents: \(\gamma\)
beta
represents: \(\beta\)
gamma = np.ones((d))
np.random.seed(24)
beta = np.zeros((d))
gamma
array([1., 1., 1.])
beta
array([0., 0., 0.])
Forward pass
eps
represents: \(\epsilon\)
mu
represents: \(\mu\)
var
represents: \(\sigma^2\)
zmu
represents: \(\bar{z_l}\)
ivar
represents: \(\frac{1}{\sqrt{\sigma^2 + \epsilon}}\)
zhat
represents: \(\hat{z_l}\)
q
represents: \(q_l\)
eps = 1e-6 # 𝜖
mu = np.mean(z, axis = 0) # 𝜇
var = np.var(z, axis=0) # 𝜎^2
zmu = z - mu # z - 𝜇
ivar = 1 / np.sqrt(var + eps) # 𝜎𝑖𝑛𝑣
zhat = zmu * ivar
q = gamma*zhat + beta # ql
q
array([[ 0.95346238, -0.07293249, 0.28603871],
[-0.38138495, 1.45864972, 1.62088604],
[-1.71623228, 1.45864972, 0.28603871],
[-0.38138495, -0.07293249, 0.28603871],
[ 1.62088604, -0.58345989, -1.04880861],
[-0.38138495, -1.09398729, 0.28603871],
[ 0.28603871, -1.09398729, -1.71623228]])
mu
array([4.57142857, 3.14285714, 6.57142857])
var
array([2.24489796, 3.83673469, 2.24489796])
We will save some of these variables in cache
as they will be used in backpropagation
cache = (gamma, beta, zmu, ivar, zhat)
Note: During training we also keep an exponentially decaying running value of the mean and variance of each feature, and these averages are used to normalize data at test-time. At each timestep we update the running averages for mean and variance using an exponential decay based on the
momentum
parameter:
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
Test-time forward pass for batch normalization
We use the running mean and variance to normalize the incoming test data (\(z_t\)), then scale and shift the normalized data using gamma (\(\gamma\)) and beta (\(\beta\)) respectively. Output is stored in \(q_t\)
zt_hat = (zt - running_mean) / np.sqrt(running_var + eps)
qt = gamma * zt_hat + beta
Backpropagation
This dq
variable below represents \(\frac{\partial J}{\partial q_l}\)
np.random.seed(24)
dq = np.random.randn(m,d)
dq
array([[ 1.32921217, -0.77003345, -0.31628036],
[-0.99081039, -1.07081626, -1.43871328],
[ 0.56441685, 0.29572189, -1.62640423],
[ 0.2195652 , 0.6788048 , 1.88927273],
[ 0.9615384 , 0.1040112 , -0.48116532],
[ 0.85022853, 1.45342467, 1.05773744],
[ 0.16556161, 0.51501838, -1.33693569]])
dgamma
represents: \(\frac{\partial J}{\partial \gamma}\)
dbeta
represents: \(\frac{\partial J}{\partial \beta}\)
dzhat
represents: \(\frac{\partial J}{\partial \hat{z_l}}\)
dvar
represents: \(\frac{\partial J}{\partial \sigma^2}\)
dmu
represents: \(\frac{\partial J}{\partial \mu}\)
dz
represents: \(\frac{\partial J}{\partial z_l}\)
dgamma = np.sum(dq * zhat, axis=0)
dbeta = np.sum(dq, axis=0)
dzhat = dq * gamma
dvar = np.sum(dzhat * zmu * (-.5) * (ivar**3), axis=0)
dmu = np.sum(dzhat * (-ivar), axis=0)
dz = dzhat * ivar + dvar * (2/m) * zmu + (1/m)*dmu
dgamma
array([ 1.87446152, -3.33807569, 0.75442823])
dbeta
array([ 3.09971237, 1.20613122, -2.25248871])
dz
array([[ 0.42119623, -0.49884504, -0.01690198],
[-0.888674 , -0.27953285, -0.86205837],
[ 0.38788918, 0.41812232, -0.89130965],
[-0.0808407 , 0.24082659, 1.45513635],
[ 0.05651819, -0.17691132, -0.03093201],
[ 0.34007894, 0.38771122, 0.90015001],
[-0.23616783, -0.09137091, -0.55408435]])