Kong, Xiangnan
End-to-End Deep Learning for Structural Brain Imaging: A Unified Framework
Su, Yao, Han, Keqi, Zeng, Mingjie, Sun, Lichao, Zhan, Liang, Yang, Carl, He, Lifang, Kong, Xiangnan
Brain imaging analysis is fundamental in neuroscience, providing valuable insights into brain structure and function. Traditional workflows follow a sequential pipeline--brain extraction, registration, segmentation, parcellation, network generation, and classification--treating each step as an independent task. These methods rely heavily on task-specific training data and expert intervention to correct intermediate errors, making them particularly burdensome for high-dimensional neuroimaging data, where annotations and quality control are costly and time-consuming. We introduce Uni-Brain, a unified end-to-end framework that integrates all processing steps into a single optimization process, allowing tasks to interact and refine each other. Unlike traditional approaches that require extensive task-specific annotations, UniBrain operates with minimal supervision, leveraging only low-cost labels ( i.e., classification and extraction) and a single labeled atlas. By jointly optimizing extraction, registration, segmentation, parcellation, network generation, and classification, UniBrain enhances both accuracy and computational efficiency while significantly reducing annotation effort. Experimental results demonstrate its superiority over existing methods across multiple tasks, offering a more scalable and reliable solution for neuroimaging analysis.
SkipSNN: Efficiently Classifying Spike Trains with Event-attention
Yin, Hang, Su, Yao, Liu, Liping, Hartvigsen, Thomas, Dai, Xin, Kong, Xiangnan
Spike train classification has recently become an important topic in the machine learning community, where each spike train is a binary event sequence with \emph{temporal-sparsity of signals of interest} and \emph{temporal-noise} properties. A promising model for it should follow the design principle of performing intensive computation only when signals of interest appear. So such tasks use mainly Spiking Neural Networks (SNNs) due to their consideration of temporal-sparsity of spike trains. However, the basic mechanism of SNNs ignore the temporal-noise issue, which makes them computationally expensive and thus high power consumption for analyzing spike trains on resource-constrained platforms. As an event-driven model, an SNN neuron makes a reaction given any input signals, making it difficult to quickly find signals of interest. In this paper, we introduce an event-attention mechanism that enables SNNs to dynamically highlight useful signals of the original spike trains. To this end, we propose SkipSNN, which extends existing SNN models by learning to mask out noise by skipping membrane potential updates and shortening the effective size of the computational graph. This process is analogous to how people choose to open and close their eyes to filter the information they see. We evaluate SkipSNN on various neuromorphic tasks and demonstrate that it achieves significantly better computational efficiency and classification accuracy than other state-of-the-art SNNs.
Multi-State Brain Network Discovery
Yin, Hang, Su, Yao, Liu, Xinyue, Hartvigsen, Thomas, Li, Yanhua, Kong, Xiangnan
Brain network discovery aims to find nodes and edges from the spatio-temporal signals obtained by neuroimaging data, such as fMRI scans of human brains. Existing methods tend to derive representative or average brain networks, assuming observed signals are generated by only a single brain activity state. However, the human brain usually involves multiple activity states, which jointly determine the brain activities. The brain regions and their connectivity usually exhibit intricate patterns that are difficult to capture with only a single-state network. Recent studies find that brain parcellation and connectivity change according to the brain activity state. We refer to such brain networks as multi-state, and this mixture can help us understand human behavior. Thus, compared to a single-state network, a multi-state network can prevent us from losing crucial information of cognitive brain network. To achieve this, we propose a new model called MNGL (Multi-state Network Graphical Lasso), which successfully models multi-state brain networks by combining CGL (coherent graphical lasso) with GMM (Gaussian Mixture Model). Using both synthetic and real world ADHD 200 fMRI datasets, we demonstrate that MNGL outperforms recent state-of-the-art alternatives by discovering more explanatory and realistic results.
One-shot Joint Extraction, Registration and Segmentation of Neuroimaging Data
Su, Yao, Qian, Zhentian, Ma, Lei, He, Lifang, Kong, Xiangnan
Brain extraction, registration and segmentation are indispensable preprocessing steps in neuroimaging studies. The aim is to extract the brain from raw imaging scans (i.e., extraction step), align it with a target brain image (i.e., registration step) and label the anatomical brain regions (i.e., segmentation step). Conventional studies typically focus on developing separate methods for the extraction, registration and segmentation tasks in a supervised setting. The performance of these methods is largely contingent on the quantity of training samples and the extent of visual inspections carried out by experts for error correction. Nevertheless, collecting voxel-level labels and performing manual quality control on high-dimensional neuroimages (e.g., 3D MRI) are expensive and time-consuming in many medical studies. In this paper, we study the problem of one-shot joint extraction, registration and segmentation in neuroimaging data, which exploits only one labeled template image (a.k.a. atlas) and a few unlabeled raw images for training. We propose a unified end-to-end framework, called JERS, to jointly optimize the extraction, registration and segmentation tasks, allowing feedback among them. Specifically, we use a group of extraction, registration and segmentation modules to learn the extraction mask, transformation and segmentation mask, where modules are interconnected and mutually reinforced by self-supervision. Empirical results on real-world datasets demonstrate that our proposed method performs exceptionally in the extraction, registration and segmentation tasks. Our code and data can be found at https://github.com/Anonymous4545/JERS
Finding Short Signals in Long Irregular Time Series with Continuous-Time Attention Policy Networks
Hartvigsen, Thomas, Thadajarassiri, Jidapa, Kong, Xiangnan, Rundensteiner, Elke
Irregularly-sampled time series (ITS) are native to high-impact domains like healthcare, where measurements are collected over time at uneven intervals. However, for many classification problems, only small portions of long time series are often relevant to the class label. In this case, existing ITS models often fail to classify long series since they rely on careful imputation, which easily over- or under-samples the relevant regions. Using this insight, we then propose CAT, a model that classifies multivariate ITS by explicitly seeking highly-relevant portions of an input series' timeline. CAT achieves this by integrating three components: (1) A Moment Network learns to seek relevant moments in an ITS's continuous timeline using reinforcement learning. (2) A Receptor Network models the temporal dynamics of both observations and their timing localized around predicted moments. (3) A recurrent Transition Model models the sequence of transitions between these moments, cultivating a representation with which the series is classified. Using synthetic and real data, we find that CAT outperforms ten state-of-the-art methods by finding short signals in long irregular time series.
One-Shot Learning on Attributed Sequences
Zhuang, Zhongfang, Kong, Xiangnan, Rundensteiner, Elke, Arora, Aditya, Zouaoui, Jihane
One-shot learning has become an important research topic in the last decade with many real-world applications. The goal of one-shot learning is to classify unlabeled instances when there is only one labeled example per class. Conventional problem setting of one-shot learning mainly focuses on the data that is already in feature space (such as images). However, the data instances in real-world applications are often more complex and feature vectors may not be available. In this paper, we study the problem of one-shot learning on attributed sequences, where each instance is composed of a set of attributes (e.g., user profile) and a sequence of categorical items (e.g., clickstream). This problem is important for a variety of real-world applications ranging from fraud prevention to network intrusion detection. This problem is more challenging than conventional one-shot learning since there are dependencies between attributes and sequences. We design a deep learning framework OLAS to tackle this problem. The proposed OLAS utilizes a twin network to generalize the features from pairwise attributed sequence examples. Empirical results on real-world datasets demonstrate the proposed OLAS can outperform the state-of-the-art methods under a rich variety of parameter settings.
Signed Distance-based Deep Memory Recommender
Tran, Thanh, Liu, Xinyue, Lee, Kyumin, Kong, Xiangnan
Personalized recommendation algorithms learn a user's preference for an item by measuring a distance/similarity between them. However, some of the existing recommendation models (e.g., matrix factorization) assume a linear relationship between the user and item. This approach limits the capacity of recommender systems, since the interactions between users and items in real-world applications are much more complex than the linear relationship. To overcome this limitation, in this paper, we design and propose a deep learning framework called Signed Distance-based Deep Memory Recommender, which captures non-linear relationships between users and items explicitly and implicitly, and work well in both general recommendation task and shopping basket-based recommendation task. Through an extensive empirical study on six real-world datasets in the two recommendation tasks, our proposed approach achieved significant improvement over ten state-of-the-art recommendation models.
TreeGAN: Syntax-Aware Sequence Generation with Generative Adversarial Networks
Liu, Xinyue, Kong, Xiangnan, Liu, Lei, Chiang, Kuorong
Generative Adversarial Networks (GANs) have shown great capacity on image generation, in which a discriminative model guides the training of a generative model to construct images that resemble real images. Recently, GANs have been extended from generating images to generating sequences (e.g., poems, music and codes). Existing GANs on sequence generation mainly focus on general sequences, which are grammar-free. In many real-world applications, however, we need to generate sequences in a formal language with the constraint of its corresponding grammar. For example, to test the performance of a database, one may want to generate a collection of SQL queries, which are not only similar to the queries of real users, but also follow the SQL syntax of the target database. Generating such sequences is highly challenging because both the generator and discriminator of GANs need to consider the structure of the sequences and the given grammar in the formal language. To address these issues, we study the problem of syntax-aware sequence generation with GANs, in which a collection of real sequences and a set of pre-defined grammatical rules are given to both discriminator and generator. We propose a novel GAN framework, namely TreeGAN, to incorporate a given Context-Free Grammar (CFG) into the sequence generation process. In TreeGAN, the generator employs a recurrent neural network (RNN) to construct a parse tree. Each generated parse tree can then be translated to a valid sequence of the given grammar. The discriminator uses a tree-structured RNN to distinguish the generated trees from real trees. We show that TreeGAN can generate sequences for any CFG and its generation fully conforms with the given syntax. Experiments on synthetic and real data sets demonstrated that TreeGAN significantly improves the quality of the sequence generation in context-free languages.
Learning Role-based Graph Embeddings
Ahmed, Nesreen K., Rossi, Ryan, Lee, John Boaz, Willke, Theodore L., Zhou, Rong, Kong, Xiangnan, Eldardiry, Hoda
Random walks are at the heart of many existing network embedding methods. However, such algorithms have many limitations that arise from the use of random walks, e.g., the features resulting from these methods are unable to transfer to new nodes and graphs as they are tied to vertex identity. In this work, we introduce the Role2Vec framework which uses the flexible notion of attributed random walks, and serves as a basis for generalizing existing methods such as DeepWalk, node2vec, and many others that leverage random walks. Our proposed framework enables these methods to be more widely applicable for both transductive and inductive learning as well as for use on graphs with attributes (if available). This is achieved by learning functions that generalize to new nodes and graphs. We show that our proposed framework is effective with an average AUC improvement of 16.55% while requiring on average 853x less space than existing methods on a variety of graphs.
Inductive Representation Learning in Large Attributed Graphs
Ahmed, Nesreen K., Rossi, Ryan A., Zhou, Rong, Lee, John Boaz, Kong, Xiangnan, Willke, Theodore L., Eldardiry, Hoda
Graphs (networks) are ubiquitous and allow us to model entities (nodes) and the dependencies (edges) between them. Learning a useful feature representation from graph data lies at the heart and success of many machine learning tasks such as classification, anomaly detection, link prediction, among many others. Many existing techniques use random walks as a basis for learning features or estimating the parameters of a graph model for a downstream prediction task. Examples include recent node embedding methods such as DeepWalk, node2vec, as well as graph-based deep learning algorithms. However, the simple random walk used by these methods is fundamentally tied to the identity of the node. This has three main disadvantages. First, these approaches are inherently transductive and do not generalize to unseen nodes and other graphs. Second, they are not space-efficient as a feature vector is learned for each node which is impractical for large graphs. Third, most of these approaches lack support for attributed graphs. To make these methods more generally applicable, we propose a framework for inductive network representation learning based on the notion of attributed random walk that is not tied to node identity and is instead based on learning a function $\Phi : \mathrm{\rm \bf x} \rightarrow w$ that maps a node attribute vector $\mathrm{\rm \bf x}$ to a type $w$. This framework serves as a basis for generalizing existing methods such as DeepWalk, node2vec, and many other previous methods that leverage traditional random walks.