Goto

Collaborating Authors

 Raghu, Maithra


Do Vision Transformers See Like Convolutional Neural Networks?

arXiv.org Artificial Intelligence

Convolutional neural networks (CNNs) have so far been the de-facto model for visual data. Recent work has shown that (Vision) Transformer models (ViT) can achieve comparable or even superior performance on image classification tasks. This raises a central question: how are Vision Transformers solving these tasks? Are they acting like convolutional networks, or learning entirely different visual representations? Analyzing the internal representation structure of ViTs and CNNs on image classification benchmarks, we find striking differences between the two architectures, such as ViT having more uniform representations across all layers. We explore how these differences arise, finding crucial roles played by self-attention, which enables early aggregation of global information, and ViT residual connections, which strongly propagate features from lower to higher layers. We study the ramifications for spatial localization, demonstrating ViTs successfully preserve input spatial information, with noticeable effects from different classification methods. Finally, we study the effect of (pretraining) dataset scale on intermediate features and transfer learning, and conclude with a discussion on connections to new architectures such as the MLP-Mixer.


Pointer Value Retrieval: A new benchmark for understanding the limits of neural network generalization

arXiv.org Artificial Intelligence

The successes of deep learning critically rely on the ability of neural networks to output meaningful predictions on unseen data -- generalization. Yet despite its criticality, there remain fundamental open questions on how neural networks generalize. How much do neural networks rely on memorization -- seeing highly similar training examples -- and how much are they capable of human-intelligence styled reasoning -- identifying abstract rules underlying the data? In this paper we introduce a novel benchmark, Pointer Value Retrieval (PVR) tasks, that explore the limits of neural network generalization. While PVR tasks can consist of visual as well as symbolic inputs, each with varying levels of difficulty, they all have a simple underlying rule. One part of the PVR task input acts as a pointer, giving the location of a different part of the input, which forms the value (and output). We demonstrate that this task structure provides a rich testbed for understanding generalization, with our empirical study showing large variations in neural network performance based on dataset size, task complexity and model architecture. The interaction of position, values and the pointer rule also allow the development of nuanced tests of generalization, by introducing distribution shift and increasing functional complexity. These reveal both subtle failures and surprising successes, suggesting many promising directions of exploration on this benchmark.


Anatomy of Catastrophic Forgetting: Hidden Representations and Task Semantics

arXiv.org Machine Learning

A central challenge in developing versatile machine learning systems is catastrophic forgetting: a model trained on tasks in sequence will suffer significant performance drops on earlier tasks. Despite the ubiquity of catastrophic forgetting, there is limited understanding of the underlying process and its causes. In this paper, we address this important knowledge gap, investigating how forgetting affects representations in neural network models. Through representational analysis techniques, we find that deeper layers are disproportionately the source of forgetting. Supporting this, a study of methods to mitigate forgetting illustrates that they act to stabilize deeper layers. These insights enable the development of an analytic argument and empirical picture relating the degree of forgetting to representational similarity between tasks. Consistent with this picture, we observe maximal forgetting occurs for task sequences with intermediate similarity. We perform empirical studies on the standard split CIFAR-10 setup and also introduce a novel CIFAR-100 based task approximating realistic input distribution shift.


Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML

arXiv.org Machine Learning

An important research direction in machine learning has centered around developing meta-learning algorithms to tackle few-shot learning. An especially successful algorithm has been Model Agnostic Meta-Learning (MAML), a method that consists of two optimization loops, with the outer loop finding a meta-initialization, from which the inner loop can efficiently learn new tasks. Despite MAML's popularity, a fundamental open question remains -- is the effectiveness of MAML due to the meta-initialization being primed for rapid learning (large, efficient changes in the representations) or due to feature reuse, with the meta initialization already containing high quality features? We investigate this question, via ablation studies and analysis of the latent representations, finding that feature reuse is the dominant factor. This leads to the ANIL (Almost No Inner Loop) algorithm, a simplification of MAML where we remove the inner loop for all but the (task-specific) head of a MAML-trained network. ANIL matches MAML's performance on benchmark few-shot image classification and RL and offers computational improvements over MAML. We further study the precise contributions of the head and body of the network, showing that performance on the test tasks is entirely determined by the quality of the learned features, and we can remove even the head of the network (the NIL algorithm). We conclude with a discussion of the rapid learning vs feature reuse question for meta-learning algorithms more broadly.


The Algorithmic Automation Problem: Prediction, Triage, and Human Effort

arXiv.org Artificial Intelligence

On a variety of high-stakes tasks, machine learning algorithms are on the threshold of doing what human experts do with such high fidelity that we are contemplating using their predictions as a substitute for human output. For example, convolutional neural networks are close to diagnosing pneumonia from chest X-rays better than radiologists can [14, 15]; examples like these underpin much of the widespread discussion of algorithmic automation in these tasks. In assessing the potential for algorithms, however, the community has implicitly equated the specific task of prediction with the general task of automation. We argue here that this implicit correspondence misses key aspects of the automation problem; a broader conceptualization of automation can lead directly to concrete benefits in some of the key application areas where this process is unfolding. We start from the premise that automation is more than just the replacement of human effort on a task; it is also the meta-decision of which instances of the task to automate. And it is here that algorithms distinguish themselves from earlier technology used for automation, because they can actively take part in this decision of what to automate. But as currently constructed, they are not set up to help with this second part of the problem. The automation problem, then, should involve an algorithm that on any given instance both (i) produces a prediction output; and (ii) additionally also produces a triage judgment of its effectiveness relative to the human effort it would replace on that instance.


