Avoiding spurious correlations via logit correction

Liu, Sheng, Zhang, Xu, Sekhar, Nitesh, Wu, Yue, Singhal, Prateek, Fernandez-Granda, Carlos

arXiv.org Artificial Intelligence 

Empirical studies suggest that machine learning models trained with empirical risk minimization (ERM) often rely on attributes that may be spuriously correlated with the class labels. Such models typically lead to poor performance during inference for data lacking such correlations. In this work, we explicitly consider a situation where potential spurious correlations are present in the majority of training data. In contrast with existing approaches, which use the ERM model outputs to detect the samples without spurious correlations and either heuristically upweight or upsample those samples, we propose the logit correction (LC) loss, a simple yet effective improvement on the softmax cross-entropy loss, to correct the sample logit. We demonstrate that minimizing the LC loss is equivalent to maximizing the group-balanced accuracy, so the proposed LC could mitigate the negative impacts of spurious correlations. Our extensive experimental results further reveal that the proposed LC loss outperforms state-of-the-art solutions on multiple popular benchmarks by a large margin, an average 5.5% absolute improvement, without access to spurious attribute labels. LC is also competitive with oracle methods that make use of the attribute labels. In practical applications such as self-driving cars, a robust machine learning model must be designed to comprehend its surroundings in rare conditions that may not have been well-represented in its training set. However, deep neural networks can be negatively affected by spurious correlations between observed features and class labels that hold for well-represented groups but not for rare groups. For example, when classifying stop signs versus other traffic signs in autonomous driving, 99% of the stop signs in the United States are red. A model trained with standard empirical risk minimization (ERM) may learn models with low average training error that rely on the spurious background attribute instead of the desired "STOP" text on the sign, resulting in high average accuracy but low worst-group accuracy (e.g., making errors on yellow color or faded stop signs).

Duplicate Docs Excel Report

Title
None found

Similar Docs  Excel Report  more

TitleSimilaritySource
None found