Batchless Normalization: How to Normalize Activations with just one Instance in Memory
–arXiv.org Artificial Intelligence
The basic idea is to take a look at each activation after a layer and to normalize it by scaling and shifting it so that the mean and standard deviation across the current batch for that activation become 0 and 1, respectively. This is supposed to approximate a normalization with the population statistics by means of the batch statistics, leading to approximately normalized inputs for the following layer. That being said, a batch normalization layer is usually assumed to include a denormalization afterwards, that is, the normalized activations are once again transformed affinely so as to have a certain mean and standard deviation, which are learnable parameters of the model. This means that the inputs to the next layer are not normalized, but rather conform approximately to a mean and standard deviation that are independent of whatever the layer before the batch normalization layer produced. The benefits of batch normalization are manifest empirically, but their theoretical understanding is under debate. I will say no more about this as my intention is not to criticize the benefits, but to address the shortcomings of which there are also several: Memory consumption: All instances of the batch must be in memory at the same time in order to compute the batch statistics. This can become a problem if the data required per instance (the activations as well as the gradients of the activations with respect to loss) do not fit on the available hardware multiple times. Even if multiple devices are available, it requires either communication between these at each batch normalization layer, or to compromise on the accuracy of the batch statistics by computing it separately and independently for each device.
arXiv.org Artificial Intelligence
Dec-30-2022