BlackJAX: Composable Bayesian inference in JAX
Cabezas, Alberto, Corenflos, Adrien, Lao, Junpeng, Louf, Rémi, Carnec, Antoine, Chaudhari, Kaustubh, Cohn-Gordon, Reuben, Coullon, Jeremie, Deng, Wei, Duffield, Sam, Durán-Martín, Gerardo, Elantkowski, Marcin, Foreman-Mackey, Dan, Gregori, Michele, Iguaran, Carlos, Kumar, Ravin, Lysy, Martin, Murphy, Kevin, Orduz, Juan Camilo, Patel, Karm, Wang, Xi, Zinkov, Rob
BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well with probabilistic programming languages by working directly with the (un-normalized) target log density function. BlackJAX is intended as a collection of low-level, composable implementations of basic statistical 'atoms' that can be combined to perform well-defined Bayesian inference, but also provides high-level routines for ease of use. It is designed for users who need cutting-edge methods, researchers who want to create complex sampling methods, and people who want to learn how these work.
Feb-22-2024