Towards Demystifying Representation Learning with Non-contrastive Self-supervision

Wang, Xiang, Chen, Xinlei, Du, Simon S., Tian, Yuandong

arXiv.org Machine Learning 

Self-supervised learning recently emerges as a promising direction to learn representations without manual labels. While contrastive learning (Oord et al., 2018; Tian et al., 2019; Bachman et al., 2019; He et al., 2020; Chen et al., 2020a) minimizes the distance of representation between positive pairs, and maximizes such distances between negative pairs, recently, non-contrastive self-supervised learning (abbreviated as nc-SSL) is able to learn nontrivial representation with only positive pairs, using an extra predictor and a stop-gradient operation. Furthermore, the learned representation shows comparable (or even better) performance for downstream tasks (e.g., image classification) (Grill et al., 2020; Chen & He, 2020). This brings about two fundamental questions: (1) why the learned representation does not collapse to trivial (i.e., constant) solutions, and (2) without negative pairs, what representation nc-SSL learns from the training and how the learned representation reduces the sample complexity in downstream tasks. While many theoretical results on contrastive SSL (Arora et al., 2019; Lee et al., 2020; Tosh et al., 2020; Wen & Li, 2021) do exist, similar study on nc-SSL has been very rare. As one of the first work towards this direction, Tian et al. (2021) show that while the global optimum of the non-contrastive loss is indeed a trivial one, following gradient direction in nc-SSL, one can find a local optimum that admits a nontrivial representation. Based on their theoretical findings on gradient-based methods, they proposed a new approach, DirectPred, that directly sets the predictor using the eigen-decomposition of the correlation matrix of input before the predictor, rather than updating it with gradient methods. As a method for nc-SSL, DirectPred shows comparable or better performance in multiple datasets, including CIFAR-10 (Krizhevsky et al., 2009), STL-10 (Coates et al., 2011) and ImageNet (Deng et al., 2009), compared to BYOL (Grill et al., 2020) and SimSiam (Chen & He, 2020) that optimize the predictor using gradient descent.