Lázaro-Gredilla, Miguel
Diffusion Model Predictive Control
Zhou, Guangyao, Swaminathan, Sivaramakrishnan, Raju, Rajkumar Vasudeva, Guntupalli, J. Swaroop, Lehrach, Wolfgang, Ortiz, Joseph, Dedieu, Antoine, Lázaro-Gredilla, Miguel, Murphy, Kevin
We propose Diffusion Model Predictive Control (D-MPC), a novel MPC approach that learns a multi-step action proposal and a multi-step dynamics model, both using diffusion models, and combines them for use in online MPC. On the popular D4RL benchmark, we show performance that is significantly better than existing model-based offline planning methods using MPC and competitive with state-of-the-art (SOTA) model-based and model-free reinforcement learning methods. We additionally illustrate D-MPC's ability to optimize novel reward functions at run time and adapt to novel dynamics, and highlight its advantages compared to existing diffusion-based planning baselines.
DMC-VB: A Benchmark for Representation Learning for Control with Visual Distractors
Ortiz, Joseph, Dedieu, Antoine, Lehrach, Wolfgang, Guntupalli, Swaroop, Wendelken, Carter, Humayun, Ahmad, Zhou, Guangyao, Swaminathan, Sivaramakrishnan, Lázaro-Gredilla, Miguel, Murphy, Kevin
Learning from previously collected data via behavioral cloning or offline reinforcement learning (RL) is a powerful recipe for scaling generalist agents by avoiding the need for expensive online learning. Despite strong generalization in some respects, agents are often remarkably brittle to minor visual variations in control-irrelevant factors such as the background or camera viewpoint. In this paper, we present theDeepMind Control Visual Benchmark (DMC-VB), a dataset collected in the DeepMind Control Suite to evaluate the robustness of offline RL agents for solving continuous control tasks from visual input in the presence of visual distractors. In contrast to prior works, our dataset (a) combines locomotion and navigation tasks of varying difficulties, (b) includes static and dynamic visual variations, (c) considers data generated by policies with different skill levels, (d) systematically returns pairs of state and pixel observation, (e) is an order of magnitude larger, and (f) includes tasks with hidden goals. Accompanying our dataset, we propose three benchmarks to evaluate representation learning methods for pretraining, and carry out experiments on several recently proposed methods. First, we find that pretrained representations do not help policy learning on DMC-VB, and we highlight a large representation gap between policies learned on pixel observations and on states. Second, we demonstrate when expert data is limited, policy learning can benefit from representations pretrained on (a) suboptimal data, and (b) tasks with stochastic hidden goals. Our dataset and benchmark code to train and evaluate agents are available at: https://github.com/google-deepmind/dmc_vision_benchmark.
What type of inference is planning?
Lázaro-Gredilla, Miguel, Ku, Li Yang, Murphy, Kevin P., George, Dileep
Multiple types of inference are available for probabilistic graphical models, e.g., marginal, maximum-a-posteriori, and even marginal maximum-a-posteriori. Which one do researchers mean when they talk about "planning as inference"? There is no consistency in the literature, different types are used, and their ability to do planning is further entangled with specific approximations or additional constraints. In this work we use the variational framework to show that all commonly used types of inference correspond to different weightings of the entropy terms in the variational problem, and that planning corresponds _exactly_ to a _different_ set of weights. This means that all the tricks of variational inference are readily applicable to planning. We develop an analogue of loopy belief propagation that allows us to perform approximate planning in factored state Markov decisions processes without incurring intractability due to the exponentially large state space. The variational perspective shows that the previous types of inference for planning are only adequate in environments with low stochasticity, and allows us to characterize each type by its own merits, disentangling the type of inference from the additional approximations that its practical use requires. We validate these results empirically on synthetic MDPs and tasks posed in the International Planning Competition.
Learning Cognitive Maps from Transformer Representations for Efficient Planning in Partially Observed Environments
Dedieu, Antoine, Lehrach, Wolfgang, Zhou, Guangyao, George, Dileep, Lázaro-Gredilla, Miguel
Despite their stellar performance on a wide range of tasks, including in-context tasks only revealed during inference, vanilla transformers and variants trained for next-token predictions (a) do not learn an explicit world model of their environment which can be flexibly queried and (b) cannot be used for planning or navigation. In this paper, we consider partially observed environments (POEs), where an agent receives perceptually aliased observations as it navigates, which makes path planning hard. We introduce a transformer with (multiple) discrete bottleneck(s), TDB, whose latent codes learn a compressed representation of the history of observations and actions. After training a TDB to predict the future observation(s) given the history, we extract interpretable cognitive maps of the environment from its active bottleneck(s) indices. These maps are then paired with an external solver to solve (constrained) path planning problems. First, we show that a TDB trained on POEs (a) retains the near perfect predictive performance of a vanilla transformer or an LSTM while (b) solving shortest path problems exponentially faster. Second, a TDB extracts interpretable representations from text datasets, while reaching higher in-context accuracy than vanilla sequence models. Finally, in new POEs, a TDB (a) reaches near-perfect in-context accuracy, (b) learns accurate in-context cognitive maps (c) solves in-context path planning problems.
Graph schemas as abstractions for transfer learning, inference, and planning
Guntupalli, J. Swaroop, Raju, Rajkumar Vasudeva, Kushagra, Shrinu, Wendelken, Carter, Sawyer, Danny, Deshpande, Ishan, Zhou, Guangyao, Lázaro-Gredilla, Miguel, George, Dileep
Transferring latent structure from one environment or problem to another is a mechanism by which humans and animals generalize with very little data. Inspired by cognitive and neurobiological insights, we propose graph schemas as a mechanism of abstraction for transfer learning. Graph schemas start with latent graph learning where perceptually aliased observations are disambiguated in the latent space using contextual information. Latent graph learning is also emerging as a new computational model of the hippocampus to explain map learning and transitive inference. Our insight is that a latent graph can be treated as a flexible template -- a schema -- that models concepts and behaviors, with slots that bind groups of latent nodes to the specific observations or groundings. By treating learned latent graphs (schemas) as prior knowledge, new environments can be quickly learned as compositions of schemas and their newly learned bindings. We evaluate graph schemas on two previously published challenging tasks: the memory & planning game and one-shot StreetLearn, which are designed to test rapid task solving in novel environments. Graph schemas can be learned in far fewer episodes than previous baselines, and can model and plan in a few steps in novel variations of these tasks. We also demonstrate learning, matching, and reusing graph schemas in more challenging 2D and 3D environments with extensive perceptual aliasing and size variations, and show how different schemas can be composed to model larger and more complex environments. To summarize, our main contribution is a unified system, inspired and grounded in cognitive science, that facilitates rapid transfer learning of new environments using schemas via map-induction and composition that handles perceptual aliasing.
PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX
Zhou, Guangyao, Kumar, Nishanth, Lázaro-Gredilla, Miguel, Kushagra, Shrinu, George, Dileep
PGMax is an open-source Python package for easy specification of discrete Probabilistic Graphical Models (PGMs) as factor graphs, and automatic derivation of efficient and scalable loopy belief propagation (LBP) implementation in JAX. It supports general factor graphs, and can effectively leverage modern accelerators like GPUs for inference. Compared with existing alternatives, PGMax obtains higher-quality inference results with orders-of-magnitude inference speedups. PGMax additionally interacts seamlessly with the rapidly growing JAX ecosystem, opening up exciting new possibilities. Our source code, examples and documentation are available at https://github.com/vicariousinc/PGMax.
Sample-efficient L0-L2 constrained structure learning of sparse Ising models
Dedieu, Antoine, Lázaro-Gredilla, Miguel, George, Dileep
We consider the problem of learning the underlying graph of a sparse Ising model with $p$ nodes from $n$ i.i.d. samples. The most recent and best performing approaches combine an empirical loss (the logistic regression loss or the interaction screening loss) with a regularizer (an L1 penalty or an L1 constraint). This results in a convex problem that can be solved separately for each node of the graph. In this work, we leverage the cardinality constraint L0 norm, which is known to properly induce sparsity, and further combine it with an L2 norm to better model the non-zero coefficients. We show that our proposed estimators achieve an improved sample complexity, both (a) theoretically -- by reaching new state-of-the-art upper bounds for recovery guarantees -- and (b) empirically -- by showing sharper phase transitions between poor and full recovery for graph topologies studied in the literature -- when compared to their L1-based counterparts.
Query Training: Learning and inference for directed and undirected graphical models
Lázaro-Gredilla, Miguel, Lehrach, Wolfgang, Gothoskar, Nishad, Zhou, Guangyao, Dedieu, Antoine, George, Dileep
Probabilistic graphical models (PGMs) provide a compact representation of knowledge that can be queried in a flexible way: after learning the parameters of a graphical model, new probabilistic queries can be answered at test time without retraining. However, learning undirected graphical models is notoriously hard due to the intractability of the partition function. For directed models, a popular approach is to use variational autoencoders, but there is no systematic way to choose the encoder architecture given the PGM, and the encoder only amortizes inference for a single probabilistic query (i.e., new queries require separate training). We introduce Query Training (QT), a systematic method to turn any PGM structure (directed or not, with or without hidden variables) into a trainable inference network. This single network can approximate any inference query. We demonstrate experimentally that QT can be used to learn a challenging 8-connected grid Markov random field with hidden variables and that it consistently outperforms the state-of-the-art AdVIL when tested on three undirected models across multiple datasets.
From proprioception to long-horizon planning in novel environments: A hierarchical RL model
Gothoskar, Nishad, Lázaro-Gredilla, Miguel, George, Dileep
For an intelligent agent to flexibly and efficiently operate in complex environments, they must be able to reason at multiple levels of temporal, spatial, and conceptual abstraction. At the lower levels, the agent must interpret their proprioceptive inputs and control their muscles, and at the higher levels, the agent must select goals and plan how they will achieve those goals. It is clear that each of these types of reasoning is amenable to different types of representations, algorithms, and inputs. In this work, we introduce a simple, three-level hierarchical architecture that reflects these distinctions. The low-level controller operates on the continuous proprioceptive inputs, using model-free learning to acquire useful behaviors. These in turn induce a set of mid-level dynamics, which are learned by the mid-level controller and used for model-predictive control, to select a behavior to activate at each timestep. The high-level controller leverages a discrete, graph representation for goal selection and path planning to specify targets for the mid-level controller. We apply our method to a series of navigation tasks in the Mujoco Ant environment, consistently demonstrating significant improvements in sample-efficiency compared to prior model-free, model-based, and hierarchical RL methods. Finally, as an illustrative example of the advantages of our architecture, we apply our method to a complex maze environment that requires efficient exploration and long-horizon planning.
Learning higher-order sequential structure with cloned HMMs
Dedieu, Antoine, Gothoskar, Nishad, Swingle, Scott, Lehrach, Wolfgang, Lázaro-Gredilla, Miguel, George, Dileep
Sequence modeling is a fundamental real-world problem with a wide range of applications. Recurrent neural networks (RNNs) are currently preferred in sequence prediction tasks due to their ability to model long-term and variable order dependencies. However, RNNs have disadvantages in several applications because of their inability to natively handle uncertainty, and because of the inscrutable internal representations. Probabilistic sequence models like Hidden Markov Models (HMM) have the advantage of more interpretable representations and the ability to handle uncertainty. Although overcomplete HMMs with many more hidden states compared to the observed states can, in theory, model long-term temporal dependencies [23], training HMMs is challenging due to credit diffusion [3]. For this reason, simpler and inflexible n-gram models are preferred to HMMs for tasks like language modeling. Tensor decomposition methods [1] have been suggested for the learning of HMMs in order to overcome the credit diffusion problem, but current methods are not applicable to the overcomplete setting where the full-rank requirements on the transition and emission matrices are not fulfilled. Recently there has been renewed interest in the topic of training overcomplete HMMs for higher-order dependencies with the expectation that sparsity structures could potentially alleviate the credit diffusion problem [23]. In this paper we demonstrate that a particular sparsity structure on the emission matrix can help HMMs learn higher-order temporal structure using the standard Expectation-Maximization algorithms [26] (Baum-Welch) and its online variants.