Goto

Collaborating Authors

 Gupta, Sharut


Understanding the Role of Equivariance in Self-supervised Learning

arXiv.org Machine Learning

Contrastive learning has been a leading paradigm for self-supervised learning, but it is widely observed that it comes at the price of sacrificing useful features (\eg colors) by being invariant to data augmentations. Given this limitation, there has been a surge of interest in equivariant self-supervised learning (E-SSL) that learns features to be augmentation-aware. However, even for the simplest rotation prediction method, there is a lack of rigorous understanding of why, when, and how E-SSL learns useful features for downstream tasks. To bridge this gap between practice and theory, we establish an information-theoretic perspective to understand the generalization ability of E-SSL. In particular, we identify a critical explaining-away effect in E-SSL that creates a synergy between the equivariant and classification tasks. This synergy effect encourages models to extract class-relevant features to improve its equivariant prediction, which, in turn, benefits downstream tasks requiring semantic features. Based on this perspective, we theoretically analyze the influence of data transformations and reveal several principles for practical designs of E-SSL. Our theory not only aligns well with existing E-SSL methods but also sheds light on new directions by exploring the benefits of model equivariance. We believe that a theoretically grounded understanding on the role of equivariance would inspire more principled and advanced designs in this field. Code is available at https://github.com/kaotty/Understanding-ESSL.


An Information Criterion for Controlled Disentanglement of Multimodal Data

arXiv.org Artificial Intelligence

Multimodal representation learning seeks to relate and decompose information inherent in multiple modalities. By disentangling modality-specific information from information that is shared across modalities, we can improve interpretability and robustness and enable downstream tasks such as the generation of counterfactual outcomes. Separating the two types of information is challenging since they are often deeply entangled in many real-world applications. We present a comprehensive analysis of the optimality of each disentangled representation, particularly focusing on the scenario not covered in prior work where the so-called Minimum Necessary Information (MNI) point is not attainable. SSL successfully learns shared and modality-specific features on multiple synthetic and real-world datasets and consistently outperforms baselines on various downstream tasks, including prediction tasks for vision-language data, as well as molecule-phenotype retrieval tasks for biological data. Humans understand and interact with the world using multiple senses, each providing unique and complementary information essential for forming a comprehensive mental representation of the environment. Large multimodal representation learning models such as CLIP (Radford et al., 2021), trained through self-supervised learning, maximally capture the mutual information shared across multiple modalities, exploiting the assumption of multi-view redundancy (Tosh et al., 2021; Sridharan & Kakade, 2008). This property indicates that shared information between modalities is exactly what is relevant for downstream tasks. However, the modality gap, rooted in the inherent differences in representational nature and information content across modalities (Liang et al., 2022b; Ramasinghe et al., 2024; Huh et al., 2024), leads to the misalignment between modalities and restricts the application of these methods in many real-world multimodal scenarios.


In-Context Symmetries: Self-Supervised Learning through Contextual World Models

arXiv.org Artificial Intelligence

At the core of self-supervised learning for vision is the idea of learning invariant or equivariant representations with respect to a set of data transformations. This approach, however, introduces strong inductive biases, which can render the representations fragile in downstream tasks that do not conform to these symmetries. In this work, drawing insights from world models, we propose to instead learn a general representation that can adapt to be invariant or equivariant to different transformations by paying attention to context -- a memory module that tracks task-specific states, actions, and future states. Here, the action is the transformation, while the current and future states respectively represent the input's representation before and after the transformation. Our proposed algorithm, Contextual Self-Supervised Learning (ContextSSL), learns equivariance to all transformations (as opposed to invariance). In this way, the model can learn to encode all relevant features as general representations while having the versatility to tail down to task-wise symmetries when given a few examples as the context. Empirically, we demonstrate significant performance gains over existing methods on equivariance-related tasks, supported by both qualitative and quantitative evaluations.


Removing Biases from Molecular Representations via Information Maximization

arXiv.org Artificial Intelligence

High-throughput drug screening - using cell imaging or gene expression measurements as readouts of drug effect - is a critical tool in biotechnology to assess and understand the relationship between the chemical structure and biological activity of a drug. Since large-scale screens have to be divided into multiple experiments, a key difficulty is dealing with batch effects, which can introduce systematic errors and non-biological associations in the data. We propose InfoCORE, an Information maximization approach for COnfounder REmoval, to effectively deal with batch effects and obtain refined molecular representations. InfoCORE establishes a variational lower bound on the conditional mutual information of the latent representations given a batch identifier. It adaptively reweighs samples to equalize their implied batch distribution. Extensive experiments on drug screening data reveal InfoCORE's superior performance in a multitude of tasks including molecular property prediction and molecule-phenotype retrieval. Additionally, we show results for how InfoCORE offers a versatile framework and resolves general distribution shifts and issues of data fairness by minimizing correlation with spurious features or removing sensitive attributes. Representation learning (Bengio et al., 2013) has become pivotal in drug discovery (Wu et al., 2018) and understanding biological systems (Yang et al., 2021b). It serves as a pillar for recognizing drug mechanisms, predicting a drug's activity and toxicity, and identifying disease-associated chemical structures. A central challenge in this context is to accurately capture the nuanced relationship between the chemical structure of a small molecule and its biological or physical attributes. Most molecular representation learning methods only encode a molecule's chemical identity and hence provide unimodal representations (Wang et al., 2022; Xu et al., 2021b). A limitation of such techniques is that molecules with similar structures can have very different effects in the cellular context.


Context is Environment

arXiv.org Machine Learning

