Lehrach, Wolfgang
Improving Transformer World Models for Data-Efficient RL
Dedieu, Antoine, Ortiz, Joseph, Lou, Xinghua, Wendelken, Carter, Lehrach, Wolfgang, Guntupalli, J Swaroop, Lazaro-Gredilla, Miguel, Murphy, Kevin Patrick
We present an approach to model-based RL that achieves a new state of the art performance on the challenging Craftax-classic benchmark, an open-world 2D survival game that requires agents to exhibit a wide range of general abilities -- such as strong generalization, deep exploration, and long-term reasoning. With a series of careful design choices aimed at improving sample efficiency, our MBRL algorithm achieves a reward of 67.4% after only 1M environment steps, significantly outperforming DreamerV3, which achieves 53.2%, and, for the first time, exceeds human performance of 65.0%. Our method starts by constructing a SOTA model-free baseline, using a novel policy architecture that combines CNNs and RNNs. We then add three improvements to the standard MBRL setup: (a) "Dyna with warmup", which trains the policy on real and imaginary data, (b) "nearest neighbor tokenizer" on image patches, which improves the scheme to create the transformer world model (TWM) inputs, and (c) "block teacher forcing", which allows the TWM to reason jointly about the future tokens of the next timestep.
Diffusion Model Predictive Control
Zhou, Guangyao, Swaminathan, Sivaramakrishnan, Raju, Rajkumar Vasudeva, Guntupalli, J. Swaroop, Lehrach, Wolfgang, Ortiz, Joseph, Dedieu, Antoine, Lázaro-Gredilla, Miguel, Murphy, Kevin
We propose Diffusion Model Predictive Control (D-MPC), a novel MPC approach that learns a multi-step action proposal and a multi-step dynamics model, both using diffusion models, and combines them for use in online MPC. On the popular D4RL benchmark, we show performance that is significantly better than existing model-based offline planning methods using MPC and competitive with state-of-the-art (SOTA) model-based and model-free reinforcement learning methods. We additionally illustrate D-MPC's ability to optimize novel reward functions at run time and adapt to novel dynamics, and highlight its advantages compared to existing diffusion-based planning baselines.
DMC-VB: A Benchmark for Representation Learning for Control with Visual Distractors
Ortiz, Joseph, Dedieu, Antoine, Lehrach, Wolfgang, Guntupalli, Swaroop, Wendelken, Carter, Humayun, Ahmad, Zhou, Guangyao, Swaminathan, Sivaramakrishnan, Lázaro-Gredilla, Miguel, Murphy, Kevin
Learning from previously collected data via behavioral cloning or offline reinforcement learning (RL) is a powerful recipe for scaling generalist agents by avoiding the need for expensive online learning. Despite strong generalization in some respects, agents are often remarkably brittle to minor visual variations in control-irrelevant factors such as the background or camera viewpoint. In this paper, we present theDeepMind Control Visual Benchmark (DMC-VB), a dataset collected in the DeepMind Control Suite to evaluate the robustness of offline RL agents for solving continuous control tasks from visual input in the presence of visual distractors. In contrast to prior works, our dataset (a) combines locomotion and navigation tasks of varying difficulties, (b) includes static and dynamic visual variations, (c) considers data generated by policies with different skill levels, (d) systematically returns pairs of state and pixel observation, (e) is an order of magnitude larger, and (f) includes tasks with hidden goals. Accompanying our dataset, we propose three benchmarks to evaluate representation learning methods for pretraining, and carry out experiments on several recently proposed methods. First, we find that pretrained representations do not help policy learning on DMC-VB, and we highlight a large representation gap between policies learned on pixel observations and on states. Second, we demonstrate when expert data is limited, policy learning can benefit from representations pretrained on (a) suboptimal data, and (b) tasks with stochastic hidden goals. Our dataset and benchmark code to train and evaluate agents are available at: https://github.com/google-deepmind/dmc_vision_benchmark.
Learning Cognitive Maps from Transformer Representations for Efficient Planning in Partially Observed Environments
Dedieu, Antoine, Lehrach, Wolfgang, Zhou, Guangyao, George, Dileep, Lázaro-Gredilla, Miguel
Despite their stellar performance on a wide range of tasks, including in-context tasks only revealed during inference, vanilla transformers and variants trained for next-token predictions (a) do not learn an explicit world model of their environment which can be flexibly queried and (b) cannot be used for planning or navigation. In this paper, we consider partially observed environments (POEs), where an agent receives perceptually aliased observations as it navigates, which makes path planning hard. We introduce a transformer with (multiple) discrete bottleneck(s), TDB, whose latent codes learn a compressed representation of the history of observations and actions. After training a TDB to predict the future observation(s) given the history, we extract interpretable cognitive maps of the environment from its active bottleneck(s) indices. These maps are then paired with an external solver to solve (constrained) path planning problems. First, we show that a TDB trained on POEs (a) retains the near perfect predictive performance of a vanilla transformer or an LSTM while (b) solving shortest path problems exponentially faster. Second, a TDB extracts interpretable representations from text datasets, while reaching higher in-context accuracy than vanilla sequence models. Finally, in new POEs, a TDB (a) reaches near-perfect in-context accuracy, (b) learns accurate in-context cognitive maps (c) solves in-context path planning problems.
Query Training: Learning and inference for directed and undirected graphical models
Lázaro-Gredilla, Miguel, Lehrach, Wolfgang, Gothoskar, Nishad, Zhou, Guangyao, Dedieu, Antoine, George, Dileep
Probabilistic graphical models (PGMs) provide a compact representation of knowledge that can be queried in a flexible way: after learning the parameters of a graphical model, new probabilistic queries can be answered at test time without retraining. However, learning undirected graphical models is notoriously hard due to the intractability of the partition function. For directed models, a popular approach is to use variational autoencoders, but there is no systematic way to choose the encoder architecture given the PGM, and the encoder only amortizes inference for a single probabilistic query (i.e., new queries require separate training). We introduce Query Training (QT), a systematic method to turn any PGM structure (directed or not, with or without hidden variables) into a trainable inference network. This single network can approximate any inference query. We demonstrate experimentally that QT can be used to learn a challenging 8-connected grid Markov random field with hidden variables and that it consistently outperforms the state-of-the-art AdVIL when tested on three undirected models across multiple datasets.
Learning higher-order sequential structure with cloned HMMs
Dedieu, Antoine, Gothoskar, Nishad, Swingle, Scott, Lehrach, Wolfgang, Lázaro-Gredilla, Miguel, George, Dileep
Sequence modeling is a fundamental real-world problem with a wide range of applications. Recurrent neural networks (RNNs) are currently preferred in sequence prediction tasks due to their ability to model long-term and variable order dependencies. However, RNNs have disadvantages in several applications because of their inability to natively handle uncertainty, and because of the inscrutable internal representations. Probabilistic sequence models like Hidden Markov Models (HMM) have the advantage of more interpretable representations and the ability to handle uncertainty. Although overcomplete HMMs with many more hidden states compared to the observed states can, in theory, model long-term temporal dependencies [23], training HMMs is challenging due to credit diffusion [3]. For this reason, simpler and inflexible n-gram models are preferred to HMMs for tasks like language modeling. Tensor decomposition methods [1] have been suggested for the learning of HMMs in order to overcome the credit diffusion problem, but current methods are not applicable to the overcomplete setting where the full-rank requirements on the transition and emission matrices are not fulfilled. Recently there has been renewed interest in the topic of training overcomplete HMMs for higher-order dependencies with the expectation that sparsity structures could potentially alleviate the credit diffusion problem [23]. In this paper we demonstrate that a particular sparsity structure on the emission matrix can help HMMs learn higher-order temporal structure using the standard Expectation-Maximization algorithms [26] (Baum-Welch) and its online variants.
Generative Shape Models: Joint Text Recognition and Segmentation with Very Little Training Data
Lou, Xinghua, Kansky, Ken, Lehrach, Wolfgang, Laan, CC, Marthi, Bhaskara, Phoenix, D., George, Dileep
Abstract: We demonstrate that a generative model for object shapes can achieve state of the art results on challenging scene text recognition tasks, and with orders ofmagnitude fewer training images than required for competing discriminative methods.In addition to transcribing text from challenging images, our method performs fine-grained instance segmentation of characters. We show that our model is more robust to both affine transformations and non-affine deformations comparedto previous approaches.