In-context Reinforcement Learning with Algorithm Distillation
Laskin, Michael, Wang, Luyu, Oh, Junhyuk, Parisotto, Emilio, Spencer, Stephen, Steigerwald, Richie, Strouse, DJ, Hansen, Steven, Filos, Angelos, Brooks, Ethan, Gazeau, Maxime, Sahni, Himanshu, Singh, Satinder, Mnih, Volodymyr
–arXiv.org Artificial Intelligence
We propose Algorithm Distillation (AD), a method for distilling reinforcement learning (RL) algorithms into neural networks by modeling their training histories with a causal sequence model. Algorithm Distillation treats learning to reinforcement learn as an across-episode sequential prediction problem. A dataset of learning histories is generated by a source RL algorithm, and then a causal transformer is trained by autoregressively predicting actions given their preceding learning histories as context. Unlike sequential policy prediction architectures that distill post-learning or expert sequences, AD is able to improve its policy entirely in-context without updating its network parameters. We demonstrate that AD can reinforcement learn in-context in a variety of environments with sparse rewards, combinatorial task structure, and pixel-based observations, and find that AD learns a more data-efficient RL algorithm than the one that generated the source data. Algorithm Distillation (AD) has two steps - (i) a dataset of learning histories is collected from individual single-task RL algorithms solving different tasks; (ii) a causal transformer predicts actions from these histories using across-episodic contexts. Since the RL policy improves throughout the learning histories, by predicting actions accurately AD learns to output an improved policy relative to the one seen in its context. AD models state-action-reward tokens, and does not condition on returns. Transformers have emerged as powerful neural network architectures for sequence modeling (Vaswani et al., 2017). A striking property of pre-trained transformers is their ability to adapt to downstream tasks through prompt conditioning or in-context learning. After pre-training on large offline datasets, large transformers have been shown to generalize to downstream tasks in text completion (Brown et al., 2020), language understanding (Devlin et al., 2018), and image generation (Yu et al., 2022). Recent work demonstrated that transformers can also learn policies from offline data by treating offline Reinforcement Learning (RL) as a sequential prediction problem. While Chen et al. (2021) showed that transformers can learn single-task policies from offline RL data via imitation learning, subsequent work showed that transformers can also extract multi-task policies in both same-domain (Lee et al., 2022) and cross-domain settings (Reed et al., 2022).
arXiv.org Artificial Intelligence
Oct-25-2022
- Country:
- North America
- United States > New York
- Richmond County > New York City (0.04)
- Queens County > New York City (0.04)
- New York County > New York City (0.04)
- Kings County > New York City (0.04)
- Bronx County > New York City (0.04)
- Puerto Rico > San Juan
- San Juan (0.04)
- United States > New York
- Asia
- Middle East > Jordan (0.04)
- Japan > Honshū
- Chūbu > Toyama Prefecture > Toyama (0.04)
- North America
- Genre:
- Research Report (0.82)
- Industry:
- Leisure & Entertainment (0.93)
- Technology: