Shen, Han
Fundamental Safety-Capability Trade-offs in Fine-tuning Large Language Models
Chen, Pin-Yu, Shen, Han, Das, Payel, Chen, Tianyi
Fine-tuning Large Language Models (LLMs) on some task-specific datasets has been a primary use of LLMs. However, it has been empirically observed that this approach to enhancing capability inevitably compromises safety, a phenomenon also known as the safety-capability trade-off in LLM fine-tuning. This paper presents a theoretical framework for understanding the interplay between safety and capability in two primary safety-aware LLM fine-tuning strategies, providing new insights into the effects of data similarity, context overlap, and alignment loss landscape. Our theoretical results characterize the fundamental limits of the safety-capability trade-off in LLM fine-tuning, which are also validated by numerical experiments.
Mitigating Forgetting in LLM Supervised Fine-Tuning and Preference Learning
Fernando, Heshan, Shen, Han, Ram, Parikshit, Zhou, Yi, Samulowitz, Horst, Baracaldo, Nathalie, Chen, Tianyi
Post-training of pre-trained LLMs, which typically consists of the supervised finetuning (SFT) stage and the preference learning (RLHF or DPO) stage, is crucial to effective and safe LLM applications. The widely adopted approach in posttraining popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, sequential training is sub-optimal in terms of SFT and RLHF/DPO tradeoff: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. We theoretically prove the sub-optimality of sequential post-training. Furthermore, we propose a practical joint post-training framework with theoretical convergence guarantees and empirically outperforms sequential post-training framework, while having similar computational cost. Recent years have witnessed the great capabilities of large language models (LLMs) trained on a large corpus of datasets (OpenAI, 2022; Dubey et al., 2024; Abdin et al., 2024). These models have been applied to a wide range of tasks including virtual assistant (OpenAI, 2022), code development (Roziere et al., 2023), and education/research (Achiam et al., 2023). Typically LLMs undergo the pre-training phase and the post-training phase.
SEAL: Safety-enhanced Aligned LLM Fine-tuning via Bilevel Data Selection
Shen, Han, Chen, Pin-Yu, Das, Payel, Chen, Tianyi
Fine-tuning on task-specific data to boost downstream performance is a crucial step for leveraging Large Language Models (LLMs). However, previous studies have demonstrated that fine-tuning the models on several adversarial samples or even benign data can greatly comprise the model's pre-equipped alignment and safety capabilities. In this work, we propose SEAL, a novel framework to enhance safety in LLM fine-tuning. SEAL learns a data ranker based on the bilevel optimization to up rank the safe and high-quality fine-tuning data and down rank the unsafe or low-quality ones. Models trained with SEAL demonstrate superior quality over multiple baselines, with 8.5% and 9.7% win rate increase compared to random selection respectively on Llama-3-8b-Instruct and Merlinite-7b models. Our code is available on github https://github.com/hanshen95/SEAL.
Principled Penalty-based Methods for Bilevel Reinforcement Learning and RLHF
Shen, Han, Yang, Zhuoran, Chen, Tianyi
Bilevel optimization has been recently applied to many machine learning tasks. However, their applications have been restricted to the supervised learning setting, where static objective functions with benign structures are considered. But bilevel problems such as incentive design, inverse reinforcement learning (RL), and RL from human feedback (RLHF) are often modeled as dynamic objective functions that go beyond the simple static objective structures, which pose significant challenges of using existing bilevel solutions. To tackle this new class of bilevel problems, we introduce the first principled algorithmic framework for solving bilevel RL problems through the lens of penalty formulation. We provide theoretical studies of the problem landscape and its penalty-based (policy) gradient algorithms. We demonstrate the effectiveness of our algorithms via simulations in the Stackelberg Markov game, RL from human feedback and incentive design.
Joint Unsupervised and Supervised Training for Automatic Speech Recognition via Bilevel Optimization
Saif, A F M, Cui, Xiaodong, Shen, Han, Lu, Songtao, Kingsbury, Brian, Chen, Tianyi
BL-JUST employs a lower and upper level optimization In general, bilevel optimization problems are optimization problems with an unsupervised loss and a supervised loss respectively, where the feasible set is determined (in part) using the solution leveraging recent advances in penalty-based bilevel optimization to set of a second optimization problem [10]. Determining the feasible solve this challenging ASR problem with affordable complexity and set is generally called the lower-level problem and the second parametric rigorous convergence guarantees. To evaluate BL-JUST, extensive optimization problem is called the upper-level problem [31, 29].
On Penalty-based Bilevel Gradient Descent Method
Shen, Han, Xiao, Quan, Chen, Tianyi
Bilevel optimization enjoys a wide range of applications in hyper-parameter optimization, meta-learning and reinforcement learning. However, bilevel optimization problems are difficult to solve. Recent progress on scalable bilevel algorithms mainly focuses on bilevel optimization problems where the lower-level objective is either strongly convex or unconstrained. In this work, we tackle the bilevel problem through the lens of the penalty method. We show that under certain conditions, the penalty reformulation recovers the solutions of the original bilevel problem. Further, we propose the penalty-based bilevel gradient descent (PBGD) algorithm and establish its finite-time convergence for the constrained bilevel problem without lower-level strong convexity. Experiments showcase the efficiency of the proposed PBGD algorithm.
Alternating Implicit Projected SGD and Its Efficient Variants for Equality-constrained Bilevel Optimization
Xiao, Quan, Shen, Han, Yin, Wotao, Chen, Tianyi
Stochastic bilevel optimization, which captures the inherent nested structure of machine learning problems, is gaining popularity in many recent applications. Existing works on bilevel optimization mostly consider either unconstrained problems or constrained upper-level problems. This paper considers the stochastic bilevel optimization problems with equality constraints both in the upper and lower levels. By leveraging the special structure of the equality constraints problem, the paper first presents an alternating implicit projected SGD approach and establishes the $\tilde{\cal O}(\epsilon^{-2})$ sample complexity that matches the state-of-the-art complexity of ALSET \citep{chen2021closing} for unconstrained bilevel problems. To further save the cost of projection, the paper presents two alternating implicit projection-efficient SGD approaches, where one algorithm enjoys the $\tilde{\cal O}(\epsilon^{-2}/T)$ upper-level and $\tilde{\cal O}(\epsilon^{-1.5}/T^{\frac{3}{4}})$ lower-level projection complexity with ${\cal O}(T)$ lower-level batch size, and the other one enjoys $\tilde{\cal O}(\epsilon^{-1.5})$ upper-level and lower-level projection complexity with ${\cal O}(1)$ batch size. Application to federated bilevel optimization has been presented to showcase the empirical performance of our algorithms. Our results demonstrate that equality-constrained bilevel optimization with strongly-convex lower-level problems can be solved as efficiently as stochastic single-level optimization problems.
A Single-Timescale Analysis For Stochastic Approximation With Multiple Coupled Sequences
Shen, Han, Chen, Tianyi
Stochastic approximation (SA) with multiple coupled sequences has found broad applications in machine learning such as bilevel learning and reinforcement learning (RL). In this paper, we study the finite-time convergence of nonlinear SA with multiple coupled sequences. Different from existing multi-timescale analysis, we seek for scenarios where a fine-grained analysis can provide the tight performance guarantee for multi-sequence single-timescale SA (STSA). At the heart of our analysis is the smoothness property of the fixed points in multi-sequence SA that holds in many applications. When all sequences have strongly monotone increments, we establish the iteration complexity of $\mathcal{O}(\epsilon^{-1})$ to achieve $\epsilon$-accuracy, which improves the existing $\mathcal{O}(\epsilon^{-1.5})$ complexity for two coupled sequences. When all but the main sequence have strongly monotone increments, we establish the iteration complexity of $\mathcal{O}(\epsilon^{-2})$. The merit of our results lies in that applying them to stochastic bilevel and compositional optimization problems, as well as RL problems leads to either relaxed assumptions or improvements over their existing performance guarantees.
Towards Understanding Asynchronous Advantage Actor-critic: Convergence and Linear Speedup
Shen, Han, Zhang, Kaiqing, Hong, Mingyi, Chen, Tianyi
Asynchronous and parallel implementation of standard reinforcement learning (RL) algorithms is a key enabler of the tremendous success of modern RL. Among many asynchronous RL algorithms, arguably the most popular and effective one is the asynchronous advantage actor-critic (A3C) algorithm. Although A3C is becoming the workhorse of RL, its theoretical properties are still not well-understood, including its non-asymptotic analysis and the performance gain of parallelism (a.k.a. linear speedup). This paper revisits the A3C algorithm and establishes its non-asymptotic convergence guarantees. Under both i.i.d. and Markovian sampling, we establish the local convergence guarantee for A3C in the general policy approximation case and the global convergence guarantee in softmax policy parameterization. Under i.i.d. sampling, A3C obtains sample complexity of $\mathcal{O}(\epsilon^{-2.5}/N)$ per worker to achieve $\epsilon$ accuracy, where $N$ is the number of workers. Compared to the best-known sample complexity of $\mathcal{O}(\epsilon^{-2.5})$ for two-timescale AC, A3C achieves \emph{linear speedup}, which justifies the advantage of parallelism and asynchrony in AC algorithms theoretically for the first time. Numerical tests on synthetic environment, OpenAI Gym environments and Atari games have been provided to verify our theoretical analysis.