Parallelizing neural networks on one GPU with JAX

#artificialintelligence 

Most neural network libraries these days give amazing computational performance for training large neural networks. But small networks, which aren't big enough to usefully "fill" a GPU, leave a lot of available compute unused. Running a small network on a GPU is a bit like buying an apartment building and then living in the janitor's closet. In this article, I describe how to get your money's worth by training dozens of networks at once. As you follow along, we'll efficiently train dozens of small neural networks in parallel on a single GPU using the vmap function from JAX. Whether you are training ensembles, sweeping over hyperparameters, or averaging across random seeds, this technique can give you a 10x-100x improvement in computation time. If you haven't tried JAX yet, this may give you a reason to.