Jankowiak, Martin
Tensor Variable Elimination for Plated Factor Graphs
Obermeyer, Fritz, Bingham, Eli, Jankowiak, Martin, Chiu, Justin, Pradhan, Neeraj, Rush, Alexander, Goodman, Noah
A wide class of machine learning algorithms can be reduced to variable elimination on factor graphs. While factor graphs provide a unifying notation for these algorithms, they do not provide a compact way to express repeated structure when compared to plate diagrams for directed graphical models. To exploit efficient tensor algebra in graphs with plates of variables, we generalize undirected factor graphs to plated factor graphs and variable elimination to a tensor variable elimination algorithm that operates directly on plated factor graphs. Moreover, we generalize complexity bounds based on treewidth and characterize the class of plated factor graphs for which inference is tractable. As an application, we integrate tensor variable elimination into the Pyro probabilistic programming language to enable exact inference in discrete latent variable models with repeated structure. We validate our methods with experiments on both directed and undirected graphical models, including applications to polyphonic music modeling, animal movement modeling, and latent sentiment analysis.
Closed Form Variational Objectives For Bayesian Neural Networks with a Single Hidden Layer
Jankowiak, Martin
In this note we consider setups in which variational objectives for Bayesian neural networks can be computed in closed form. In particular we focus on single-layer networks in which the activation function is piecewise polynomial (e.g. ReLU). In this case we show that for a Normal likelihood and structured Normal variational distributions one can compute a variational lower bound in closed form. In addition we compute the predictive mean and variance in closed form. Finally, we also show how to compute approximate lower bounds for other likelihoods (e.g. softmax classification). In experiments we show how the resulting variational objectives can help improve training and provide fast test time predictions.
Pyro: Deep Universal Probabilistic Programming
Bingham, Eli, Chen, Jonathan P., Jankowiak, Martin, Obermeyer, Fritz, Pradhan, Neeraj, Karaletsos, Theofanis, Singh, Rohit, Szerlip, Paul, Horsfall, Paul, Goodman, Noah D.
Pyro is a probabilistic programming language built on Python as a platform for developing advanced probabilistic models in AI research. To scale to large datasets and high-dimensional models, Pyro uses stochastic variational inference algorithms and probability distributions built on top of PyTorch, a modern GPU-accelerated deep learning framework. To accommodate complex or model-specific algorithmic behavior, Pyro leverages Poutine, a library of composable building blocks for modifying the behavior of probabilistic programs.
Pathwise Derivatives Beyond the Reparameterization Trick
Jankowiak, Martin, Obermeyer, Fritz
We observe that gradients computed via the reparameterization trick are in direct correspondence with solutions of the transport equation in the formalism of optimal transport. We use this perspective to compute (approximate) pathwise gradients for probability distributions not directly amenable to the reparameterization trick: Gamma, Beta, and Dirichlet. We further observe that when the reparameterization trick is applied to the Cholesky-factorized multivariate Normal distribution, the resulting gradients are suboptimal in the sense of optimal transport. We derive the optimal gradients and show that they have reduced variance in a Gaussian Process regression task. We demonstrate with a variety of synthetic experiments and stochastic variational inference tasks that our pathwise gradients are competitive with other methods.
Pathwise Derivatives for Multivariate Distributions
Jankowiak, Martin, Karaletsos, Theofanis
We exploit the link between the transport equation and derivatives of expectations to construct efficient pathwise gradient estimators for multivariate distributions. We focus on two main threads. First, we use null solutions of the transport equation to construct adaptive control variates that can be used to construct gradient estimators with reduced variance. Second, we consider the case of multivariate mixture distributions. In particular we show how to compute pathwise derivatives for mixtures of multivariate Normal distributions with arbitrary means and diagonal covariances. We demonstrate in a variety of experiments in the context of variational inference that our gradient estimators can outperform other methods, especially in high dimensions.