Balasubramanian, Vineeth N
FW-Shapley: Real-time Estimation of Weighted Shapley Values
Panda, Pranoy, Tandon, Siddharth, Balasubramanian, Vineeth N
Fair credit assignment is essential in various machine learning (ML) applications, and Shapley values have emerged as a valuable tool for this purpose. However, in critical ML applications such as data valuation and feature attribution, the uniform weighting of Shapley values across subset cardinalities leads to unintuitive credit assignments. To address this, weighted Shapley values were proposed as a generalization, allowing different weights for subsets with different cardinalities. Despite their advantages, similar to Shapley values, Weighted Shapley values suffer from exponential compute costs, making them impractical for high-dimensional datasets. To tackle this issue, we present two key contributions. Firstly, we provide a weighted least squares characterization of weighted Shapley values. Next, using this characterization, we propose Fast Weighted Shapley (FW-Shapley), an amortized framework for efficiently computing weighted Shapley values using a learned estimator. We further show that our estimator's training procedure is theoretically valid even though we do not use ground truth Weighted Shapley values during training. On the feature attribution task, we outperform the learned estimator FastSHAP by $27\%$ (on average) in terms of Inclusion AUC. For data valuation, we are much faster (14 times) while being comparable to the state-of-the-art KNN Shapley.
Interpretable Model Drift Detection
Panda, Pranoy, Srinivas, Kancheti Sai, Balasubramanian, Vineeth N, Sinha, Gaurav
Data in the real world often has an evolving distribution. Thus, machine learning models trained on such data get outdated over time. This phenomenon is called model drift. Knowledge of this drift serves two purposes: (i) Retain an accurate model and (ii) Discovery of knowledge or insights about change in the relationship between input features and output variable w.r.t. the model. Most existing works focus only on detecting model drift but offer no interpretability. In this work, we take a principled approach to study the problem of interpretable model drift detection from a risk perspective using a feature-interaction aware hypothesis testing framework, which enjoys guarantees on test power. The proposed framework is generic, i.e., it can be adapted to both classification and regression tasks. Experiments on several standard drift detection datasets show that our method is superior to existing interpretable methods (especially on real-world datasets) and on par with state-of-the-art black-box drift detection methods. We also quantitatively and qualitatively study the interpretability aspect including a case study on USENET2 dataset. We find our method focuses on model and drift sensitive features compared to baseline interpretable drift detectors.
Grounding Descriptions in Images informs Zero-Shot Visual Recognition
Halbe, Shaunak, Tian, Junjiao, Joseph, K J, Smith, James Seale, Stevo, Katherine, Balasubramanian, Vineeth N, Kira, Zsolt
Vision-language models (VLMs) like CLIP have been cherished for their ability to perform zero-shot visual recognition on open-vocabulary concepts. This is achieved by selecting the object category whose textual representation bears the highest similarity with the query image. While successful in some domains, this method struggles with identifying fine-grained entities as well as generalizing to unseen concepts that are not captured by the training distribution. Recent works attempt to mitigate these challenges by integrating category descriptions at test time, albeit yielding modest improvements. We attribute these limited gains to a fundamental misalignment between image and description representations, which is rooted in the pretraining structure of CLIP. In this paper, we propose GRAIN, a new pretraining strategy aimed at aligning representations at both fine and coarse levels simultaneously. Our approach learns to jointly ground textual descriptions in image regions along with aligning overarching captions with global image representations. To drive this pre-training, we leverage frozen Multimodal Large Language Models (MLLMs) to derive large-scale synthetic annotations. We demonstrate the enhanced zero-shot performance of our model compared to current state-of-the art methods across 11 diverse image classification datasets. Additionally, we introduce Products-2023, a newly curated, manually labeled dataset featuring novel concepts, and showcase our model's ability to recognize these concepts by benchmarking on it. Significant improvements achieved by our model on other downstream tasks like retrieval further highlight the superior quality of representations learned by our approach. Code available at https://github.com/shaunak27/grain-clip .
Detecting and Measuring Confounding Using Causal Mechanism Shifts
Reddy, Abbavaram Gowtham, Balasubramanian, Vineeth N
Detecting and measuring confounding effects from data is a key challenge in causal inference. Existing methods frequently assume causal sufficiency, disregarding the presence of unobserved confounding variables. Causal sufficiency is both unrealistic and empirically untestable. Additionally, existing methods make strong parametric assumptions about the underlying causal generative process to guarantee the identifiability of confounding variables. Relaxing the causal sufficiency and parametric assumptions and leveraging recent advancements in causal discovery and confounding analysis with non-i.i.d. data, we propose a comprehensive approach for detecting and measuring confounding. We consider various definitions of confounding and introduce tailored methodologies to achieve three objectives: (i) detecting and measuring confounding among a set of variables, (ii) separating observed and unobserved confounding effects, and (iii) understanding the relative strengths of confounding bias between different sets of variables. We present useful properties of a confounding measure and present measures that satisfy those properties. Empirical results support the theoretical analysis.
Teaching Transformers Causal Reasoning through Axiomatic Training
Vashishtha, Aniket, Kumar, Abhinav, Reddy, Abbavaram Gowtham, Balasubramanian, Vineeth N, Sharma, Amit
For text-based AI systems to interact in the real world, causal reasoning is an essential skill. Since interventional data is costly to generate, we study to what extent an agent can learn causal reasoning from passive data. Specifically, we consider an axiomatic training setup where an agent learns from multiple demonstrations of a causal axiom (or rule), rather than incorporating the axiom as an inductive bias or inferring it from data values. A key question is whether the agent would learn to generalize from the axiom demonstrations to new scenarios. For example, if a transformer model is trained on demonstrations of the causal transitivity axiom over small graphs, would it generalize to applying the transitivity axiom over large graphs? Our results, based on a novel axiomatic training scheme, indicate that such generalization is possible. We consider the task of inferring whether a variable causes another variable, given a causal graph structure. We find that a 67 million parameter transformer model, when trained on linear causal chains (along with some noisy variations) can generalize well to new kinds of graphs, including longer causal chains, causal chains with reversed order, and graphs with branching; even when it is not explicitly trained for such settings. Our model performs at par (or even better) than many larger language models such as GPT-4, Gemini Pro, and Phi-3. Overall, our axiomatic training framework provides a new paradigm of learning causal reasoning from passive data that can be used to learn arbitrary axioms, as long as sufficient demonstrations can be generated.
Advancing Ante-Hoc Explainable Models through Generative Adversarial Networks
Garg, Tanmay, Vemuri, Deepika, Balasubramanian, Vineeth N
This paper presents a novel concept learning framework for enhancing model interpretability and performance in visual classification tasks. Our approach appends an unsupervised explanation generator to the primary classifier network and makes use of adversarial training. During training, the explanation module is optimized to extract visual concepts from the classifier's latent representations, while the GAN-based module aims to discriminate images generated from concepts, from true images. This joint training scheme enables the model to implicitly align its internally learned concepts with human-interpretable visual properties. Comprehensive experiments demonstrate the robustness of our approach, while producing coherent concept activations. We analyse the learned concepts, showing their semantic concordance with object parts and visual attributes. We also study how perturbations in the adversarial training protocol impact both classification and concept acquisition. In summary, this work presents a significant step towards building inherently interpretable deep vision models with task-aligned concept representations - a key enabler for developing trustworthy AI for real-world perception tasks.
NESTER: An Adaptive Neurosymbolic Method for Causal Effect Estimation
Reddy, Abbavaram Gowtham, Balasubramanian, Vineeth N
Causal effect estimation from observational data is a central problem in causal inference. Methods based on potential outcomes framework solve this problem by exploiting inductive biases and heuristics from causal inference. Each of these methods addresses a specific aspect of causal effect estimation, such as controlling propensity score, enforcing randomization, etc., by designing neural network (NN) architectures and regularizers. In this paper, we propose an adaptive method called Neurosymbolic Causal Effect Estimator (NESTER), a generalized method for causal effect estimation. NESTER integrates the ideas used in existing methods based on multi-head NNs for causal effect estimation into one framework. We design a Domain Specific Language (DSL) tailored for causal effect estimation based on causal inductive biases used in literature. We conduct a theoretical analysis to investigate NESTER's efficacy in estimating causal effects. Our comprehensive empirical results show that NESTER performs better than state-of-the-art methods on benchmark datasets.
Rethinking Robustness of Model Attributions
Kamath, Sandesh, Mittal, Sankalp, Deshpande, Amit, Balasubramanian, Vineeth N
For machine learning models to be reliable and trustworthy, their decisions must be interpretable. As these models find increasing use in safety-critical applications, it is important that not just the model predictions but also their explanations (as feature attributions) be robust to small human-imperceptible input perturbations. Recent works have shown that many attribution methods are fragile and have proposed improvements in either these methods or the model training. We observe two main causes for fragile attributions: first, the existing metrics of robustness (e.g., top-k intersection) over-penalize even reasonable local shifts in attribution, thereby making random perturbations to appear as a strong attack, and second, the attribution can be concentrated in a small region even when there are multiple important parts in an image. To rectify this, we propose simple ways to strengthen existing metrics and attribution methods that incorporate locality of pixels in robustness metrics and diversity of pixel locations in attributions. Towards the role of model training in attributional robustness, we empirically observe that adversarially trained models have more robust attributions on smaller datasets, however, this advantage disappears in larger datasets. Code is available at https://github.com/ksandeshk/LENS.
On Counterfactual Data Augmentation Under Confounding
Reddy, Abbavaram Gowtham, Bachu, Saketh, Dash, Saloni, Sharma, Charchit, Sharma, Amit, Balasubramanian, Vineeth N
Counterfactual data augmentation has recently emerged as a method to mitigate confounding biases in the training data. These biases, such as spurious correlations, arise due to various observed and unobserved confounding variables in the data generation process. In this paper, we formally analyze how confounding biases impact downstream classifiers and present a causal viewpoint to the solutions based on counterfactual data augmentation. We explore how removing confounding biases serves as a means to learn invariant features, ultimately aiding in generalization beyond the observed data distribution. Additionally, we present a straightforward yet powerful algorithm for generating counterfactual images, which effectively mitigates the influence of confounding effects on downstream classifiers. Through experiments on MNIST variants and the CelebA datasets, we demonstrate how our simple augmentation method helps existing state-of-the-art methods achieve good results.
MADG: Margin-based Adversarial Learning for Domain Generalization
Dayal, Aveen, B., Vimal K., Cenkeramaddi, Linga Reddy, Mohan, C. Krishna, Kumar, Abhinav, Balasubramanian, Vineeth N
Domain Generalization (DG) techniques have emerged as a popular approach to address the challenges of domain shift in Deep Learning (DL), with the goal of generalizing well to the target domain unseen during the training. In recent years, numerous methods have been proposed to address the DG setting, among which one popular approach is the adversarial learning-based methodology. The main idea behind adversarial DG methods is to learn domain-invariant features by minimizing a discrepancy metric. However, most adversarial DG methods use 0-1 loss based $\mathcal{H}\Delta\mathcal{H}$ divergence metric. In contrast, the margin loss-based discrepancy metric has the following advantages: more informative, tighter, practical, and efficiently optimizable. To mitigate this gap, this work proposes a novel adversarial learning DG algorithm, MADG, motivated by a margin loss-based discrepancy metric. The proposed MADG model learns domain-invariant features across all source domains and uses adversarial training to generalize well to the unseen target domain. We also provide a theoretical analysis of the proposed MADG model based on the unseen target error bound. Specifically, we construct the link between the source and unseen domains in the real-valued hypothesis space and derive the generalization bound using margin loss and Rademacher complexity. We extensively experiment with the MADG model on popular real-world DG datasets, VLCS, PACS, OfficeHome, DomainNet, and TerraIncognita. We evaluate the proposed algorithm on DomainBed's benchmark and observe consistent performance across all the datasets.