Wu, Yinjun
DISCRET: Synthesizing Faithful Explanations For Treatment Effect Estimation
Wu, Yinjun, Keoliya, Mayank, Chen, Kan, Velingker, Neelay, Li, Ziyang, Getzen, Emily J, Long, Qi, Naik, Mayur, Parikh, Ravi B, Wong, Eric
Designing faithful yet accurate AI models is challenging, particularly in the field of individual treatment effect estimation (ITE). ITE prediction models deployed in critical settings such as healthcare should ideally be (i) accurate, and (ii) provide faithful explanations. However, current solutions are inadequate: state-of-the-art black-box models do not supply explanations, post-hoc explainers for black-box models lack faithfulness guarantees, and self-interpretable models greatly compromise accuracy. To address these issues, we propose DISCRET, a self-interpretable ITE framework that synthesizes faithful, rule-based explanations for each sample. A key insight behind DISCRET is that explanations can serve dually as database queries to identify similar subgroups of samples. We provide a novel RL algorithm to efficiently synthesize these explanations from a large search space. We evaluate DISCRET on diverse tasks involving tabular, image, and text data. DISCRET outperforms the best self-interpretable models and has accuracy comparable to the best black-box models while providing faithful explanations. DISCRET is available at https://github.com/wuyinjun-1993/DISCRET-ICML2024.
MDB: Interactively Querying Datasets and Models
Naik, Aaditya, Stein, Adam, Wu, Yinjun, Wong, Eric, Naik, Mayur
As models are trained and deployed, developers need to be able to systematically debug errors that emerge in the machine learning pipeline. We present MDB, a debugging framework for interactively querying datasets and models. MDB integrates functional programming with relational algebra to build expressive queries over a database of datasets and model predictions. Queries are reusable and easily modified, enabling debuggers to rapidly iterate and refine queries to discover and characterize errors and model behaviors. We evaluate MDB on object detection, bias discovery, image classification, and data imputation tasks across self-driving videos, large language models, and medical records. Our experiments show that MDB enables up to 10x faster and 40\% shorter queries than other baselines. In a user study, we find developers can successfully construct complex queries that describe errors of machine learning models.
Do Machine Learning Models Learn Statistical Rules Inferred from Data?
Naik, Aaditya, Wu, Yinjun, Naik, Mayur, Wong, Eric
Machine learning models can make critical errors that are easily hidden within vast amounts of data. Such errors often run counter to rules based on human intuition. However, rules based on human knowledge are challenging to scale or to even formalize. We thereby seek to infer statistical rules from the data and quantify the extent to which a model has learned them. We propose a framework SQRL that integrates logic-based methods with statistical inference to derive these rules from a model's training data without supervision. We further show how to adapt models at test time to reduce rule violations and produce more coherent predictions. SQRL generates up to 300K rules over datasets from vision, tabular, and language settings. We uncover up to 158K violations of those rules by state-of-the-art models for classification, object detection, and data imputation. Test-time adaptation reduces these violations by up to 68.7% with relative performance improvement up to 32%. SQRL is available at https://github.com/DebugML/sqrl.
Rectifying Group Irregularities in Explanations for Distribution Shift
Stein, Adam, Wu, Yinjun, Wong, Eric, Naik, Mayur
It is well-known that real-world changes constituting distribution shift adversely affect model performance. How to characterize those changes in an interpretable manner is poorly understood. Existing techniques to address this problem take the form of shift explanations that elucidate how to map samples from the original distribution toward the shifted one by reducing the disparity between these two distributions. However, these methods can introduce group irregularities, leading to explanations that are less feasible and robust. To address these issues, we propose Group-aware Shift Explanations (GSE), a method that produces interpretable explanations by leveraging worst-group optimization to rectify group irregularities. We demonstrate how GSE not only maintains group structures, such as demographic and hierarchical subpopulations, but also enhances feasibility and robustness in the resulting explanations in a wide range of tabular, language, and image settings.
Learning to Select Pivotal Samples for Meta Re-weighting
Wu, Yinjun, Stein, Adam, Gardner, Jacob, Naik, Mayur
Sample re-weighting strategies provide a promising mechanism to deal with imperfect training data in machine learning, such as noisily labeled or class-imbalanced data. One such strategy involves formulating a bi-level optimization problem called the meta re-weighting problem, whose goal is to optimize performance on a small set of perfect pivotal samples, called meta samples. Many approaches have been proposed to efficiently solve this problem. However, all of them assume that a perfect meta sample set is already provided while we observe that the selections of meta sample set is performance critical. In this paper, we study how to learn to identify such a meta sample set from a large, imperfect training set, that is subsequently cleaned and used to optimize performance in the meta re-weighting setting. We propose a learning framework which reduces the meta samples selection problem to a weighted K-means clustering problem through rigorously theoretical analysis. We propose two clustering methods within our learning framework, Representation-based clustering method (RBC) and Gradient-based clustering method (GBC), for balancing performance and computational efficiency. Empirical studies demonstrate the performance advantage of our methods over various baseline methods.
Dynamic Gaussian Mixture based Deep Generative Model For Robust Forecasting on Sparse Multivariate Time Series
Wu, Yinjun, Ni, Jingchao, Cheng, Wei, Zong, Bo, Song, Dongjin, Chen, Zhengzhang, Liu, Yanchi, Zhang, Xuchao, Chen, Haifeng, Davidson, Susan
Forecasting on sparse multivariate time series (MTS) aims to model the predictors of future values of time series given their incomplete past, which is important for many emerging applications. However, most existing methods process MTS's individually, and do not leverage the dynamic distributions underlying the MTS's, leading to sub-optimal results when the sparsity is high. To address this challenge, we propose a novel generative model, which tracks the transition of latent clusters, instead of isolated feature representations, to achieve robust modeling. It is characterized by a newly designed dynamic Gaussian mixture distribution, which captures the dynamics of clustering structures, and is used for emitting time series. The generative model is parameterized by neural networks. A structured inference network is also designed for enabling inductive analysis. A gating mechanism is further introduced to dynamically tune the Gaussian mixture distributions. Extensive experimental results on a variety of real-life datasets demonstrate the effectiveness of our method.
DeltaGrad: Rapid retraining of machine learning models
Wu, Yinjun, Dobriban, Edgar, Davidson, Susan B.
Machine learning models are not static and may need to be retrained on slightly changed datasets, for instance, with the addition or deletion of a set of data points. This has many applications, including privacy, robustness, bias reduction, and uncertainty quantifcation. However, it is expensive to retrain models from scratch. To address this problem, we propose the DeltaGrad algorithm for rapid retraining machine learning models based on information cached during the training phase. We provide both theoretical and empirical support for the effectiveness of DeltaGrad, and show that it compares favorably to the state of the art.