Towards a Theoretical Understanding of the 'Reversal Curse' via Training Dynamics
Zhu, Hanlin, Huang, Baihe, Zhang, Shaolun, Jordan, Michael, Jiao, Jiantao, Tian, Yuandong, Russell, Stuart
–arXiv.org Artificial Intelligence
Reversal curse (Berglund et al., 2023) refers to the phenomenon that an auto-regressive LLM that learns "A is B" during training fails to generalize to the reverse direction "B is A", and this task is also termed as "inverse search" in Allen-Zhu and Li (2023). Although some previous works propose different methods to mitigate the reversal curse, including reversing the training dataset (Guo et al., 2024; Golovneva et al., 2024) and training on different objectives such as autoregressive blank infilling (Lv et al., 2023), these methods might negatively affect the model performance on other tasks since they either alter the dataset or the model architecture. Without dataset manipulation or changing the auto-regressive nature (causal structure) of the model, the reversal curse is hard to mitigate even with ICL strategies such as chain-of-thought (Allen-Zhu and Li, 2023; Guo et al., 2024). In this paper, we aim to theoretically study why the reversal curse happens for auto-regressive LLMs. Different from previous work that studies the capacity of (transformer-based (Vaswani et al., 2017)) LLMs through the lens of expressivity (e.g., Yun et al. (2019); Pérez et al. (2021); Feng et al. (2024)), reversal curse cannot be explained by expressivity since a model can express "A is B" is also able to express "B is A". Therefore, we analyze the reversal curse via training dynamics since even if a set of parameters can express a fact in both directions, it might not be reachable through popular training algorithms (e.g., gradient descent, AdamW (Loshchilov and Hutter, 2017)) with training data only presented in one direction. We summarize our main contributions as follows: We theoretically analyze reversal curse where training or test sequences have the from "A B" or "B A" via training dynamics of (stochastic) gradient descent under two auto-regressive models: a bilinear model (Section 3) and one-layer transformers under certain assumptions similar to Tian et al. (2023a) (Section 4). The analysis of the training dynamics of both models reveals a core reason why the reversal curse happens: the weights of the autoregressive models are asymmetric, i.e., the increase of weights from the token A to token B
arXiv.org Artificial Intelligence
May-7-2024