GAMformer: In-Context Learning for Generalized Additive Models

Mueller, Andreas, Siems, Julien, Nori, Harsha, Salinas, David, Zela, Arber, Caruana, Rich, Hutter, Frank

arXiv.org Machine Learning 

Generalized Additive Models (GAMs) are widely recognized for their ability to create fully interpretable machine learning models for tabular data. Traditionally, training GAMs involves iterative learning algorithms, such as splines, boosted trees, or neural networks, which refine the additive components through repeated error reduction. In this paper, we introduce GAMformer, the first method to leverage in-context learning to estimate shape functions of a GAM in a single forward pass, representing a significant departure from the conventional iterative approaches to GAM fitting. Building on previous research applying in-context learning to tabular data, we exclusively use complex, synthetic data to train GAMformer, yet find it extrapolates well to real-world data. Our experiments show that GAMformer performs on par with other leading GAMs across various classification benchmarks while generating highly interpretable shape functions. The growing importance of interpretability in machine learning is evident, especially in areas where transparency, fairness, and accountability are critical (Barocas and Selbst, 2016; Rudin et al., 2022). Interpretable models are essential for building trust between humans and AI systems by allowing users to understand the reasoning behind the model's predictions and decisions (Ribeiro et al., 2016). This is crucial in safety-critical fields like healthcare, where incorrect or biased decisions can have severe consequences (Caruana et al., 2015). Additionally, interpretability is vital for regulatory compliance in sectors like finance and hiring, where explaining and justifying model outcomes is necessary (Arun et al., 2016; Dattner et al., 2019). Interpretable models also help detect and mitigate bias by revealing the factors influencing predictions, ensuring fair and unbiased decisions across different population groups (Mehrabi et al., 2021). Generalized Additive Models (GAMs) have proven a popular choice for interpretable modeling due to their high accuracy and interpretability. In GAMs, the target variable is expressed as a sum of non-linearly transformed features.