How Does Critical Batch Size Scale in Pre-training?

Zhang, Hanlin, Morwani, Depen, Vyas, Nikhil, Wu, Jingfeng, Zou, Difan, Ghai, Udaya, Foster, Dean, Kakade, Sham

arXiv.org Machine Learning 

Efficient optimization is critical in pre-training large models (LMs) at scale (McCandlish et al., 2018; Shoeybi et al., 2019; Kaplan et al., 2020). In particular, large-batch training is key to accelerating training, as it enables more efficient parallelism across hardware accelerators (You et al., 2017; Goyal et al., 2018). Specifically, understanding the scaling behavior of the critical batch size (CBS) is essential for optimizing data parallelism, as it defines the point beyond which increasing the batch size may result in computational efficiency degradation. Below the CBS, approximately linear scaling is achievable--doubling the batch size can proportionally reduce the number of optimization steps required to reach a target loss. However, beyond this threshold, further increases in batch size would lead to diminishing returns, making it essential to balance computational efficiency with model performance (McCandlish et al., 2018; Shallue et al., 2019). This trade-off presents a challenge for studying pre-training given resource constraints as practitioners are compelled to navigate difficult decisions in balancing compute, data, and training time. We investigate the scaling laws governing CBS in the context of autoregressive transformerbased language modeling (Vaswani, 2017; Radford et al., 2018). Analyzing CBS in pre-training is challenging due to the absence of a precise formalism relating it to model and data sizes in the literature (McCandlish et al., 2018; Kaplan et al., 2020).