Goto

Collaborating Authors

 Chaudhari, Pratik


Adapting Machine Learning Diagnostic Models to New Populations Using a Small Amount of Data: Results from Clinical Neuroscience

arXiv.org Artificial Intelligence

Machine learning (ML) has shown great promise for revolutionizing a number of areas, including healthcare. However, it is also facing a reproducibility crisis, especially in medicine. ML models that are carefully constructed from and evaluated on a training set might not generalize well on data from different patient populations or acquisition instrument settings and protocols. We tackle this problem in the context of neuroimaging of Alzheimer's disease (AD), schizophrenia (SZ) and brain aging. We develop a weighted empirical risk minimization approach that optimally combines data from a source group, e.g., subjects are stratified by attributes such as sex, age group, race and clinical cohort to make predictions on a target group, e.g., other sex, age group, etc. using a small fraction (10%) of data from the target group. We apply this method to multi-source data of 15,363 individuals from 20 neuroimaging studies to build ML models for diagnosis of AD and SZ, and estimation of brain age. We found that this approach achieves substantially better accuracy than existing domain adaptation techniques: it obtains area under curve greater than 0.95 for AD classification, area under curve greater than 0.7 for SZ classification and mean absolute error less than 5 years for brain age prediction on all target groups, achieving robustness to variations of scanners, protocols, and demographic or clinical characteristics. In some cases, it is even better than training on all data from the target group, because it leverages the diversity and size of a larger training set. We also demonstrate the utility of our models for prognostic tasks such as predicting disease progression in individuals with mild cognitive impairment. Critically, our brain age prediction models lead to new clinical insights regarding correlations with neurophysiological tests.


A picture of the space of typical learnable tasks

arXiv.org Artificial Intelligence

We develop information geometric techniques to understand the representations learned by deep networks when they are trained on different tasks using supervised, meta-, semi-supervised and contrastive learning. We shed light on the following phenomena that relate to the structure of the space of tasks: (1) the manifold of probabilistic models trained on different tasks using different representation learning methods is effectively low-dimensional; (2) supervised learning on one task results in a surprising amount of progress even on seemingly dissimilar tasks; progress on other tasks is larger if the training task has diverse classes; (3) the structure of the space of tasks indicated by our analysis is consistent with parts of the Wordnet phylogenetic tree; (4) episodic meta-learning algorithms and supervised learning traverse different trajectories during training but they fit similar models eventually; (5) contrastive and semi-supervised learning methods traverse trajectories similar to those of supervised learning. We use classification tasks constructed from the CIFAR-10 and Imagenet datasets to study these phenomena.


The Value of Out-of-Distribution Data

arXiv.org Artificial Intelligence

We expect the generalization error to improve with more samples from a similar task, and to deteriorate with more samples from an out-of-distribution (OOD) task. In this work, we show a counter-intuitive phenomenon: the generalization error of a task can be a non-monotonic function of the number of OOD samples. As the number of OOD samples increases, the generalization error on the target task improves before deteriorating beyond a threshold. In other words, there is value in training on small amounts of OOD data. We use Fisher's Linear Discriminant on synthetic datasets and deep networks on computer vision benchmarks such as MNIST, CIFAR-10, CINIC-10, PACS and DomainNet to demonstrate and analyze this phenomenon. In the idealistic setting where we know which samples are OOD, we show that these non-monotonic trends can be exploited using an appropriately weighted objective of the target and OOD empirical risk. While its practical utility is limited, this does suggest that if we can detect OOD samples, then there may be ways to benefit from them. When we do not know which samples are OOD, we show how a number of go-to strategies such as data-augmentation, hyper-parameter optimization, and pre-training are not enough to ensure that the target generalization error does not deteriorate with the number of OOD samples in the dataset.


Budgeting Counterfactual for Offline RL

arXiv.org Artificial Intelligence

The main challenge of offline reinforcement learning, where data is limited, arises from a sequence of counterfactual reasoning dilemmas within the realm of potential actions: What if we were to choose a different course of action? These circumstances frequently give rise to extrapolation errors, which tend to accumulate exponentially with the problem horizon. Hence, it becomes crucial to acknowledge that not all decision steps are equally important to the final outcome, and to budget the number of counterfactual decisions a policy make in order to control the extrapolation. Contrary to existing approaches that use regularization on either the policy or value function, we propose an approach to explicitly bound the amount of out-of-distribution actions during training. Specifically, our method utilizes dynamic programming to decide where to extrapolate and where not to, with an upper bound on the decisions different from behavior policy. It balances between the potential for improvement from taking out-of-distribution actions and the risk of making errors due to extrapolation. Theoretically, we justify our method by the constrained optimality of the fixed point solution to our $Q$ updating rules. Empirically, we show that the overall performance of our method is better than the state-of-the-art offline RL methods on tasks in the widely-used D4RL benchmarks.


The Training Process of Many Deep Networks Explores the Same Low-Dimensional Manifold

arXiv.org Artificial Intelligence

We develop information-geometric techniques to analyze the trajectories of the predictions of deep networks during training. By examining the underlying high-dimensional probabilistic models, we reveal that the training process explores an effectively low-dimensional manifold. Networks with a wide range of architectures, sizes, trained using different optimization methods, regularization techniques, data augmentation techniques, and weight initializations lie on the same manifold in the prediction space. We study the details of this manifold to find that networks with different architectures follow distinguishable trajectories but other factors have a minimal influence; larger networks train along a similar manifold as that of smaller networks, just faster; and networks initialized at very different parts of the prediction space converge to the solution along a similar manifold.


