Structured Inverse-Free Natural Gradient: Memory-Efficient & Numerically-Stable KFAC for Large Neural Nets

Lin, Wu, Dangel, Felix, Eschenhagen, Runa, Neklyudov, Kirill, Kristiadi, Agustinus, Turner, Richard E., Makhzani, Alireza

arXiv.org Machine Learning 

Second-order methods for deep learning--such as KFAC--can be useful for neural net training. However, they are often memory-inefficient and numerically unstable for low-precision training since their preconditioning Kronecker factors are dense, and require high-precision matrix inversion or decomposition. Consequently, such methods are not widely used for training large neural networks such as transformerbased models. We address these two issues by (i) formulating an inverse-free update of KFAC and (ii) imposing structures in each of the Kronecker factors, resulting in a method we term structured inverse-free natural gradient descent (SINGD). On large modern neural networks, we show that, in contrast to KFAC, SINGD is memory efficient and numerically robust, and often outperforms AdamW even in half precision. Hence, our work closes a gap between first-order and second-order methods in modern low precision training for large neural nets. The continuing success of deep learning (DL) is--to a large extent--powered by scaling up computational power (Thompson et al., 2020) to increase the number of neural network (NN) parameters that can be trained. Contemporary natural language processing (Radford et al., 2019; Brown et al., 2020; Touvron et al., 2023) and computer vision (Dehghani et al., 2023) models often consist of many billions of parameters, and will likely grow further in the future. To compensate for increasingly higher computational demands of training more parameters, many training pipelines use lower precision data types (Micikevicius et al., 2018) and memory-efficient first-order optimizers like SGD (Robbins & Monro, 1951) or Adam(W) (Kingma & Ba, 2015; Loshchilov & Hutter, 2019). One major obstacle why those methods are rarely used in deep learning is their higher memory consumption and iteration cost.

Duplicate Docs Excel Report

Title
None found

Similar Docs  Excel Report  more

TitleSimilaritySource
None found