Mitigating Forgetting in LLM Supervised Fine-Tuning and Preference Learning

Fernando, Heshan, Shen, Han, Ram, Parikshit, Zhou, Yi, Samulowitz, Horst, Baracaldo, Nathalie, Chen, Tianyi

arXiv.org Machine Learning 

Post-training of pre-trained LLMs, which typically consists of the supervised finetuning (SFT) stage and the preference learning (RLHF or DPO) stage, is crucial to effective and safe LLM applications. The widely adopted approach in posttraining popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, sequential training is sub-optimal in terms of SFT and RLHF/DPO tradeoff: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. We theoretically prove the sub-optimality of sequential post-training. Furthermore, we propose a practical joint post-training framework with theoretical convergence guarantees and empirically outperforms sequential post-training framework, while having similar computational cost. Recent years have witnessed the great capabilities of large language models (LLMs) trained on a large corpus of datasets (OpenAI, 2022; Dubey et al., 2024; Abdin et al., 2024). These models have been applied to a wide range of tasks including virtual assistant (OpenAI, 2022), code development (Roziere et al., 2023), and education/research (Achiam et al., 2023). Typically LLMs undergo the pre-training phase and the post-training phase.