Purandare, Sanket
SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile
Zhang, Ruisi, Liu, Tianyu, Feng, Will, Gu, Andrew, Purandare, Sanket, Liang, Wanchao, Massa, Francisco
Distributed training of large models consumes enormous computation resources and requires substantial engineering efforts to compose various training techniques. This paper presents SimpleFSDP, a PyTorch-native compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. SimpleFSDP's novelty lies in its unique $torch.compile$-friendly implementation of collective communications using existing PyTorch primitives, namely parametrizations, selective activation checkpointing, and DTensor. It also features the first-of-its-kind intermediate representation (IR) nodes bucketing and reordering in the TorchInductor backend for effective computation-communication overlapping. As a result, users can employ the aforementioned optimizations to automatically or manually wrap model components for minimal communication exposure. Extensive evaluations of SimpleFSDP on Llama 3 models (including the ultra-large 405B) using TorchTitan demonstrate up to 28.54% memory reduction and 68.67% throughput improvement compared to the most widely adopted FSDP2 eager framework, when composed with other distributed training techniques.
TorchTitan: One-stop PyTorch native solution for production ready LLM pre-training
Liang, Wanchao, Liu, Tianyu, Wright, Less, Constable, Will, Gu, Andrew, Huang, Chien-Chin, Zhang, Iris, Feng, Wei, Huang, Howard, Wang, Junjie, Purandare, Sanket, Nadathur, Gokul, Idreos, Stratos
The development of large language models (LLMs) has been instrumental in advancing state-of-the-art natural language processing applications. Training LLMs with billions of parameters and trillions of tokens require sophisticated distributed systems that enable composing and comparing several state-of-the-art techniques in order to efficiently scale across thousands of accelerators. However, existing solutions are complex, scattered across multiple libraries/repositories, lack interoperability, and are cumbersome to maintain. Thus, curating and empirically comparing training recipes require non-trivial engineering effort. This paper introduces TorchTitan, an open-source, PyTorch-native distributed training system that unifies state-of-the-art techniques, streamlining integration and reducing overhead. TorchTitan enables 3D parallelism in a modular manner with elastic scaling, providing comprehensive logging, checkpointing, and debugging tools for production-ready training. It also incorporates hardware-software co-designed solutions, leveraging features like Float8 training and SymmetricMemory. As a flexible test bed, TorchTitan facilitates custom recipe curation and comparison, allowing us to develop optimized training recipes for Llama 3.1 and provide guidance on selecting techniques for maximum efficiency based on our experiences. We thoroughly assess TorchTitan on the Llama 3.1 family of LLMs, spanning 8 billion to 405 billion parameters, and showcase its exceptional performance, modular composability, and elastic scalability. By stacking training optimizations, we demonstrate accelerations of 65.08% with 1D parallelism at the 128-GPU scale (Llama 3.1 8B), an additional 12.59% with 2D parallelism at the 256-GPU scale (Llama 3.1 70B), and an additional 30% with 3D parallelism at the 512-GPU scale (Llama 3.1 405B) on NVIDIA H100 GPUs over optimized baselines.
Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond
Oncescu, Costin-Andrei, Purandare, Sanket, Idreos, Stratos, Kakade, Sham
A lot of recent progress in deep learning, particularly in the form of large language models (LLMs) has been driven by the transformer architecture [Vaswani et al., 2017]. While these models have great quality, it comes at a computation cost which scales quadratically in sequence length - both during training and inference. This can become prohibitive for very long contexts and as such a number of alternative architectures with better computational scaling in context length have been proposed [Gu and Dao, 2023, Poli et al., 2023, Fu et al., 2024]. While most of these works have improved computational efficiency for training, some still scale quadratically in sequence length when it comes to inference, thus not improving asymptotically over transformers. In this work, we propose a framework for optimizing inference efficiency for a general class of such models. As a case study, which inspired the method, we focus on long convolution sequence models (LCSMs) [Poli et al., 2023, Fu et al., 2022, Romero et al., 2021, Li et al., 2022, Karami and Ghodsi, 2024, Fu et al., 2023a]. However, our approach is not limited to LCSMs alone and we identify the properties that allow for such inference speedups in hope to guide the design of future architectures. In the particular case of LCSMs (including Hyena), the building block of the architecture is that of convolving the input sequence with a sequence-length long, (potentially underparameterized) filter. If we let L be the sequence length (e.g.