KOPPA: Improving Prompt-based Continual Learning with Key-Query Orthogonal Projection and Prototype-based One-Versus-All

Tran, Quyen, Tran, Lam, Than, Khoat, Tran, Toan, Phung, Dinh, Le, Trung

arXiv.org Artificial Intelligence 

Drawing inspiration from prompt tuning techniques applied to Large Language Models, recent methods based on pre-trained ViT networks have achieved remarkable results in the field of Continual Learning. Specifically, these approaches propose to maintain a set of prompts and allocate a subset of them to learn each task using a key-query matching strategy. However, they may encounter limitations when lacking control over the correlations between old task queries and keys of future tasks, the shift of features in the latent space, and the relative separation of latent vectors learned in independent tasks. In this work, we introduce a novel key-query learning strategy based on orthogonal projection, inspired by model-agnostic meta-learning, to enhance prompt matching efficiency and address the challenge of shifting features. Furthermore, we introduce a One-Versus-All (OVA) prototype-based component that enhances the classification head distinction. Experimental results on benchmark datasets demonstrate that our method empowers the model to achieve results surpassing those of current state-of-the-art approaches by a large margin of up to 20%. Our code is available at https://anonymous.4open.science/r/KOPPA/README.md. Continual Learning (CL) is an evolving field in machine learning, aiming to enable models to learn continuously from a sequence of tasks with varying data distributions. A challenging CL scenario is Class Incremental Learning (CIL), where a model sequentially learns new categories and must classify all seen classes without task-ID information, leading to a fundamental issue in CL known as Catastrophic Forgetting (CF) (French, 1999), where performance on earlier tasks degrades due to the absence of old task data and differences in data distributions. In CIL, models are required to classify test samples without prior knowledge of their task IDs.