Nakkiran, Preetum
Composition and Control with Distilled Energy Diffusion Models and Sequential Monte Carlo
Thornton, James, Bethune, Louis, Zhang, Ruixiang, Bradley, Arwen, Nakkiran, Preetum, Zhai, Shuangfei
Diffusion models may be formulated as a time-indexed sequence of energy-based models, where the score corresponds to the negative gradient of an energy function. As opposed to learning the score directly, an energy parameterization is attractive as the energy itself can be used to control generation via Monte Carlo samplers. Architectural constraints and training instability in energy parameterized models have so far yielded inferior performance compared to directly approximating the score or denoiser. We address these deficiencies by introducing a novel training regime for the energy function through distillation of pre-trained diffusion models, resembling a Helmholtz decomposition of the score vector field. We further showcase the synergies between energy and score by casting the diffusion sampling procedure as a Feynman Kac model where sampling is controlled using potentials from the learnt energy functions. The Feynman Kac model formalism enables composition and low temperature sampling through sequential Monte Carlo.
Mechanisms of Projective Composition of Diffusion Models
Bradley, Arwen, Nakkiran, Preetum, Berthelot, David, Thornton, James, Susskind, Joshua M.
We study the theoretical foundations of composition in diffusion models, with a particular focus on out-of-distribution extrapolation and length-generalization. Prior work has shown that composing distributions via linear score combination can achieve promising results, including length-generalization in some cases (Du et al., 2023; Liu et al., 2022). However, our theoretical understanding of how and why such compositions work remains incomplete. In fact, it is not even entirely clear what it means for composition to "work". This paper starts to address these fundamental gaps. We begin by precisely defining one possible desired result of composition, which we call projective composition. Then, we investigate: (1) when linear score combinations provably achieve projective composition, (2) whether reverse-diffusion sampling can generate the desired composition, and (3) the conditions under which composition fails. Finally, we connect our theoretical analysis to prior empirical observations where composition has either worked or failed, for reasons that were unclear at the time.
Normalizing Flows are Capable Generative Models
Zhai, Shuangfei, Zhang, Ruixiang, Nakkiran, Preetum, Berthelot, David, Gu, Jiatao, Zheng, Huangjie, Chen, Tianrong, Bautista, Miguel Angel, Jaitly, Navdeep, Susskind, Josh
Normalizing Flows (NFs) are likelihood-based models for continuous inputs. They have demonstrated promising results on both density estimation and generative modeling tasks, but have received relatively little attention in recent years. In this work, we demonstrate that NFs are more powerful than previously believed. We present TarFlow: a simple and scalable architecture that enables highly performant NF models. TarFlow can be thought of as a Transformer-based variant of Masked Autoregressive Flows (MAFs): it consists of a stack of autoregressive Transformer blocks on image patches, alternating the autoregression direction between layers. TarFlow is straightforward to train end-to-end, and capable of directly modeling and generating pixels. We also propose three key techniques to improve sample quality: Gaussian noise augmentation during training, a post training denoising procedure, and an effective guidance method for both class-conditional and unconditional settings. Putting these together, TarFlow sets new state-of-the-art results on likelihood estimation for images, beating the previous best methods by a large margin, and generates samples with quality and diversity comparable to diffusion models, for the first time with a stand-alone NF model. We make our code available at https://github.com/apple/ml-tarflow.
A Formal Framework for Understanding Length Generalization in Transformers
Huang, Xinting, Yang, Andy, Bhattamishra, Satwik, Sarrof, Yash, Krebs, Andreas, Zhou, Hattie, Nakkiran, Preetum, Hahn, Michael
A major challenge for transformers is generalizing to sequences longer than those observed during training. While previous works have empirically shown that transformers can either succeed or fail at length generalization depending on the task, theoretical understanding of this phenomenon remains limited. In this work, we introduce a rigorous theoretical framework to analyze length generalization in causal transformers with learnable absolute positional encodings. In particular, we characterize those functions that are identifiable in the limit from sufficiently long inputs with absolute positional encodings under an idealized inference scheme using a norm-based regularizer. This enables us to prove the possibility of length generalization for a rich family of problems. We experimentally validate the theory as a predictor of success and failure of length generalization across a range of algorithmic and formal language tasks. Our theory not only explains a broad set of empirical observations but also opens the way to provably predicting length generalization capabilities in transformers.
How JEPA Avoids Noisy Features: The Implicit Bias of Deep Linear Self Distillation Networks
Littwin, Etai, Saremi, Omid, Advani, Madhu, Thilak, Vimal, Nakkiran, Preetum, Huang, Chen, Susskind, Joshua
Two competing paradigms exist for self-supervised learning of data representations. Joint Embedding Predictive Architecture (JEPA) is a class of architectures in which semantically similar inputs are encoded into representations that are predictive of each other. A recent successful approach that falls under the JEPA framework is self-distillation, where an online encoder is trained to predict the output of the target encoder, sometimes using a lightweight predictor network. This is contrasted with the Masked AutoEncoder (MAE) paradigm, where an encoder and decoder are trained to reconstruct missing parts of the input in the data space rather, than its latent representation. A common motivation for using the JEPA approach over MAE is that the JEPA objective prioritizes abstract features over fine-grained pixel information (which can be unpredictable and uninformative). In this work, we seek to understand the mechanism behind this empirical observation by analyzing the training dynamics of deep linear models. We uncover a surprising mechanism: in a simplified linear setting where both approaches learn similar representations, JEPAs are biased to learn high-influence features, i.e., features characterized by having high regression coefficients. Our results point to a distinct implicit bias of predicting in latent space that may shed light on its success in practice.
Step-by-Step Diffusion: An Elementary Tutorial
Nakkiran, Preetum, Bradley, Arwen, Zhou, Hattie, Advani, Madhu
There are many existing resources for learning diffusion models. Why did we write another? Our goal was to teach diffusion as simply as possible, with minimal mathematical and machine learning prerequisites, but in enough detail to reason about its correctness. Unlike most tutorials on this subject, we take neither a Variational Auto Encoder (VAE) nor an Stochastic Differential Equations (SDE) approach. In fact, for the core ideas we will not need any SDEs, Evidence-Based-Lower-Bounds (ELBOs), Langevin dynamics, or even the notion of a score.
When is Multicalibration Post-Processing Necessary?
Hansen, Dutch, Devic, Siddartha, Nakkiran, Preetum, Sharan, Vatsal
A popular approach to ensuring that probabilistic predictions from machine learning algorithms are meaningful is model calibration. Intuitively, calibration requires that amongst all samples given score p [0, 1] by an ML algorithm, exactly a p-fraction of those samples have positive label. Calibration ensures that a predictor has an accurate estimate of its own predictive uncertainty, and is a fundamental requirement in applications where probabilities may be taken into account for high-stake decisions such as disease diagnosis (Dahabreh et al., 2017) or credit/lending decisions (Bequé et al., 2017). Miscalibration can result in undesirable downstream consequences when probabilistic predictions are thresholded into decisions: if a predictor has high calibration error in disease diagnosis, for example, the individuals assigned lower predicted probabilities may be unfairly denied treatment. Calibration has a long history in the machine learning community (Guo et al., 2017; Minderer et al., 2021; Niculescu-Mizil and Caruana, 2005; Platt et al., 1999), but was arguably first introduced in fairness contexts by Cleary (1968). More recently, it has appeared in the algorithmic fairness community via the seminal works of Chouldechova (2017); Kleinberg et al. (2017). Although calibration ensures meaningful uncertainty estimates aggregated over the entire population, it does not preclude potential discrimination at the level of groups of individuals: a model may be well calibrated overall but systematically underestimate the risk or qualification probability on historically underrepresented subsets of individuals. For example, Obermeyer et al. (2019) show differing calibration error rates across groups defined by race for prediction in high-risk patient care management systems. As pointed out by Obermeyer et al. (2019), in the
Perspectives on the State and Future of Deep Learning - 2023
Goldblum, Micah, Anandkumar, Anima, Baraniuk, Richard, Goldstein, Tom, Cho, Kyunghyun, Lipton, Zachary C, Mitchell, Melanie, Nakkiran, Preetum, Welling, Max, Wilson, Andrew Gordon
The goal of this series is to chronicle opinions and issues in the field of machine learning as they stand today and as they change over time. The plan is to host this survey periodically until the AI singularity paperclip-frenzy-driven doomsday, keeping an updated list of topical questions and interviewing new community members for each edition.
When Does Optimizing a Proper Loss Yield Calibration?
Błasiok, Jarosław, Gopalan, Parikshit, Hu, Lunjia, Nakkiran, Preetum
Optimizing proper loss functions is popularly believed to yield predictors with good calibration properties; the intuition being that for such losses, the global optimum is to predict the ground-truth probabilities, which is indeed calibrated. However, typical machine learning models are trained to approximately minimize loss over restricted families of predictors, that are unlikely to contain the ground truth. Under what circumstances does optimizing proper loss over a restricted family yield calibrated models? What precise calibration guarantees does it give? In this work, we provide a rigorous answer to these questions. We replace the global optimality with a local optimality condition stipulating that the (proper) loss of the predictor cannot be reduced much by post-processing its predictions with a certain family of Lipschitz functions. We show that any predictor with this local optimality satisfies smooth calibration as defined in Kakade-Foster (2008), B{\l}asiok et al. (2023). Local optimality is plausibly satisfied by well-trained DNNs, which suggests an explanation for why they are calibrated from proper loss minimization alone. Finally, we show that the connection between local optimality and calibration error goes both ways: nearly calibrated predictors are also nearly locally optimal.
Loss Minimization Yields Multicalibration for Large Neural Networks
Błasiok, Jarosław, Gopalan, Parikshit, Hu, Lunjia, Kalai, Adam Tauman, Nakkiran, Preetum
Multicalibration is a notion of fairness for predictors that requires them to provide calibrated predictions across a large set of protected groups. Multicalibration is known to be a distinct goal than loss minimization, even for simple predictors such as linear functions. In this work, we consider the setting where the protected groups can be represented by neural networks of size $k$, and the predictors are neural networks of size $n > k$. We show that minimizing the squared loss over all neural nets of size $n$ implies multicalibration for all but a bounded number of unlucky values of $n$. We also give evidence that our bound on the number of unlucky values is tight, given our proof technique. Previously, results of the flavor that loss minimization yields multicalibration were known only for predictors that were near the ground truth, hence were rather limited in applicability. Unlike these, our results rely on the expressivity of neural nets and utilize the representation of the predictor.