Wald, Yoav
Time After Time: Deep-Q Effect Estimation for Interventions on When and What to do
Wald, Yoav, Goldstein, Mark, Efroni, Yonathan, van Amsterdam, Wouter A. C., Ranganath, Rajesh
Problems in fields such as healthcare, robotics, and finance requires reasoning about the value both of what decision or action to take and when to take it. The prevailing hope is that artificial intelligence will support such decisions by estimating the causal effect of policies such as how to treat patients or how to allocate resources over time. However, existing methods for estimating the effect of a policy struggle with \emph{irregular time}. They either discretize time, or disregard the effect of timing policies. We present a new deep-Q algorithm that estimates the effect of both when and what to do called Earliest Disagreement Q-Evaluation (EDQ). EDQ makes use of recursion for the Q-function that is compatible with flexible sequence models, such as transformers. EDQ provides accurate estimates under standard assumptions. We validate the approach through experiments on survival time and tumor growth tasks.
Novel Node Category Detection Under Subpopulation Shift
Chung, Hsing-Huan, Chaudhari, Shravan, Wald, Yoav, Han, Xing, Ghosh, Joydeep
It is often important to detect nodes of novel categories under such distribution shifts for safety or insight discovery purposes. We introduce a new approach, Recall-Constrained Optimization with Selective Link Prediction (RECO-SLIP), to detect nodes belonging to novel categories in attributed graphs under subpopulation shifts. By integrating a recall-constrained learning framework with a sample-efficient link prediction mechanism, RECO-SLIP addresses the dual challenges of resilience against subpopulation shifts and the effective exploitation of graph structure. Our extensive empirical evaluation across multiple graph datasets demonstrates the superior performance of RECO-SLIP over existing methods.
Data Augmentations for Improved (Large) Language Model Generalization
Feder, Amir, Wald, Yoav, Shi, Claudia, Saria, Suchi, Blei, David
The reliance of text classifiers on spurious correlations can lead to poor generalization at deployment, raising concerns about their use in safety-critical domains such as healthcare. In this work, we propose to use counterfactual data augmentation, guided by knowledge of the causal structure of the data, to simulate interventions on spurious features and to learn more robust text classifiers. We show that this strategy is appropriate in prediction problems where the label is spuriously correlated with an attribute. Under the assumptions of such problems, we discuss the favorable sample complexity of counterfactual data augmentation, compared to importance re-weighting. Pragmatically, we match examples using auxiliary data, based on diff-in-diff methodology, and use a large language model (LLM) to represent a conditional probability of text. Through extensive experimentation on learning caregiver-invariant predictors of clinical diagnoses from medical narratives and on semi-synthetic data, we demonstrate that our method for simulating interventions improves out-of-distribution (OOD) accuracy compared to baseline invariant learning algorithms.
Don't blame Dataset Shift! Shortcut Learning due to Gradients and Cross Entropy
Puli, Aahlad, Zhang, Lily, Wald, Yoav, Ranganath, Rajesh
Common explanations for shortcut learning assume that the shortcut improves prediction under the training distribution but not in the test distribution. Thus, models trained via the typical gradient-based optimization of cross-entropy, which we call default-ERM, utilize the shortcut. However, even when the stable feature determines the label in the training distribution and the shortcut does not provide any additional information, like in perception tasks, default-ERM still exhibits shortcut learning. Why are such solutions preferred when the loss for default-ERM can be driven to zero using the stable feature alone? By studying a linear perception task, we show that default-ERM's preference for maximizing the margin leads to models that depend more on the shortcut than the stable feature, even without overparameterization. This insight suggests that default-ERM's implicit inductive bias towards max-margin is unsuitable for perception tasks. Instead, we develop an inductive bias toward uniform margins and show that this bias guarantees dependence only on the perfect stable feature in the linear perception task. We develop loss functions that encourage uniform-margin solutions, called margin control (MARG-CTRL). MARG-CTRL mitigates shortcut learning on a variety of vision and language tasks, showing that better inductive biases can remove the need for expensive two-stage shortcut-mitigating methods in perception tasks.
Malign Overfitting: Interpolation Can Provably Preclude Invariance
Wald, Yoav, Yona, Gal, Shalit, Uri, Carmon, Yair
Learned classifiers should often possess certain invariance properties meant to encourage fairness, robustness, or out-of-distribution generalization. However, multiple recent works empirically demonstrate that common invariance-inducing regularizers are ineffective in the over-parameterized regime, in which classifiers perfectly fit (i.e. interpolate) the training data. This suggests that the phenomenon of ``benign overfitting," in which models generalize well despite interpolating, might not favorably extend to settings in which robustness or fairness are desirable. In this work we provide a theoretical justification for these observations. We prove that -- even in the simplest of settings -- any interpolating learning rule (with arbitrarily small margin) will not satisfy these invariance properties. We then propose and analyze an algorithm that -- in the same setting -- successfully learns a non-interpolating classifier that is provably invariant. We validate our theoretical observations on simulated data and the Waterbirds dataset.
Explaining in Style: Training a GAN to explain a classifier in StyleSpace
Lang, Oran, Gandelsman, Yossi, Yarom, Michal, Wald, Yoav, Elidan, Gal, Hassidim, Avinatan, Freeman, William T., Isola, Phillip, Globerson, Amir, Irani, Michal, Mosseri, Inbar
Image classification models can depend on multiple different semantic attributes of the image. An explanation of the decision of the classifier needs to both discover and visualize these properties. Here we present StylEx, a method for doing this, by training a generative model to specifically explain multiple attributes that underlie classifier decisions. A natural source for such attributes is the StyleSpace of StyleGAN, which is known to generate semantically meaningful dimensions in the image. However, because standard GAN training is not dependent on the classifier, it may not represent these attributes which are important for the classifier decision, and the dimensions of StyleSpace may represent irrelevant attributes. To overcome this, we propose a training procedure for a StyleGAN, which incorporates the classifier model, in order to learn a classifier-specific StyleSpace. Explanatory attributes are then selected from this space. These can be used to visualize the effect of changing multiple attributes per image, thus providing image-specific explanations. We apply StylEx to multiple domains, including animals, leaves, faces and retinal images. For these, we show how an image can be modified in different ways to change its classifier output. Our results show that the method finds attributes that align well with semantic ones, generate meaningful image-specific explanations, and are human-interpretable as measured in user-studies.
Robust Conditional Probabilities
Wald, Yoav, Globerson, Amir
Conditional probabilities are a core concept in machine learning. For example, optimal prediction of a label $Y$ given an input $X$ corresponds to maximizing the conditional probability of $Y$ given $X$. A common approach to inference tasks is learning a model of conditional probabilities. However, these models are often based on strong assumptions (e.g., log-linear models), and hence their estimate of conditional probabilities is not robust and is highly dependent on the validity of their assumptions. Here we propose a framework for reasoning about conditional probabilities without assuming anything about the underlying distributions, except knowledge of their second order marginals, which can be estimated from data. We show how this setting leads to guaranteed bounds on conditional probabilities, which can be calculated efficiently in a variety of settings, including structured-prediction. Finally, we apply them to semi-supervised deep learning, obtaining results competitive with variational autoencoders.