SAMPa: Sharpness-aware Minimization Parallelized

Neural Information Processing Systems 

Sharpness-aware minimization (SAM) has been shown to improve the generalization of neural networks. However, each SAM update requires sequentially computing two gradients, effectively doubling the per-iteration cost compared to base optimizers like SGD. We propose a simple modification of SAM, termed SAMPa, which allows us to fully parallelize the two gradient computations. SAMPa achieves a twofold speedup of SAM under the assumption that communication costs between devices are negligible. Empirical results show that SAMPa ranks among the most efficient variants of SAM in terms of computational time.