Understanding Why Generalized Reweighting Does Not Improve Over ERM

Zhai, Runtian, Dan, Chen, Kolter, Zico, Ravikumar, Pradeep

arXiv.org Machine Learning 

It has now been well established that empirical risk minimization (ERM) can empirically achieve high test performance on a variety of tasks, particularly with modern overparameterized models where the number of parameters is much larger than the number of training samples. This strong performance of ERM however has been shown to degrade under distributional shift, where the training and test distributions are different [HS15, BGO16, Tat17]. There are two broad categories of distribution shift studied in recent years. The first is domain generalization, where the training distribution is a mixture of environments, while the test distribution contains new environments that do not appear in the training distribution. The hope in such cases is to learn "invariant features" that do not change across environments, in contrast to spurious features, such as the background in image classification instead of the object, and negation words such as "not" and "never" in language sentiment analysis instead of the sentence meaning itself. However, it has been empirically shown that overparameterized models trained via ERM tend to learn spurious features. The second is subpopulation shift, where the training distribution consists of a number of groups, and the test distribution is the groupconditional distribution of any group (or more generally, an arbitrary mixture of the training groups). Such subpopulation shift occurs in the context of fair machine learning, where the dataset is divided into demographic groups, and it is of interest to perform well on all such groups; as well as in learning with imbalanced classes, where each class is a group, and the model needs to perform well on all classes. While overparameterized models trained via ERM can achieve high average performance over the entire data domain, they have been shown to have low performance on underrepresented data subpopulations.