data model
Robust Learning with Progressive Data Expansion Against Spurious Correlation
While deep learning models have shown remarkable performance in various tasks, they are susceptible to learning non-generalizable spurious features rather than the core features that are genuinely correlated to the true label. In this paper, beyond existing analyses of linear models, we theoretically examine the learning process of a two-layer nonlinear convolutional neural network in the presence of spurious features. Our analysis suggests that imbalanced data groups and easily learnable spurious features can lead to the dominance of spurious features during the learning process. In light of this, we propose a new training algorithm called PDE that efficiently enhances the model's robustness for a better worst-group performance. PDE begins with a group-balanced subset of training data and progressively expands it to facilitate the learning of the core features. Experiments on synthetic and real-world benchmark datasets confirm the superior performance of our method on models such as ResNets and Transformers. On average, our method achieves a 2.8%improvement in worst-group accuracy compared with the state-of-the-art method, while enjoying up to 10 faster training efficiency. Codes are available at https://github.com/uclaml/PDE.
Learning Beyond the Gaussian Data: Learning Dynamics of Neural Networks on an Expressive and Cumulant-Controllable Data Model
Ure, Onat, Demir, Samet, Dogan, Zafer
We study the effect of high-order statistics of data on the learning dynamics of neural networks (NNs) by using a moment-controllable non-Gaussian data model. Considering the expressivity of two-layer neural networks, we first construct the data model as a generative two-layer NN where the activation function is expanded by using Hermite polynomials. This allows us to achieve interpretable control over high-order cumulants such as skewness and kurtosis through the Hermite coefficients while keeping the data model realistic. Using samples generated from the data model, we perform controlled online learning experiments with a two-layer NN. Our results reveal a moment-wise progression in training: networks first capture low-order statistics such as mean and covariance, and progressively learn high-order cumulants. Finally, we pretrain the generative model on the Fashion-MNIST dataset and leverage the generated samples for further experiments. The results of these additional experiments confirm our conclusions and show the utility of the data model in a real-world scenario. Overall, our proposed approach bridges simplified data assumptions and practical data complexity, which offers a principled framework for investigating distributional effects in machine learning and signal processing.
Iterative Feature Matching: Toward Provable Domain Generalization with Logarithmic Environments
Domain generalization aims at performing well on unseen test environments with data from a limited number of training environments. Despite a proliferation of proposed algorithms for this task, assessing their performance both theoretically and empirically is still very challenging. Distributional matching algorithms such as (Conditional) Domain Adversarial Networks [Ganin et al., 2016, Long et al., 2018] are popular and enjoy empirical success, but they lack formal guarantees. Other approaches such as Invariant Risk Minimization (IRM) require a prohibitively large number of training environments---linear in the dimension of the spurious feature space $d_s$---even on simple data models like the one proposed by [Rosenfeld et al., 2021]. Under a variant of this model, we show that ERM and IRM can fail to find the optimal invariant predictor with $o(d_s)$ environments. We then present an iterative feature matching algorithm that is guaranteed with high probability to find the optimal invariant predictor after seeing only $O(\log d_s)$ environments. Our results provide the first theoretical justification for distribution-matching algorithms widely used in practice under a concrete nontrivial data model.