Supplementary material A Experimental details

Neural Information Processing Systems 

We are using JAX [ Bradbury et al., 2018 ]. All the models except for section C.4 have been trained with Softmax loss normalized as Batch Norm: we are using JAX's Stax implementation of Batch Norm which doesn't keep track of Trained on 512 samples of MNIST. MaxPool((2,2), 'V ALID') performs max pooling with'V ALID' padding Trained on CIFAR-10 without data augmentation. The WRN experiments are run on v3-8 TPUs and the rest on P100 GPUs. Here we describe the particularities of each figure.