An astronomer's introduction to NumPyro
Over the past year or so, I've been using JAX extensively for my research, and I've also been encouraging other astronomers to give it a try. In particular, I've been using JAX as the computation engine for probabilistic inference tasks. There's more to it, but one way that I like to think about JAX is as NumPy with just-in-time compilation and automatic differentiation. The just-in-time compilation features of JAX can be used to speed up you NumPy computations by removing some Python overhead and by executing it on your GPU. Then, automatic differentiation can be used to efficiently compute the derivatives of your code with respect to its input parameters.
Aug-4-2022, 15:52:54 GMT
- Technology: