Kim, Juno
Metastable Dynamics of Chain-of-Thought Reasoning: Provable Benefits of Search, RL and Distillation
Kim, Juno, Wu, Denny, Lee, Jason, Suzuki, Taiji
A key paradigm to improve the reasoning capabilities of large language models (LLMs) is to allocate more inference-time compute to search against a verifier or reward model. This process can then be utilized to refine the pretrained model or distill its reasoning patterns into more efficient models. In this paper, we study inference-time compute by viewing chain-of-thought (CoT) generation as a metastable Markov process: easy reasoning steps (e.g., algebraic manipulations) form densely connected clusters, while hard reasoning steps (e.g., applying a relevant theorem) create sparse, low-probability edges between clusters, leading to phase transitions at longer timescales. Under this framework, we prove that implementing a search protocol that rewards sparse edges improves CoT by decreasing the expected number of steps to reach different clusters. In contrast, we establish a limit on reasoning capability when the model is restricted to local information of the pretrained graph. We also show that the information gained by search can be utilized to obtain a better reasoning model: (1) the pretrained model can be directly finetuned to favor sparse edges via policy gradient methods, and moreover (2) a compressed metastable representation of the reasoning dynamics can be distilled into a smaller, more efficient model.
Optimality and Adaptivity of Deep Neural Features for Instrumental Variable Regression
Kim, Juno, Meunier, Dimitri, Gretton, Arthur, Suzuki, Taiji, Li, Zhu
We provide a convergence analysis of deep feature instrumental variable (DFIV) regression (Xu et al., 2021), a nonparametric approach to IV regression using data-adaptive features learned by deep neural networks in two stages. We prove that the DFIV algorithm achieves the minimax optimal learning rate when the target structural function lies in a Besov space. This is shown under standard nonparametric IV assumptions, and an additional smoothness assumption on the regularity of the conditional distribution of the covariate given the instrument, which controls the difficulty of Stage 1. We further demonstrate that DFIV, as a data-adaptive algorithm, is superior to fixed-feature (kernel or sieve) IV methods in two ways. First, when the target function possesses low spatial homogeneity (i.e., it has both smooth and spiky/discontinuous regions), DFIV still achieves the optimal rate, while fixed-feature methods are shown to be strictly suboptimal. Second, comparing with kernel-based two-stage regression estimators, DFIV is provably more data efficient in the Stage 1 samples.
Transformers Provably Solve Parity Efficiently with Chain of Thought
Kim, Juno, Suzuki, Taiji
This work provides the first theoretical analysis of training transformers to solve complex problems by recursively generating intermediate states, analogous to fine-tuning for chain-of-thought (CoT) reasoning. We consider training a one-layer transformer to solve the fundamental k-parity problem, extending the work on RNNs by Wies et al. (2023). We establish three key results: (1) any finite-precision gradient-based algorithm, without intermediate supervision, requires substantial iterations to solve parity with finite samples. Our findings, supported by numerical experiments, show that task decomposition and stepwise reasoning naturally arise from optimizing transformers with CoT; moreover, self-consistency checking can improve multistep reasoning ability, aligning with empirical studies of CoT. Large language models (LLMs) based on the transformer architecture (Vaswani et al., 2017) have achieved astounding success across a variety of natural language processing and machine learning tasks (see e.g. These failures are particularly evident in tasks requiring multi-hop reasoning or compounded logical steps (Sakarvadia et al., 2024). A promising approach to overcome these limitations is chain-of-thought (CoT) reasoning, where the model is prompted or fine-tuned to solve complex tasks step-by-step by explicitly making intermediate reasoning steps to arrive at the desired answers (Wei et al., 2022; Kojima et al., 2022). Since its discovery, CoT reasoning has been shown to significantly enhance the problem-solving capabilities of LLMs while also increasing the interpretability and trustworthiness of the reasoning process, and has spawned numerous prompting techniques (Liu et al., 2023; Qiao et al., 2023) and applications for a variety of downstream tasks including common-sense reasoning, mathematical problem-solving, and symbolic or multi-modal reasoning; see e.g.
A Novel Nuanced Conversation Evaluation Framework for Large Language Models in Mental Health
Marrapese, Alexander, Suleiman, Basem, Ullah, Imdad, Kim, Juno
Understanding the conversation abilities of Large Language Models (LLMs) can help lead to its more cautious and appropriate deployment. This is especially important for safety-critical domains like mental health, where someone's life may depend on the exact wording of a response to an urgent question. In this paper, we propose a novel framework for evaluating the nuanced conversation abilities of LLMs. Within it, we develop a series of quantitative metrics developed from literature on using psychotherapy conversation analysis literature. While we ensure that our framework and metrics are transferable by researchers to relevant adjacent domains, we apply them to the mental health field. We use our framework to evaluate several popular frontier LLMs, including some GPT and Llama models, through a verified mental health dataset. Our results show that GPT4 Turbo can perform significantly more similarly to verified therapists than other selected LLMs. We conduct additional analysis to examine how LLM conversation performance varies across specific mental health topics. Our results indicate that GPT4 Turbo performs well in achieving high correlation with verified therapists in particular topics such as Parenting and Relationships. We believe our contributions will help researchers develop better LLMs that, in turn, will more positively support people's lives.
Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape
Kim, Juno, Suzuki, Taiji
Large language models based on the Transformer architecture have demonstrated impressive capabilities to learn in context. However, existing theoretical studies on how this phenomenon arises are limited to the dynamics of a single layer of attention trained on linear regression tasks. In this paper, we study the optimization of a Transformer consisting of a fully connected layer followed by a linear attention layer. The MLP acts as a common nonlinear representation or feature map, greatly enhancing the power of in-context learning. We prove in the mean-field and two-timescale limit that the infinite-dimensional loss landscape for the distribution of parameters, while highly nonconvex, becomes quite benign. We also analyze the second-order stability of mean-field dynamics and show that Wasserstein gradient flow almost always avoids saddle points. Furthermore, we establish novel methods for obtaining concrete improvement rates both away from and near critical points. This represents the first saddle point analysis of mean-field dynamics in general and the techniques are of independent interest.
$t^3$-Variational Autoencoder: Learning Heavy-tailed Data with Student's t and Power Divergence
Kim, Juno, Kwon, Jaehyuk, Cho, Mincheol, Lee, Hyunjong, Won, Joong-Ho
The variational autoencoder (VAE) typically employs a standard normal prior as a regularizer for the probabilistic latent encoder. However, the Gaussian tail often decays too quickly to effectively accommodate the encoded points, failing to preserve crucial structures hidden in the data. In this paper, we explore the use of heavy-tailed models to combat over-regularization. Drawing upon insights from information geometry, we propose $t^3$VAE, a modified VAE framework that incorporates Student's t-distributions for the prior, encoder, and decoder. This results in a joint model distribution of a power form which we argue can better fit real-world datasets. We derive a new objective by reformulating the evidence lower bound as joint optimization of KL divergence between two statistical manifolds and replacing with $\gamma$-power divergence, a natural alternative for power families. $t^3$VAE demonstrates superior generation of low-density regions when trained on heavy-tailed synthetic data. Furthermore, we show that $t^3$VAE significantly outperforms other models on CelebA and imbalanced CIFAR-100 datasets.
Symmetric Mean-field Langevin Dynamics for Distributional Minimax Problems
Kim, Juno, Yamamoto, Kakei, Oko, Kazusato, Yang, Zhuoran, Suzuki, Taiji
In this paper, we extend mean-field Langevin dynamics to minimax optimization over probability distributions for the first time with symmetric and provably convergent updates. We propose mean-field Langevin averaged gradient (MFL-AG), a single-loop algorithm that implements gradient descent ascent in the distribution spaces with a novel weighted averaging, and establish average-iterate convergence to the mixed Nash equilibrium. We also study both time and particle discretization regimes and prove a new uniform-in-time propagation of chaos result which accounts for the dependency of the particle interactions on all previous distributions. Furthermore, we propose mean-field Langevin anchored best response (MFL-ABR), a symmetric double-loop algorithm based on best response dynamics with linear last-iterate convergence. Finally, we study applications to zero-sum Markov games and conduct simulations demonstrating long-term optimality.
Hessian Based Smoothing Splines for Manifold Learning
Kim, Juno
We propose a multidimensional smoothing spline algorithm in the context of manifold learning. We generalize the bending energy penalty of thin-plate splines to a quadratic form on the Sobolev space of a flat manifold, based on the Frobenius norm of the Hessian matrix. This leads to a natural definition of smoothing splines on manifolds, which minimizes square error while optimizing a global curvature penalty. The existence and uniqueness of the solution is shown by applying the theory of reproducing kernel Hilbert spaces. The minimizer is expressed as a combination of Green's functions for the biharmonic operator, and 'linear' functions of everywhere vanishing Hessian. Furthermore, we utilize the Hessian estimation procedure from the Hessian Eigenmaps algorithm to approximate the spline loss when the true manifold is unknown. This yields a particularly simple quadratic optimization algorithm for smoothing response values without needing to fit the underlying manifold. Analysis of asymptotic error and robustness are given, as well as discussion of out-of-sample prediction methods and applications.