Goto

Collaborating Authors

 openreview


Prototype Language Models

arXiv.org Machine Learning

Knowing which training examples drive outputs is fundamental to auditing, correcting, and understanding language models, yet for modern LLMs this remains expensive, approximate, and largely post-hoc. Standard language models generate tokens through a dense network pathway, causing training data's influence to be distributed across parameters rather than organized along explicit, traceable components. We introduce a prototype language model architecture, Prototypes for Interpretable Sequence Modeling (PRISM), that forms each prediction via a sparse, non-negative mixture of learned prototypes, trained with clustering objectives that anchor each prototype to coherent neighborhoods of training examples. Across architectures from 130M to 1.6B parameters trained on up to 50B tokens, prototype language models either surpass or remain within 2.5 percentage points on average downstream accuracy of matched dense baselines. We show that sparse prototype structure localizes curvature in the loss landscape, yielding a more tractable Hessian and enabling training data attribution that is ~500x faster than post hoc baselines when consuming equivalent memory. Calibrating linear prototype controllers can improve downstream accuracy by roughly 3 points while tracing those corrections back to training neighborhoods, and targeted prototype suppression can remove model behaviors without finetuning or measurable loss in generation quality.


Topology of Reasoning: Understanding Large Reasoning Models through Reasoning Graph Properties

Neural Information Processing Systems

Recent large-scale reasoning models have achieved state-of-the-art performance on challenging mathematical benchmarks, yet the internal mechanisms underlying their success remain poorly understood. In this work, we introduce the notion of a reasoning graph, extracted by clustering hidden-state representations at each reasoning step, and systematically analyze three key graph-theoretic properties: cyclicity, diameter, and small-world index, across multiple tasks (GSM8K, MATH500, AIME 2024). Our findings reveal that distilled reasoning models (e.g., DeepSeekR1-Distill-Qwen-32B) exhibit significantly more recurrent cycles (about 5 per sample), substantially larger graph diameters, and pronounced small-world characteristics (about 6x) compared to their base counterparts. Notably, these structural advantages grow with task difficulty and model capacity, with cycle detection peaking at the 14B scale and exploration diameter maximized in the 32B variant, correlating positively with accuracy. Furthermore, we show that supervised fine-tuning on an improved dataset systematically expands reasoning graph diameters in tandem with performance gains, offering concrete guidelines for dataset design aimed at boosting reasoning capabilities.


Bigram Subnetworks: Mapping to Next Tokens in Transformer Language Models

Neural Information Processing Systems

In Transformer language models, activation vectors transform from current token embeddings to next token predictions as they pass through the model. To isolate a minimal form of this transformation, we identify language model subnetworks that make bigram predictions, naive next token predictions based only on the current token. We find that bigram subnetworks can be found in fully trained language models up to 1B parameters, and these subnetworks are critical for model performance even when they consist of less than 0.2% of model parameters. Bigram subnetworks are concentrated in the first Transformer MLP layer, and they overlap significantly with subnetworks trained to optimally prune a given model. Mechanistically, the bigram subnetworks often recreate a pattern from the full models where the first layer induces a sharp change that aligns activations with next token predictions rather than current token representations. Our results demonstrate that bigram subnetworks comprise a minimal subset of parameters that are both necessary and sufficient for basic next token predictions in language models, and they help drive the transformation from current to next token activations in the residual stream. These subnetworks can lay a foundation for studying more complex language model circuits by building up from a minimal circuit.1


TabSTAR: ATabular Foundation Model for Tabular Data with Text Fields

Neural Information Processing Systems

While deep learning has achieved remarkable success across many domains, it has historically underperformed on tabular learning tasks, which remain dominated by gradient boosting decision trees. However, recent advancements are paving the way for Tabular Foundation Models, which can leverage real-world knowledge and generalize across diverse datasets, particularly when the data contains free-text. Although incorporating language model capabilities into tabular tasks has been explored, most existing methods utilize static, target-agnostic textual representations, limiting their effectiveness. We introduce TabSTAR: a Tabular Foundation Model with Semantically Target-Aware Representations. TabSTAR is designed to enable transfer learning on tabular data with textual features, with an architecture free of dataset-specific parameters. It unfreezes a pretrained text encoder and takes as input target tokens, which provide the model with the context needed to learn task-specific embeddings. TabSTAR achieves state-of-the-art performance for both medium-and large-sized datasets across known benchmarks of classification tasks with text features, and its pretraining phase exhibits scaling laws in the number of datasets, offering a pathway for further performance improvements.1


Improving the Straight-Through Estimator with Zeroth-Order Information

Neural Information Processing Systems

We study the problem of training neural networks with quantized parameters. Learning low-precision quantized parameters by enabling computation of gradients via the Straight-Through Estimator (STE) can be challenging. While the STE enables back-propagation, which is a first-order method, recent works have explored the use of zeroth-order (ZO) gradient descent for fine-tuning. We note that the STE provides high-quality biased gradients, and ZO gradients are unbiased but can be expensive. We thus propose First-Order-Guided Zeroth-Order Gradient Descent (FOGZO) that reduces STE bias while reducing computations relative to ZO methods. Empirically, we show FOGZO improves the tradeoff between quality and training time in Quantization-Aware Pre-Training. Specifically, versus STE at the same number of iterations, we show a 1-8% accuracy improvement for DeiTTiny/Small, 1-2% accuracy improvement on ResNet 18/50, and 1-22 perplexity point improvement for LLaMA models with up to 0.3 billion parameters. For the same loss, FOGZO yields a 796 reduction in computation versus n-SPSA for a 2-layer MLP on MNIST.


