Tao, Zhixu
Task Arithmetic Through The Lens Of One-Shot Federated Learning
Tao, Zhixu, Mason, Ian, Kulkarni, Sanjeev, Boix, Xavier
Task Arithmetic is a model merging technique that enables the combination of multiple models' capabilities into a single model through simple arithmetic in the weight space, without the need for additional fine-tuning or access to the original training data. However, the factors that determine the success of Task Arithmetic remain unclear. In this paper, we examine Task Arithmetic for multi-task learning by framing it as a one-shot Federated Learning problem. We demonstrate that Task Arithmetic is mathematically equivalent to the commonly used algorithm in Federated Learning, called Federated Averaging (FedAvg). By leveraging well-established theoretical results from FedAvg, we identify two key factors that impact the performance of Task Arithmetic: data heterogeneity and training heterogeneity. To mitigate these challenges, we adapt several algorithms from Federated Learning to improve the effectiveness of Task Arithmetic. Our experiments demonstrate that applying these algorithms can often significantly boost performance of the merged model compared to the original Task Arithmetic approach. This work bridges Task Arithmetic and Federated Learning, offering new theoretical perspectives on Task Arithmetic and improved practical methodologies for model merging.
Byzantine-Robust Clustered Federated Learning
Tao, Zhixu, Yang, Kun, Kulkarni, Sanjeev R.
This paper focuses on the problem of adversarial attacks from Byzantine machines in a Federated Learning setting where non-Byzantine machines can be partitioned into disjoint clusters. In this setting, non-Byzantine machines in the same cluster have the same underlying data distribution, and different clusters of non-Byzantine machines have different learning tasks. Byzantine machines can adversarially attack any cluster and disturb the training process on clusters they attack. In the presence of Byzantine machines, the goal of our work is to identify cluster membership of non-Byzantine machines and optimize the models learned by each cluster. We adopt the Iterative Federated Clustering Algorithm (IFCA) framework of Ghosh et al. (2020) to alternatively estimate cluster membership and optimize models. In order to make this framework robust against adversarial attacks from Byzantine machines, we use coordinate-wise trimmed mean and coordinate-wise median aggregation methods used by Yin et al. (2018). Specifically, we propose a new Byzantine-Robust Iterative Federated Clustering Algorithm to improve on the results in Ghosh et al. (2019). We prove a convergence rate for this algorithm for strongly convex loss functions. We compare our convergence rate with the convergence rate of an existing algorithm, and we demonstrate the performance of our algorithm on simulated data.