Speculative Knowledge Distillation: Bridging the Teacher-Student Gap Through Interleaved Sampling

Xu, Wenda, Han, Rujun, Wang, Zifeng, Le, Long T., Madeka, Dhruv, Li, Lei, Wang, William Yang, Agarwal, Rishabh, Lee, Chen-Yu, Pfister, Tomas

arXiv.org Artificial Intelligence 

Recent advances in knowledge distillation (KD) have enabled smaller student models to approach the performance of larger teacher models. However, popular methods such as supervised KD and on-policy KD, are adversely impacted by the knowledge gaps between teacher-student in practical scenarios. Supervised KD suffers from a distribution mismatch between training with a static dataset and inference over final student-generated outputs. Conversely, on-policy KD, which uses student-generated samples for training, can suffer from low-quality training examples with which teacher models are not familiar, resulting in inaccurate teacher feedback. To address these limitations, we introduce Speculative Knowledge Distillation (SKD), a novel approach that leverages cooperation between student and teacher models to generate high-quality training data on-the-fly while aligning with the student's inference-time distribution. In SKD, the student proposes tokens, and the teacher replaces poorly ranked ones based on its own distribution, transferring high-quality knowledge adaptively. We evaluate SKD on various text generation tasks, including translation, summarization, math, and instruction following, and show that SKD consistently outperforms existing KD methods across different domains, data sizes, and model initialization strategies. Figure 1: SKD outperforms supervised and on-policy KD for our tested tasks: Assamese-to-English translation, dialogue summarization, and arithmetic reasoning. Supervised KD is trained on ground-truth outputs, while on-policy KD uses self-generated data. All models use greedy decoding for evaluation. Work done as a student researcher at Google Cloud AI Research. Left: SKD addresses the limitations of on-policy knowledge distillation (KD) by filtering out low-quality student samples and replacing them with teacher generated tokens. However, the substantial inference-time costs and memory footprint associated with LLMs present significant challenges for practical deployment (Agarwal et al., 2024). Therefore, compressing LLMs while maintaining their performance is crucial for real-time practical applications. Knowledge Distillation (KD) (Hinton et al., 2015) is a widely used method to compress LLMs by transferring knowledge from a larger teacher model to a smaller student model. Traditional KD approaches, such as supervised KD (Sanh et al., 2020) and SeqKD (Kim & Rush, 2016b), rely on a static dataset of outputs to train the student model. However, this fixed dataset can lead to a distribution mismatch between the training data and the student's generated samples at inference time, hindering the student's learning.