One key problem in AI research is to build systems that generalize across a wide range of test environments. In principle, these algorithms should discard spurious correlations present only in certain training environments, and capture invariant patterns appearing across conditions. For example, we would like to build self-driving systems that, while trained on data from environments with varying weather conditions, traffic conditions, and driving rules, can perform satisfactorily in completely new environments. Unfortunately, this has so far been a far cry: models trained catastrophically fail to generalize to unseen weather conditions [Lechner et al., 2022]. Despite its importance, how to perform well beyond the distribution of the training data remains a burning question. In fact, entire research groups are devoted to study generalization, major international conferences offer well-attended workshops dedicated to the issue [Wald et al., 2023], and news articles remind us of the profound societal impact from failures of ML systems [Angwin et al., 2016]. Research efforts have so far produced domain generalization algorithms that fall into one out of two broad categories. On the one hand, invariance proposals [Ganin et al., 2016, Peters et al., 2016, Arjovsky et al., 2019], illustrated in Figure 1a, discard all environment-specific information, thus removing excessive signal about the problem. On the other hand, marginal transfer proposals [Blanchard et al., 2011, Li et al., 2016, Zhang et al., 2020, Bao and Karaletsos, 2023], also illustrated in Figure 1b, summarize observed inputs in each environment as a coarse embedding, diluting important signal at the example level.


AdaBest: Minimizing Client Drift in Federated Learning via Adaptive Bias Estimation

arXiv.org Artificial Intelligence

In Federated Learning (FL), a number of clients or devices collaborate to train a model without sharing their data. Models are optimized locally at each client and further communicated to a central hub for aggregation. While FL is an appealing decentralized training paradigm, heterogeneity among data from different clients can cause the local optimization to drift away from the global objective. In order to estimate and therefore remove this drift, variance reduction techniques have been incorporated into FL optimization recently. However, these approaches inaccurately estimate the clients' drift and ultimately fail to remove it properly. In this work, we propose an adaptive algorithm that accurately estimates drift across clients. In comparison to previous works, our approach necessitates less storage and communication bandwidth, as well as lower compute costs. Additionally, our proposed methodology induces stability by constraining the norm of estimates for client drift, making it more practical for large scale FL. Experimental findings demonstrate that the proposed algorithm converges significantly faster and achieves higher accuracy than the baselines across various FL benchmarks.


Structuring Representation Geometry with Rotationally Equivariant Contrastive Learning

arXiv.org Artificial Intelligence

Self-supervised learning converts raw perceptual data such as images to a compact space where simple Euclidean distances measure meaningful variations in data. In this paper, we extend this formulation by adding additional geometric structure to the embedding space by enforcing transformations of input space to correspond to simple (i.e., linear) transformations of embedding space. Specifically, in the contrastive learning setting, we introduce an equivariance objesctive and theoretically prove that its minima forces augmentations on input space to correspond to rotations on the spherical embedding space. We show that merely combining our equivariant loss with a non-collapse term results in non-trivial representations, without requiring invariance to data augmentations. Optimal performance is achieved by also encouraging approximate invariance, where input augmentations correspond to small rotations. Our method, Care: Contrastive Augmentation-induced Rotational Equivariance, leads to improved performance on downstream tasks, and ensures sensitivity in embedding space to important variations in data (e.g., color) that standard contrastive methods do not achieve. Code is available at https://github.com/Sharut/CARE.


QU-BraTS: MICCAI BraTS 2020 Challenge on Quantifying Uncertainty in Brain Tumor Segmentation - Analysis of Ranking Scores and Benchmarking Results

arXiv.org Artificial Intelligence

Deep learning (DL) models have provided state-of-the-art performance in various medical imaging benchmarking challenges, including the Brain Tumor Segmentation (BraTS) challenges. However, the task of focal pathology multi-compartment segmentation (e.g., tumor and lesion sub-regions) is particularly challenging, and potential errors hinder translating DL models into clinical workflows. Quantifying the reliability of DL model predictions in the form of uncertainties could enable clinical review of the most uncertain regions, thereby building trust and paving the way toward clinical translation. Several uncertainty estimation methods have recently been introduced for DL medical image segmentation tasks. Developing scores to evaluate and compare the performance of uncertainty measures will assist the end-user in making more informed decisions. In this study, we explore and evaluate a score developed during the BraTS 2019 and BraTS 2020 task on uncertainty quantification (QU-BraTS) and designed to assess and rank uncertainty estimates for brain tumor multi-compartment segmentation. This score (1) rewards uncertainty estimates that produce high confidence in correct assertions and those that assign low confidence levels at incorrect assertions, and (2) penalizes uncertainty measures that lead to a higher percentage of under-confident correct assertions. We further benchmark the segmentation uncertainties generated by 14 independent participating teams of QU-BraTS 2020, all of which also participated in the main BraTS segmentation task. Overall, our findings confirm the importance and complementary value that uncertainty estimates provide to segmentation algorithms, highlighting the need for uncertainty quantification in medical image analyses.


Addressing catastrophic forgetting for medical domain expansion

arXiv.org Artificial Intelligence

Model brittleness is a key concern when deploying deep learning models in real-world medical settings. A model that has high performance at one institution may suffer a significant decline in performance when tested at other institutions. While pooling datasets from multiple institutions and re-training may provide a straightforward solution, it is often infeasible and may compromise patient privacy. An alternative approach is to fine-tune the model on subsequent institutions after training on the original institution. Notably, this approach degrades model performance at the original institution, a phenomenon known as catastrophic forgetting. In this paper, we develop an approach to address catastrophic forgetting based on elastic weight consolidation combined with modulation of batch normalization statistics under two scenarios: first, for expanding the domain from one imaging system's data to another imaging system's, and second, for expanding the domain from a large multi-institutional dataset to another single institution dataset. We show that our approach outperforms several other state-of-the-art approaches and provide theoretical justification for the efficacy of batch normalization modulation. The results of this study are generally applicable to the deployment of any clinical deep learning model which requires domain expansion.