stad
StAD: Stein Amortized Divergence for Fast Likelihoods with Diffusion and Flow
Jagwani, Gurjeet, Thorp, Stephen, Deger, Sinan, Peiris, Hiranya
Diffusion and flow-based models are ubiquitously used for generative modelling and density estimation. They admit a deterministic probability flow ordinary differential equation (PF-ODE), analogous to continuous normalizing flows (CNFs), which describes the transport of the probability mass. Obtaining the likelihood from these models is of interest to many workflows, especially Bayesian analysis, and requires solving the trace of the Jacobian to compute the divergence of the learned PF-ODE, which is either $\mathcal{O}(D^2)$ to compute exactly or $\mathcal{O}(D)$ with a noisy estimate. We introduce StAD, a new distillation method to predict and learn the divergence of the PF-ODE using the Langevin-Stein operator without ever computing the Jacobian. We show that our method is competitive with the Hutchinson and Hutch++ on CIFAR-10, ImageNet and other density estimation tasks, consistently improving the variance and speed of the likelihood predictions compared to the Hutchinson. We additionally show our method will generalize to a varied class of generative models, and show that under some regularity conditions these learned vector fields can be made to satisfy the Stein class.
Test-Time Adaptation with State-Space Models
Schirmer, Mona, Zhang, Dan, Nalisnick, Eric
Distribution shifts between training and test data are all but inevitable over the lifecycle of a deployed model and lead to performance decay. Adapting the model can hopefully mitigate this drop in performance. Yet, adaptation is challenging since it must be unsupervised: we usually do not have access to any labeled data at test time. In this paper, we propose a probabilistic state-space model that can adapt a deployed model subjected to distribution drift. Our model learns the dynamics induced by distribution shifts on the last set of hidden features. Without requiring labels, we infer time-evolving class prototypes that serve as a dynamic classification head. Moreover, our approach is lightweight, modifying only the model's last linear layer. In experiments on real-world distribution shifts and synthetic corruptions, we demonstrate that our approach performs competitively with methods that require back-propagation and access to the model backbone.
STAD: Spatio-Temporal Adjustment of Traffic-Oblivious Travel-Time Estimation
Abbar, Sofiane, Stanojevic, Rade, Mokbel, Mohamed
Travel time estimation is an important component in modern transportation applications. The state of the art techniques for travel time estimation use GPS traces to learn the weights of a road network, often modeled as a directed graph, then apply Dijkstra-like algorithms to find shortest paths. Travel time is then computed as the sum of edge weights on the returned path. In order to enable time-dependency, existing systems compute multiple weighted graphs corresponding to different time windows. These graphs are often optimized offline before they are deployed into production routing engines, causing a serious engineering overhead. In this paper, we present STAD, a system that adjusts - on the fly - travel time estimates for any trip request expressed in the form of origin, destination, and departure time. STAD uses machine learning and sparse trips data to learn the imperfections of any basic routing engine, before it turns it into a full-fledged time-dependent system capable of adjusting travel times to real traffic conditions in a city. STAD leverages the spatio-temporal properties of traffic by combining spatial features such as departing and destination geographic zones with temporal features such as departing time and day to significantly improve the travel time estimates of the basic routing engine. Experiments on real trip datasets from Doha, New York City, and Porto show a reduction in median absolute errors of 14% in the first two cities and 29% in the latter. We also show that STAD performs better than different commercial and research baselines in all three cities.