behrouz
TNT: Improving Chunkwise Training for Test-Time Memorization
Li, Zeman, Behrouz, Ali, Deng, Yuan, Zhong, Peilin, Kacham, Praneeth, Karami, Mahdi, Razaviyayn, Meisam, Mirrokni, Vahab
Recurrent neural networks (RNNs) with deep test-time memorization modules, such as Titans and TTT, represent a promising, linearly-scaling paradigm distinct from Transformers. While these expressive models do not yet match the peak performance of state-of-the-art Transformers, their potential has been largely untapped due to prohibitively slow training and low hardware utilization. Existing parallelization methods force a fundamental conflict governed by the chunksize hyperparameter: large chunks boost speed but degrade performance, necessitating a fixed, suboptimal compromise. To solve this challenge, we introduce TNT, a novel training paradigm that decouples training efficiency from inference performance through a two-stage process. Stage one is an efficiency-focused pre-training phase utilizing a hierarchical memory. A global module processes large, hardware-friendly chunks for long-range context, while multiple parallel local modules handle fine-grained details. Crucially, by periodically resetting local memory states, we break sequential dependencies to enable massive context parallelization. Stage two is a brief fine-tuning phase where only the local memory modules are adapted to a smaller, high-resolution chunksize, maximizing accuracy with minimal overhead. Evaluated on Titans and TTT models, TNT achieves a substantial acceleration in training speed-up to 17 times faster than the most accurate baseline configuration - while simultaneously improving model accuracy. This improvement removes a critical scalability barrier, establishing a practical foundation for developing expressive RNNs and facilitating future work to close the performance gap with Transformers.
Hydra: Dual Exponentiated Memory for Multivariate Time Series Analysis
Meskin, Asal, Mirrokni, Alireza, Najar, Ali, Behrouz, Ali
In recent years, effectively modeling multivariate time series has gained significant popularity, mainly due to its wide range of applications, ranging from healthcare to financial markets and energy management. Transformers, MLPs, and linear models as the de facto backbones of modern time series models have shown promising results in single-variant and/or short-term forecasting. These models, however: (1) are permutation equivariant and so lack temporal inductive bias, being less expressive to capture the temporal dynamics; (2) are naturally designed for univariate setup, missing the inter-dependencies of temporal and variate dimensions; and/or (3) are inefficient for Long-term time series modeling. To overcome training and inference efficiency as well as the lack of temporal inductive bias, recently, linear Recurrent Neural Networks (RNNs) have gained attention as an alternative to Transformer-based models. These models, however, are inherently limited to a single sequence, missing inter-variate dependencies, and can propagate errors due to their additive nature. In this paper, we present Hydra, a by-design two-headed meta in-context memory module that learns how to memorize patterns at test time by prioritizing time series patterns that are more informative about the data. Hydra uses a 2-dimensional recurrence across both time and variate at each step, which is more powerful than mixing methods. Although the 2-dimensional nature of the model makes its training recurrent and non-parallelizable, we present a new 2D-chunk-wise training algorithm that approximates the actual recurrence with $\times 10$ efficiency improvement, while maintaining the effectiveness. Our experimental results on a diverse set of tasks and datasets, including time series forecasting, classification, and anomaly detection show the superior performance of Hydra compared to state-of-the-art baselines.
ATLAS: Learning to Optimally Memorize the Context at Test Time
Behrouz, Ali, Li, Zeman, Kacham, Praneeth, Daliri, Majid, Deng, Yuan, Zhong, Peilin, Razaviyayn, Meisam, Mirrokni, Vahab
Transformers have been established as the most popular backbones in sequence modeling, mainly due to their effectiveness in in-context retrieval tasks and the ability to learn at scale. Their quadratic memory and time complexity, however, bound their applicability in longer sequences and so has motivated researchers to explore effective alternative architectures such as modern recurrent neural networks (a.k.a long-term recurrent memory module). Despite their recent success in diverse downstream tasks, they struggle in tasks that requires long context understanding and extrapolation to longer sequences. We observe that these shortcomings come from three disjoint aspects in their design: (1) limited memory capacity that is bounded by the architecture of memory and feature mapping of the input; (2) online nature of update, i.e., optimizing the memory only with respect to the last input; and (3) less expressive management of their fixed-size memory. To enhance all these three aspects, we present ATLAS, a long-term memory module with high capacity that learns to memorize the context by optimizing the memory based on the current and past tokens, overcoming the online nature of long-term memory models. Building on this insight, we present a new family of Transformer-like architectures, called DeepTransformers, that are strict generalizations of the original Transformer architecture. Our experimental results on language modeling, common-sense reasoning, recall-intensive, and long-context understanding tasks show that ATLAS surpasses the performance of Transformers and recent linear recurrent models. ATLAS further improves the long context performance of Titans, achieving +80\% accuracy in 10M context length of BABILong benchmark.
Chimera: Effectively Modeling Multivariate Time Series with 2-Dimensional State Space Models
Behrouz, Ali, Santacatterina, Michele, Zabih, Ramin
Modeling multivariate time series is a well-established problem with a wide range of applications from healthcare to financial markets. Traditional State Space Models (SSMs) are classical approaches for univariate time series modeling due to their simplicity and expressive power to represent linear dependencies. They, however, have fundamentally limited expressive power to capture non-linear dependencies, are slow in practice, and fail to model the inter-variate information flow. Despite recent attempts to improve the expressive power of SSMs by using deep structured SSMs, the existing methods are either limited to univariate time series, fail to model complex patterns (e.g., seasonal patterns), fail to dynamically model the dependencies of variate and time dimensions, and/or are input-independent. We present Chimera that uses two input-dependent 2-D SSM heads with different discretization processes to learn long-term progression and seasonal patterns. To improve the efficiency of complex 2D recurrence, we present a fast training using a new 2-dimensional parallel selective scan. We further present and discuss 2-dimensional Mamba and Mamba-2 as the spacial cases of our 2D SSM. Our experimental evaluation shows the superior performance of Chimera on extensive and diverse benchmarks, including ECG and speech time series classification, long-term and short-term time series forecasting, and time series anomaly detection.