Cheung, Ka Chun
Parallel Sequence Modeling via Generalized Spatial Propagation Network
Wang, Hongjun, Byeon, Wonmin, Xu, Jiarui, Gu, Jinwei, Cheung, Ka Chun, Wang, Xiaolong, Han, Kai, Kautz, Jan, Liu, Sifei
We present the Generalized Spatial Propagation Network (GSPN), a new attention mechanism optimized for vision tasks that inherently captures 2D spatial structures. Existing attention models, including transformers, linear attention, and state-space models like Mamba, process multi-dimensional data as 1D sequences, compromising spatial coherence and efficiency. GSPN overcomes these limitations by directly operating on spatially coherent image data and forming dense pairwise connections through a line-scan approach. Central to GSPN is the Stability-Context Condition, which ensures stable, context-aware propagation across 2D sequences and reduces the effective sequence length to $\sqrt{N}$ for a square map with N elements, significantly enhancing computational efficiency. With learnable, input-dependent weights and no reliance on positional embeddings, GSPN achieves superior spatial fidelity and state-of-the-art performance in vision tasks, including ImageNet classification, class-guided image generation, and text-to-image generation. Notably, GSPN accelerates SD-XL with softmax-attention by over $84\times$ when generating 16K images.
TCM-FTP: Fine-Tuning Large Language Models for Herbal Prescription Prediction
Zhou, Xingzhi, Dong, Xin, Li, Chunhao, Bai, Yuning, Xu, Yulong, Cheung, Ka Chun, See, Simon, Song, Xinpeng, Zhang, Runshun, Zhou, Xuezhong, Zhang, Nevin L.
Traditional Chinese medicine (TCM) relies on specific combinations of herbs in prescriptions to treat symptoms and signs, a practice that spans thousands of years. Predicting TCM prescriptions presents a fascinating technical challenge with practical implications. However, this task faces limitations due to the scarcity of high-quality clinical datasets and the intricate relationship between symptoms and herbs. To address these issues, we introduce DigestDS, a new dataset containing practical medical records from experienced experts in digestive system diseases. We also propose a method, TCM-FTP (TCM Fine-Tuning Pre-trained), to leverage pre-trained large language models (LLMs) through supervised fine-tuning on DigestDS. Additionally, we enhance computational efficiency using a low-rank adaptation technique. TCM-FTP also incorporates data augmentation by permuting herbs within prescriptions, capitalizing on their order-agnostic properties. Impressively, TCM-FTP achieves an F1-score of 0.8031, surpassing previous methods significantly. Furthermore, it demonstrates remarkable accuracy in dosage prediction, achieving a normalized mean square error of 0.0604. In contrast, LLMs without fine-tuning perform poorly. Although LLMs have shown capabilities on a wide range of tasks, this work illustrates the importance of fine-tuning for TCM prescription prediction, and we have proposed an effective way to do that.
Unlocking Continual Learning Abilities in Language Models
Du, Wenyu, Cheng, Shuang, Luo, Tongxu, Qiu, Zihan, Huang, Zeyu, Cheung, Ka Chun, Cheng, Reynold, Fu, Jie
Language models (LMs) exhibit impressive performance and generalization capabilities. However, LMs struggle with the persistent challenge of catastrophic forgetting, which undermines their long-term sustainability in continual learning (CL). Existing approaches usually address the issue by incorporating old task data or task-wise inductive bias into LMs. However, old data and accurate task information are often unavailable or costly to collect, hindering the availability of current CL approaches for LMs. To address this limitation, we introduce $\textbf{MIGU}$ ($\textbf{M}$agn$\textbf{I}$tude-based $\textbf{G}$radient $\textbf{U}$pdating for continual learning), a rehearsal-free and task-label-free method that only updates the model parameters with large magnitudes of output in LMs' linear layers. MIGU is based on our observation that the L1-normalized magnitude distribution of the output in LMs' linear layers is different when the LM models deal with different task data. By imposing this simple constraint on the gradient update process, we can leverage the inherent behaviors of LMs, thereby unlocking their innate CL abilities. Our experiments demonstrate that MIGU is universally applicable to all three LM architectures (T5, RoBERTa, and Llama2), delivering state-of-the-art or on-par performance across continual finetuning and continual pre-training settings on four CL benchmarks. For example, MIGU brings a 15.2% average accuracy improvement over conventional parameter-efficient finetuning baselines in a 15-task CL benchmark. MIGU can also seamlessly integrate with all three existing CL types to further enhance performance. Code is available at \href{https://github.com/wenyudu/MIGU}{this https URL}.
Resilient Practical Test-Time Adaptation: Soft Batch Normalization Alignment and Entropy-driven Memory Bank
Zhou, Xingzhi, Tian, Zhiliang, Cheung, Ka Chun, See, Simon, Zhang, Nevin L.
Test-time domain adaptation effectively adjusts the source domain model to accommodate unseen domain shifts in a target domain during inference. However, the model performance can be significantly impaired by continuous distribution changes in the target domain and non-independent and identically distributed (non-i.i.d.) test samples often encountered in practical scenarios. While existing memory bank methodologies use memory to store samples and mitigate non-i.i.d. effects, they do not inherently prevent potential model degradation. To address this issue, we propose a resilient practical test-time adaptation (ResiTTA) method focused on parameter resilience and data quality. Specifically, we develop a resilient batch normalization with estimation on normalization statistics and soft alignments to mitigate overfitting and model degradation. We use an entropy-driven memory bank that accounts for timeliness, the persistence of over-confident samples, and sample uncertainty for high-quality data in adaptation. Our framework periodically adapts the source domain model using a teacher-student model through a self-training loss on the memory samples, incorporating soft alignment losses on batch normalization. We empirically validate ResiTTA across various benchmark datasets, demonstrating state-of-the-art performance.
SVD-PINNs: Transfer Learning of Physics-Informed Neural Networks via Singular Value Decomposition
Gao, Yihang, Cheung, Ka Chun, Ng, Michael K.
Physics-informed neural networks (PINNs) have attracted significant attention for solving partial differential equations (PDEs) in recent years because they alleviate the curse of dimensionality that appears in traditional methods. However, the most disadvantage of PINNs is that one neural network corresponds to one PDE. In practice, we usually need to solve a class of PDEs, not just one. With the explosive growth of deep learning, many useful techniques in general deep learning tasks are also suitable for PINNs. Transfer learning methods may reduce the cost for PINNs in solving a class of PDEs. In this paper, we proposed a transfer learning method of PINNs via keeping singular vectors and optimizing singular values (namely SVD-PINNs). Numerical experiments on high dimensional PDEs (10-d linear parabolic equations and 10-d Allen-Cahn equations) show that SVD-PINNs work for solving a class of PDEs with different but close right-hand-side functions.