Saxe, Andrew
Revisiting the Role of Relearning in Semantic Dementia
Jarvis, Devon, Klar, Verena, Klein, Richard, Rosman, Benjamin, Saxe, Andrew
Patients with semantic dementia (SD) present with remarkably consistent atrophy of neurons in the anterior temporal lobe and behavioural impairments, such as graded loss of category knowledge. While relearning of lost knowledge has been shown in acute brain injuries such as stroke, it has not been widely supported in chronic cognitive diseases such as SD. Previous research has shown that deep linear artificial neural networks exhibit stages of semantic learning akin to humans. Here, we use a deep linear network to test the hypothesis that relearning during disease progression rather than particular atrophy cause the specific behavioural patterns associated with SD. After training the network to generate the common semantic features of various hierarchically organised objects, neurons are successively deleted to mimic atrophy while retraining the model. The model with relearning and deleted neurons reproduced errors specific to SD, including prototyping errors and cross-category confusions. This suggests that relearning is necessary for artificial neural networks to reproduce the behavioural patterns associated with SD in the absence of \textit{output} non-linearities. Our results support a theory of SD progression that results from continuous relearning of lost information. Future research should revisit the role of relearning as a contributing factor to cognitive diseases.
Training Dynamics of In-Context Learning in Linear Attention
Zhang, Yedi, Singh, Aaditya K., Latham, Peter E., Saxe, Andrew
While attention-based models have demonstrated the remarkable ability of in-context learning, the theoretical understanding of how these models acquired this ability through gradient descent training is still preliminary. Towards answering this question, we study the gradient descent dynamics of multi-head linear self-attention trained for in-context linear regression. We examine two parametrizations of linear self-attention: one with the key and query weights merged as a single matrix (common in theoretical studies), and one with separate key and query matrices (closer to practical settings). For the merged parametrization, we show the training dynamics has two fixed points and the loss trajectory exhibits a single, abrupt drop. We derive an analytical time-course solution for a certain class of datasets and initialization. For the separate parametrization, we show the training dynamics has exponentially many fixed points and the loss exhibits saddle-to-saddle dynamics, which we reduce to scalar ordinary differential equations. During training, the model implements principal component regression in context with the number of principal components increasing over training time. Overall, we characterize how in-context learning abilities evolve during gradient descent training of linear attention, revealing dynamics of abrupt acquisition versus progressive improvements in models with different parametrizations.
When Are Bias-Free ReLU Networks Like Linear Networks?
Zhang, Yedi, Saxe, Andrew, Latham, Peter E.
We investigate the expressivity and learning dynamics of bias-free ReLU networks. We firstly show that two-layer bias-free ReLU networks have limited expressivity: the only odd function two-layer bias-free ReLU networks can express is a linear one. We then show that, under symmetry conditions on the data, these networks have the same learning dynamics as linear networks. This allows us to give closed-form time-course solutions to certain two-layer bias-free ReLU networks, which has not been done for nonlinear networks outside the lazy learning regime. While deep bias-free ReLU networks are more expressive than their two-layer counterparts, they still share a number of similarities with deep linear networks. These similarities enable us to leverage insights from linear networks, leading to a novel understanding of bias-free ReLU networks. Overall, our results show that some properties established for bias-free ReLU networks arise due to equivalence to linear networks, and suggest that including bias or considering asymmetric data are avenues to engage with nonlinear behaviors.
Get rich quick: exact solutions reveal how unbalanced initializations promote rapid feature learning
Kunin, Daniel, Raventós, Allan, Dominé, Clémentine, Chen, Feng, Klindt, David, Saxe, Andrew, Ganguli, Surya
While the impressive performance of modern neural networks is often attributed to their capacity to efficiently extract task-relevant features from data, the mechanisms underlying this rich feature learning regime remain elusive, with much of our theoretical understanding stemming from the opposing lazy regime. In this work, we derive exact solutions to a minimal model that transitions between lazy and rich learning, precisely elucidating how unbalanced layer-specific initialization variances and learning rates determine the degree of feature learning. Our analysis reveals that they conspire to influence the learning regime through a set of conserved quantities that constrain and modify the geometry of learning trajectories in parameter and function space. We extend our analysis to more complex linear models with multiple neurons, outputs, and layers and to shallow nonlinear networks with piecewise linear activation functions. In linear networks, rapid feature learning only occurs with balanced initializations, where all layers learn at similar speeds. While in nonlinear networks, unbalanced initializations that promote faster learning in earlier layers can accelerate rich learning. Through a series of experiments, we provide evidence that this unbalanced rich regime drives feature learning in deep finite-width networks, promotes interpretability of early layers in CNNs, reduces the sample complexity of learning hierarchical data, and decreases the time to grokking in modular arithmetic. Our theory motivates further exploration of unbalanced initializations to enhance efficient feature learning.
Tilting the Odds at the Lottery: the Interplay of Overparameterisation and Curricula in Neural Networks
Mannelli, Stefano Sarao, Ivashinka, Yaraslau, Saxe, Andrew, Saglietti, Luca
A wide range of empirical and theoretical works have shown that overparameterisation can amplify the performance of neural networks. According to the lottery ticket hypothesis, overparameterised networks have an increased chance of containing a sub-network that is well-initialised to solve the task at hand. A more parsimonious approach, inspired by animal learning, consists in guiding the learner towards solving the task by curating the order of the examples, i.e. providing a curriculum. However, this learning strategy seems to be hardly beneficial in deep learning applications. In this work, we undertake an analytical study that connects curriculum learning and overparameterisation. In particular, we investigate their interplay in the online learning setting for a 2-layer network in the XOR-like Gaussian Mixture problem. Our results show that a high degree of overparameterisation -- while simplifying the problem -- can limit the benefit from curricula, providing a theoretical account of the ineffectiveness of curricula in deep learning.
A Theory of Unimodal Bias in Multimodal Learning
Zhang, Yedi, Latham, Peter E., Saxe, Andrew
Using multiple input streams simultaneously in training multimodal neural networks is intuitively advantageous, but practically challenging. A key challenge is unimodal bias, where a network overly relies on one modality and ignores others during joint training. While unimodal bias is well-documented empirically, our theoretical understanding of how architecture and data statistics influence this bias remains incomplete. Here we develop a theory of unimodal bias with deep multimodal linear networks. We calculate the duration of the unimodal phase in learning as a function of the depth at which modalities are fused within the network, dataset statistics, and initialization. We find that the deeper the layer at which fusion occurs, the longer the unimodal phase. A long unimodal phase can lead to a generalization deficit and permanent unimodal bias in the overparametrized regime. In addition, our theory reveals the modality learned first is not necessarily the modality that contributes more to the output. Our results, derived for multimodal linear networks, extend to ReLU networks in certain settings. Taken together, this work illuminates pathologies of multimodal learning under joint training, showing that late and intermediate fusion architectures can give rise to long unimodal phases and permanent unimodal bias.
The RL Perceptron: Generalisation Dynamics of Policy Learning in High Dimensions
Patel, Nishil, Lee, Sebastian, Mannelli, Stefano Sarao, Goldt, Sebastian, Saxe, Andrew
Reinforcement learning (RL) algorithms have proven transformative in a range of domains. To tackle real-world domains, these systems often use neural networks to learn policies directly from pixels or other high-dimensional sensory input. By contrast, much theory of RL has focused on discrete state spaces or worst-case analysis, and fundamental questions remain about the dynamics of policy learning in high-dimensional settings. Here, we propose a solvable high-dimensional model of RL that can capture a variety of learning protocols, and derive its typical dynamics as a set of closed-form ordinary differential equations (ODEs). We derive optimal schedules for the learning rates and task difficulty - analogous to annealing schemes and curricula during training in RL - and show that the model exhibits rich behaviour, including delayed learning under sparse rewards; a variety of learning regimes depending on reward baselines; and a speed-accuracy trade-off driven by reward stringency. Experiments on variants of the Procgen game "Bossfight" and Arcade Learning Environment game "Pong" also show such a speed-accuracy trade-off in practice. Together, these results take a step towards closing the gap between theory and practice in high-dimensional RL.
Know your audience: specializing grounded language models with listener subtraction
Singh, Aaditya K., Ding, David, Saxe, Andrew, Hill, Felix, Lampinen, Andrew K.
Effective communication requires adapting to the idiosyncrasies of each communicative context--such as the common ground shared with each partner. Humans demonstrate this ability to specialize to their audience in many contexts, such as the popular game Dixit. We take inspiration from Dixit to formulate a multi-agent image reference game where a (trained) speaker model is rewarded for describing a target image such that one (pretrained) listener model can correctly identify it among distractors, but another listener cannot. To adapt, the speaker must exploit differences in the knowledge it shares with the different listeners. We show that finetuning an attention-based adapter between a CLIP vision encoder and a large language model in this contrastive, multi-agent setting gives rise to context-dependent natural language specialization from rewards only, without direct supervision. Through controlled experiments, we show that training a speaker with two listeners that perceive differently, using our method, allows the speaker to adapt to the idiosyncracies of the listeners. Furthermore, we show zero-shot transfer of the specialization to real-world data. Our experiments demonstrate a method for specializing grounded language models without direct supervision and highlight the interesting research challenges posed by complex multi-agent communication.
Probing transfer learning with a model of synthetic correlated datasets
Gerace, Federica, Saglietti, Luca, Mannelli, Stefano Sarao, Saxe, Andrew, Zdeborová, Lenka
Transfer learning can significantly improve the sample efficiency of neural networks, by exploiting the relatedness between a data-scarce target task and a data-abundant source task. Despite years of successful applications, transfer learning practice often relies on ad-hoc solutions, while theoretical understanding of these procedures is still limited. In the present work, we re-think a solvable model of synthetic data as a framework for modeling correlation between data-sets. This setup allows for an analytic characterization of the generalization performance obtained when transferring the learned feature map from the source to the target task. Focusing on the problem of training two-layer networks in a binary classification setting, we show that our model can capture a range of salient features of transfer learning with real data. Moreover, by exploiting parametric control over the correlation between the two data-sets, we systematically investigate under which conditions the transfer of features is beneficial for generalization.
Continual Learning in the Teacher-Student Setup: Impact of Task Similarity
Lee, Sebastian, Goldt, Sebastian, Saxe, Andrew
Continual learning-the ability to learn many tasks in sequence-is critical for artificial learning systems. Yet standard training methods for deep networks often suffer from catastrophic forgetting, where learning new tasks erases knowledge of earlier tasks. While catastrophic forgetting labels the problem, the theoretical reasons for interference between tasks remain unclear. Here, we attempt to narrow this gap between theory and practice by studying continual learning in the teacher-student setup. We extend previous analytical work on two-layer networks in the teacher-student setup to multiple teachers. Using each teacher to represent a different task, we investigate how the relationship between teachers affects the amount of forgetting and transfer exhibited by the student when the task switches. In line with recent work, we find that when tasks depend on similar features, intermediate task similarity leads to greatest forgetting. However, feature similarity is only one way in which tasks may be related. The teacher-student approach allows us to disentangle task similarity at the level of readouts (hidden-to-output weights) and features (input-to-hidden weights). We find a complex interplay between both types of similarity, initial transfer/forgetting rates, maximum transfer/forgetting, and long-term transfer/forgetting. Together, these results help illuminate the diverse factors contributing to catastrophic forgetting.