Single-Step Operator Learning for Conditioned Time-Series Diffusion Models

Neural Information Processing Systems

Diffusion models have achieved significant success, yet their application to time series data, particularly with regard to efficient sampling, remains an active area of research. We describe an operator-learning approach for conditioned timeseries diffusion models that gives efficient single-step generation by leveraging insights from the frequency-domain characteristics of both the time-series data and the diffusion process itself. The forward diffusion process induces a structured, frequency-dependent smoothing of the data's probability density function. However, this frequency smoothing is related (e.g., via likelihood function) to easily accessible frequency components of time-series data. This suggests that a module operating in the frequency space of the time-series can, potentially, more effectively learn to reverse the frequency-dependent smoothing of the data distribution induced by the diffusion process. We set up an operator learning task, based on frequency-aware building blocks, which satisfies semigroup properties, while exploiting the structure of time-series data. Evaluations on multiple datasets show that our single-step generation proposal achieves forecasting/imputation results comparable (or superior) to many multi-step diffusion schemes while significantly reducing inference costs.


PointMapPolicy: Structured Point Cloud Processing for Multi-Modal Imitation Learning

Neural Information Processing Systems

Robotic manipulation systems benefit from complementary sensing modalities, where each provides unique environmental information. Point clouds capture detailed geometric structure, while RGB images provide rich semantic context. Current point cloud methods struggle to capture fine-grained detail, especially for complex tasks, which RGB methods lack geometric awareness, which hinders their precision and generalization. We introduce PointMapPolicy, a novel approach that conditions diffusion policies on structured grids of points without downsampling. The resulting data type makes it easier to extract shape and spatial relationships from observations, and can be transformed between reference frames. Yet due to their structure in a regular grid, we enable the use of established computer vision techniques directly to 3D data. Using xLSTM as a backbone, our model efficiently fuses the point maps with RGB data for enhanced multi-modal perception. Through extensive experiments on the RoboCasa, CALVIN benchmarks and real robot evaluations, we demonstrate that our method achieves state-of-the-art performance across diverse manipulation tasks. The overview and demos are available on our project page.


DMol: AHighly Efficient and Chemical Motif-Preserving Molecule Generation Platform

Neural Information Processing Systems

We introduce a new graph diffusion model for small drug molecule generation which simultaneously offers a 10-fold reduction in the number of diffusion steps when compared to existing methods, preservation of small molecule graph motifs via motif compression, and an average 3% improvement in SMILES validity over the DiGress model across all real-world molecule benchmarking datasets. Furthermore, our approach outperforms the state-of-the-art DeFoG method with respect to motif-conservation by roughly 4%, as evidenced by high ChEMBLlikeness, QED and newly introduced shingles distance scores. The key ideas behind the approach are to use a combination of deterministic and random subgraph perturbations, so that the node and edge noise schedules are codependent; to modify the loss function of the training process in order to exploit the deterministic component of the schedule; and, to "compress" a collection of highly relevant carbon ring and other motif structures into supernodes in a way that allows for simple subsequent integration into the molecular scaffold1.


Deeper with Riemannian Geometry: Overcoming Oversmoothing and Oversquashing for Graph Foundation Models

Neural Information Processing Systems

Message Passing Neural Networks (MPNNs) are the building block of graph foundation models, but fundamentally suffer from oversmoothing and oversquashing. There has recently been a surge of interest in fixing both issues. Existing efforts primarily adopt global approaches, which may be beneficial in some regions but detrimental in others, ultimately leading to the suboptimal expressiveness. In this paper, we begin by revisiting oversquashing through a global measure - spectral gap λ- and prove that the increase of λleads to gradient vanishing with respect to the input features, thereby undermining the effectiveness of message passing. Motivated by such theoretical insights, we propose a local approach that adaptively adjusts message passing based on local structures. To achieve this, we connect local Riemannian geometry with MPNNs, and establish a novel nonhomogeneous boundary condition to address both oversquashing and oversmoothing. Building on the Robin condition, we design a GBN network with local bottleneck adjustment, coupled with theoretical guarantees. Extensive experiments on homophilic and heterophilic graphs show the expressiveness of GBN. Furthermore, GBN does not exhibit performance degradation even when the network depth exceeds 256 layers.


How Memory in Optimization Algorithms Implicitly Modifies the Loss

Neural Information Processing Systems

In modern optimization methods used in deep learning, each update depends on the history of previous iterations, often referred to as memory, and this dependence decays fast as the iterates go further into the past. For example, gradient descent with momentum has exponentially decaying memory through exponentially averaged past gradients. We introduce a general technique for identifying a memoryless algorithm that approximates an optimization algorithm with memory. It is obtained by replacing all past iterates in the update by the current one, and then adding a correction term arising from memory (also a function of the current iterate). This correction term can be interpreted as a perturbation of the loss, and the nature of this perturbation can inform how memory implicitly (anti-)regularizes the optimization dynamics. As an application of our theory, we find that Lion does not have the kind of implicit anti-regularization induced by memory that AdamW does, providing a theory-based explanation for Lion's better generalization performance recently documented [13]. Empirical evaluations confirm our theoretical findings.