Simsek, Berfin
Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence
Simsek, Berfin, Bendjeddou, Amire, Hsu, Daniel
This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons $f^*(x) \!=\! \sum_{j=1}^k \! \sigma^*(v_j^T x)$ where $v_1, \dots, v_k$ are unit vectors, and $\sigma^*$ lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case ($k\!=\!1$), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using $n \! \asymp \! k \log k$ neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When $ v_i^T v_j \!=\! \beta \! \geq \! 0$ for all $i \neq j$, we prove the existence of a sharp threshold $\beta_c \!=\! c/(c+k)$ at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.
Learning Associative Memories with Gradient Descent
Cabannes, Vivien, Simsek, Berfin, Bietti, Alberto
This work focuses on the training dynamics of one associative memory module storing outer products of token embeddings. We reduce this problem to the study of a system of particles, which interact according to properties of the data distribution and correlations between embeddings. Through theory and experiments, we provide several insights. In overparameterized regimes, we obtain logarithmic growth of the ``classification margins.'' Yet, we show that imbalance in token frequencies and memory interferences due to correlated embeddings lead to oscillatory transitory regimes. The oscillations are more pronounced with large step sizes, which can create benign loss spikes, although these learning rates speed up the dynamics and accelerate the asymptotic convergence. In underparameterized regimes, we illustrate how the cross-entropy loss can lead to suboptimal memorization schemes. Finally, we assess the validity of our findings on small Transformer models.
The Loss Landscape of Shallow ReLU-like Neural Networks: Stationary Points, Saddle Escaping, and Network Embedding
Wu, Zhengqing, Simsek, Berfin, Ged, Francois
In this paper, we investigate the loss landscape of one-hidden-layer neural networks with ReLU-like activation functions trained with the empirical squared loss. As the activation function is non-differentiable, it is so far unclear how to completely characterize the stationary points. We propose the conditions for stationarity that apply to both non-differentiable and differentiable cases. Additionally, we show that, if a stationary point does not contain "escape neurons", which are defined with first-order conditions, then it must be a local minimum. Moreover, for the scalar-output case, the presence of an escape neuron guarantees that the stationary point is not a local minimum. Our results refine the description of the saddle-to-saddle training process starting from infinitesimally small (vanishing) initialization for shallow ReLU-like networks, linking saddle escaping directly with the parameter changes of escape neurons. Moreover, we are also able to fully discuss how network embedding, which is to instantiate a narrower network within a wider network, reshapes the stationary points.
Weight-space symmetry in deep networks gives rise to permutation saddles, connected by equal-loss valleys across the loss landscape
Brea, Johanni, Simsek, Berfin, Illing, Bernd, Gerstner, Wulfram
The permutation symmetry of neurons in each layer of a deep neural network gives rise not only to multiple equivalent global minima of the loss function, but also to first-order saddle points located on the path between the global minima. In a network of $d-1$ hidden layers with $n_k$ neurons in layers $k = 1, \ldots, d$, we construct smooth paths between equivalent global minima that lead through a `permutation point' where the input and output weight vectors of two neurons in the same hidden layer $k$ collide and interchange. We show that such permutation points are critical points with at least $n_{k+1}$ vanishing eigenvalues of the Hessian matrix of second derivatives indicating a local plateau of the loss function. We find that a permutation point for the exchange of neurons $i$ and $j$ transits into a flat valley (or generally, an extended plateau of $n_{k+1}$ flat dimensions) that enables all $n_k!$ permutations of neurons in a given layer $k$ at the same loss value. Moreover, we introduce high-order permutation points by exploiting the recursive structure in neural network functions, and find that the number of $K^{\text{th}}$-order permutation points is at least by a factor $\sum_{k=1}^{d-1}\frac{1}{2!^K}{n_k-K \choose K}$ larger than the (already huge) number of equivalent global minima. In two tasks, we illustrate numerically that some of the permutation points correspond to first-order saddles (`permutation saddles'): first, in a toy network with a single hidden layer on a function approximation task and, second, in a multilayer network on the MNIST task. Our geometric approach yields a lower bound on the number of critical points generated by weight-space symmetries and provides a simple intuitive link between previous mathematical results and numerical observations.