Taming AI Bots: Controllability of Neural States in Large Language Models

arXiv.org Artificial Intelligence

We tackle the question of whether an agent can, by suitable choice of prompts, control an AI bot to any state. To that end, we first introduce a formal definition of ``meaning'' that is amenable to analysis. Then, we characterize ``meaningful data'' on which large language models (LLMs) are ostensibly trained, and ``well-trained LLMs'' through conditions that are largely met by today's LLMs. While a well-trained LLM constructs an embedding space of meanings that is Euclidean, meanings themselves do not form a vector (linear) subspace, but rather a quotient space within. We then characterize the subset of meanings that can be reached by the state of the LLMs for some input prompt, and show that a well-trained bot can reach any meaning albeit with small probability. We then introduce a stronger notion of controllability as {\em almost certain reachability}, and show that, when restricted to the space of meanings, an AI bot is controllable. We do so after introducing a functional characterization of attentive AI bots, and finally derive necessary and sufficient conditions for controllability. The fact that AI bots are controllable means that an adversary could steer them towards any state. However, the sampling process can be designed to counteract adverse actions and avoid reaching undesirable regions of state space before their boundary is crossed.


Learning Capacity: A Measure of the Effective Dimensionality of a Model

arXiv.org Artificial Intelligence

We exploit a formal correspondence between thermodynamics and inference, where the number of samples can be thought of as the inverse temperature, to define a "learning capacity'' which is a measure of the effective dimensionality of a model. We show that the learning capacity is a tiny fraction of the number of parameters for many deep networks trained on typical datasets, depends upon the number of samples used for training, and is numerically consistent with notions of capacity obtained from the PAC-Bayesian framework. The test error as a function of the learning capacity does not exhibit double descent. We show that the learning capacity of a model saturates at very small and very large sample sizes; this provides guidelines, as to whether one should procure more data or whether one should search for new architectures, to improve performance. We show how the learning capacity can be used to understand the effective dimensionality, even for non-parametric models such as random forests and $k$-nearest neighbor classifiers.


Fast Diffusion Probabilistic Model Sampling through the lens of Backward Error Analysis

arXiv.org Artificial Intelligence

Denoising diffusion probabilistic models (DDPMs) are a class of powerful generative models. The past few years have witnessed the great success of DDPMs in generating high-fidelity samples. A significant limitation of the DDPMs is the slow sampling procedure. DDPMs generally need hundreds or thousands of sequential function evaluations (steps) of neural networks to generate a sample. This paper aims to develop a fast sampling method for DDPMs requiring much fewer steps while retaining high sample quality. The inference process of DDPMs approximates solving the corresponding diffusion ordinary differential equations (diffusion ODEs) in the continuous limit. This work analyzes how the backward error affects the diffusion ODEs and the sample quality in DDPMs. We propose fast sampling through the \textbf{Restricting Backward Error schedule (RBE schedule)} based on dynamically moderating the long-time backward error. Our method accelerates DDPMs without any further training. Our experiments show that sampling with an RBE schedule generates high-quality samples within only 8 to 20 function evaluations on various benchmark datasets. We achieved 12.01 FID in 8 function evaluations on the ImageNet $128\times128$, and a $20\times$ speedup compared with previous baseline samplers.


Beyond mAP: Towards better evaluation of instance segmentation

arXiv.org Artificial Intelligence

Correctness of instance segmentation constitutes counting the number of objects, correctly localizing all predictions and classifying each localized prediction. Average Precision is the de-facto metric used to measure all these constituents of segmentation. However, this metric does not penalize duplicate predictions in the high-recall range, and cannot distinguish instances that are localized correctly but categorized incorrectly. This weakness has inadvertently led to network designs that achieve significant gains in AP but also introduce a large number of false positives. We therefore cannot rely on AP to choose a model that provides an optimal tradeoff between false positives and high recall. To resolve this dilemma, we review alternative metrics in the literature and propose two new measures to explicitly measure the amount of both spatial and categorical duplicate predictions. We also propose a Semantic Sorting and NMS module to remove these duplicates based on a pixel occupancy matching scheme. Experiments show that modern segmentation networks have significant gains in AP, but also contain a considerable amount of duplicates. Our Semantic Sorting and NMS can be added as a plug-and-play module to mitigate hedged predictions and preserve AP.


Bias in Machine Learning Models Can Be Significantly Mitigated by Careful Training: Evidence from Neuroimaging Studies

arXiv.org Artificial Intelligence

Despite the great promise that machine learning has offered in many fields of medicine, it has also raised concerns about potential biases and poor generalization across genders, age distributions, races and ethnicities, hospitals, and data acquisition equipment and protocols. In the current study, and in the context of three brain diseases, we provide evidence which suggests that when properly trained, machine learning models can generalize well across diverse conditions and do not necessarily suffer from bias. Specifically, by using multi-study magnetic resonance imaging consortia for diagnosing Alzheimer's disease, schizophrenia, and autism spectrum disorder, we find that well-trained models have a high area-under-the-curve (AUC) on subjects across different subgroups pertaining to attributes such as gender, age, racial groups, and different clinical studies and are unbiased under multiple fairness metrics such as demographic parity difference, equalized odds difference, equal opportunity difference etc. We find that models that incorporate multi-source data from demographic, clinical, genetic factors and cognitive scores are also unbiased. These models have better predictive AUC across subgroups than those trained only with imaging features but there are also situations when these additional features do not help.