jax
Covariance-aware sampling for Diffusion Models
Schioppa, Andrea, Salimans, Tim
We present a covariance-aware sampler that improves the quality of pixel-space Diffusion Model (DM) sampling in the few-step regime. We hypothesize that in the few-step regime samplers fail because they rely solely on the predicted mean of the reverse distribution, while our solution explicitly models the reverse-process covariance. Our method combines Tweedie's formula to estimate the covariance with an efficient, structured Fourier-space decomposition of the covariance matrix. Implemented as an extension of DDIM, our method requires only a minimal overhead: one extra Jacobian-Vector Product (JVP) per step. We demonstrate that for pixel-based DMs, our method consistently produces superior samples compared to state-of-the-art second order samplers (Heun, DPM-Solver++) and the recent aDDIM sampler, at an identical number of function evaluations (NFE).
1663fba7b56da1e96bed6e30546a07b0-Supplemental-Conference.pdf
Thus,theassumption of the policy being conditionally-independent ofzω givenziα corresponds well to the assumption of agents only using local information (rather than joint information) in MARL to inform their policy/decision-making. Note that we found that cyclically-annealing [82]theβ term in our variational lower bound from0to the values specified in Table 5to help avoid KL-vanishing. A.2.4 ComputationalDetails For MARL trajectory data generation, we used an internal CPU cluster for both the 3-agent hillclimbing and 2-agent coordination domains, using TPUs for only the multiagent MuJoCo data generation. Given a characteristic of interest (e.g., the level of dispersion of agents), we define a training set consisting of joint latentszω and class labelsy (e.g., classes corresponding to different intervals of team returns). Using these definitions, we can gauge the representational power ofzω by learning a mapping g: ˆνc(zω) y. In practice, g is a simple model (e.g., shallow network or linear projection) so as to gauge the expressivity of the latent space.
MPX: Mixed Precision Training for JAX
Gräfe, Alexander, Trimpe, Sebastian
Mixed-precision training has emerged as an indispensable tool for enhancing the efficiency of neural network training in recent years. Concurrently, JAX has grown in popularity as a versatile machine learning toolbox. However, it currently lacks robust support for mixed-precision training. We propose MPX, a mixed-precision training toolbox for JAX that simplifies and accelerates the training of large-scale neural networks while preserving model accuracy. MPX seamlessly integrates with popular toolboxes such as Equinox and Flax, allowing users to convert full-precision pipelines to mixed-precision versions with minimal modifications. By casting both inputs and outputs to half precision, and introducing a dynamic loss-scaling mechanism, MPX alleviates issues like gradient underflow and overflow that commonly arise in half precision computations. Its design inherits critical features from JAX's type-promotion behavior, ensuring that operations take place in the correct precision and allowing for selective enforcement of full precision where needed (e.g., sums, means, or softmax). MPX further provides wrappers for automatic creation and management of mixed-precision gradients and optimizers, enabling straightforward integration into existing JAX training pipelines. MPX's source code, documentation, and usage examples are available at github.com/Data-Science-in-Mechanical-Engineering/mixed_precision_for_JAX .
ABMax: A JAX-based Agent-based Modeling Framework
Chaturvedi, Siddharth, El-Gazzar, Ahmed, van Gerven, Marcel
Agent-based modeling (ABM) is a principal approach for studying complex systems. By decomposing a system into simpler, interacting agents, agent-based modeling (ABM) allows researchers to observe the emergence of complex phenomena. High-performance array computing libraries like JAX can help scale such computational models to a large number of agents by using automatic vectorization and just-in-time (JIT) compilation. One of the caveats of using JAX to achieve such scaling is that the shapes of arrays used in the computational model should remain immutable throughout the simulation. In the context of agent-based modeling (ABM), this can pose constraints on certain agent manipulation operations that require flexible data structures. A subset of which is represented by the ability to update a dynamically selected number of agents by applying distinct changes to them during a simulation. To this effect, we introduce ABMax, an ABM framework based on JAX that implements multiple just-in-time (JIT) compilable algorithms to provide this functionality. On the canonical predation model benchmark, ABMax achieves runtime performance comparable to state-of-the-art implementations. Further, we show that this functionality can also be vectorized, making it possible to run many similar agent-based models in parallel. We also present two examples in the form of a traffic-flow model and a financial market model to show the use case of ABMax