A Good Start Matters: Enhancing Continual Learning with Data-Driven Weight Initialization

Harun, Md Yousuf, Kanan, Christopher

arXiv.org Artificial Intelligence 

To adapt to real-world data streams, continual learning (CL) systems must rapidly learn new concepts while preserving and utilizing prior knowledge. When it comes to adding new information to continually-trained deep neural networks (DNNs), classifier weights for newly encountered categories are typically initialized randomly, leading to high initial training loss (spikes) and instability. Consequently, achieving optimal convergence and accuracy requires prolonged training, increasing computational costs. Inspired by Neural Collapse (NC), we propose a weight initialization strategy to improve learning efficiency in CL. In DNNs trained with mean-squared-error, NC gives rise to a Least-Square (LS) classifier in the last layer, whose weights can be analytically derived from learned features. Our method mitigates initial loss spikes and accelerates adaptation to new tasks. We evaluate our approach in large-scale CL settings, demonstrating faster adaptation and improved CL performance. Deep learning models excel in static environments where the data follows an independent and identically distributed (IID) assumption. However, in real-world scenarios, data distributions shift over time (non-IID), and new data arrives sequentially. Conventional deep neural networks (DNNs) struggle under such conditions, often requiring periodic re-training from scratch, which is not only computationally expensive but also contributes significantly to the carbon footprint of AI (Schwartz et al., 2020). Despite frequent retraining from scratch, real-world models still suffer up to 40% accuracy drops (Mallick et al., 2022). Continual learning (CL) aims to address this inefficiency by enabling models to learn from evolving data streams while preserving previously acquired knowledge (Parisi et al., 2019). CL is a promising solution to model decay, where predictive performance deteriorates over time due to concept drift--a shift in the meaning or distribution of target variables (Tsymbal, 2004; Gama et al., 2014; Lu et al., 2018).