Transfusion: Understanding Transfer Learning with Applications to Medical Imaging

arXiv.org Machine Learning

With the increasingly varied applications of deep learning, transfer learning has emerged as a critically important technique. However, the central question of how much feature reuse in transfer is the source of benefit remains unanswered. In this paper, we present an in-depth analysis of the effects of transfer, focusing on medical imaging, which is a particularly intriguing setting. Here, transfer learning is extremely popular, but data differences between pretraining and finetuing are considerable, reiterating the question of what is transferred. With experiments on two large scale medical imaging datasets, and CIFAR-10, we find transfer has almost negligible effects on performance, but significantly helps convergence speed. However, in all of these settings, convergence without transfer can be sped up dramatically by using only mean and variance statistics of the pretrained weights. Visualizing the lower layer filters shows that models trained from random initialization do not learn Gabor filters on medical images. We use CCA (canonical correlation analysis) to study the learned representations of the different models, finding that pretrained models are surprisingly similar to random initialization at higher layers. This similarity is evidenced both through model learning dynamics and a transfusion experiment, which explores the convergence speed using a subset of pretrained weights.


Insights on representational similarity in neural networks with canonical correlation

Neural Information Processing Systems

Comparing different neural network representations and determining how representations evolve over time remain challenging open questions in our understanding of the function of neural networks. Comparing representations in neural networks is fundamentally difficult as the structure of representations varies greatly, even across groups of networks trained on identical tasks, and over the course of training. Here, we develop projection weighted CCA (Canonical Correlation Analysis) as a tool for understanding neural networks, building off of SVCCA, a recently proposed method (Raghu et al, 2017). We first improve the core method, showing how to differentiate between signal and noise, and then apply this technique to compare across a group of CNNs, demonstrating that networks which generalize converge to more similar representations than networks which memorize, that wider networks converge to more similar solutions than narrow networks, and that trained networks with identical topology but different learning rates converge to distinct clusters with diverse representations. We also investigate the representational dynamics of RNNs, across both training and sequential timesteps, finding that RNNs converge in a bottom-up pattern over the course of training and that the hidden state is highly variable over the course of a sequence, even when accounting for linear transforms. Together, these results provide new insights into the function of CNNs and RNNs, and demonstrate the utility of using CCA to understand representations.


Insights on representational similarity in neural networks with canonical correlation

Neural Information Processing Systems

Comparing different neural network representations and determining how representations evolve over time remain challenging open questions in our understanding of the function of neural networks. Comparing representations in neural networks is fundamentally difficult as the structure of representations varies greatly, even across groups of networks trained on identical tasks, and over the course of training. Here, we develop projection weighted CCA (Canonical Correlation Analysis) as a tool for understanding neural networks, building off of SVCCA, a recently proposed method (Raghu et al, 2017). We first improve the core method, showing how to differentiate between signal and noise, and then apply this technique to compare across a group of CNNs, demonstrating that networks which generalize converge to more similar representations than networks which memorize, that wider networks converge to more similar solutions than narrow networks, and that trained networks with identical topology but different learning rates converge to distinct clusters with diverse representations. We also investigate the representational dynamics of RNNs, across both training and sequential timesteps, finding that RNNs converge in a bottom-up pattern over the course of training and that the hidden state is highly variable over the course of a sequence, even when accounting for linear transforms. Together, these results provide new insights into the function of CNNs and RNNs, and demonstrate the utility of using CCA to understand representations.


Direct Uncertainty Prediction with Applications to Healthcare

arXiv.org Machine Learning

Large labeled datasets for supervised learning are frequently constructed by assigning each instance to multiple human evaluators, and this leads to disagreement in the labels associated with a single instance. Here we consider the question of predicting the level of disagreement for a given instance, and we find an interesting phenomenon: direct prediction of uncertainty performs better than the two-step process of training a classifier and then using the classifier outputs to derive an uncertainty. We show stronger performance for predicting disagreement via this direct method both in a synthetic setting whose parameters we can fully control, and in a paradigmatic healthcare application involving multiple labels assigned by medical domain experts. We further show implications for allocating additional labeling effort toward instances with the greatest levels of predicted disagreement.


Insights on representational similarity in neural networks with canonical correlation

arXiv.org Artificial Intelligence

Comparing different neural network representations and determining how representations evolve over time remain challenging open questions in our understanding of the function of neural networks. Comparing representations in neural networks is fundamentally difficult as the structure of representations varies greatly, even across groups of networks trained on identical tasks, and over the course of training. Here, we develop projection weighted CCA (Canonical Correlation Analysis) as a tool for understanding neural networks, building off of SVCCA, a recently proposed method. We first improve the core method, showing how to differentiate between signal and noise, and then apply this technique to compare across a group of CNNs, demonstrating that networks which generalize converge to more similar representations than networks which memorize, that wider networks converge to more similar solutions than narrow networks, and that trained networks with identical topology but different learning rates converge to distinct clusters with diverse representations. We also investigate the representational dynamics of RNNs, across both training and sequential timesteps, finding that RNNs converge in a bottom-up pattern over the course of training and that the hidden state is highly variable over the course of a sequence, even when accounting for linear transforms. Together, these results provide new insights into the function of CNNs and RNNs, and demonstrate the utility of using CCA to understand representations.