Goto

Collaborating Authors

 Kushagra, Shrinu


Model Predictive Simulation Using Structured Graphical Models and Transformers

arXiv.org Artificial Intelligence

We propose an approach to simulating trajectories of multiple interacting agents (road users) based on transformers and probabilistic graphical models (PGMs), and apply it to the Waymo SimAgents challenge. The transformer baseline is based on the MTR model, which predicts multiple future trajectories conditioned on the past trajectories and static road layout features. We then improve upon these generated trajectories using a PGM, which contains factors which encode prior knowledge, such as a preference for smooth trajectories, and avoidance of collisions with static obstacles and other moving agents. We perform (approximate) MAP inference in this PGM using the Gauss-Newton method. Finally we sample $K=32$ trajectories for each of the $N \sim 100$ agents for the next $T=8 \Delta$ time steps, where $\Delta=10$ is the sampling rate per second. Following the Model Predictive Control (MPC) paradigm, we only return the first element of our forecasted trajectories at each step, and then we replan, so that the simulation can constantly adapt to its changing environment. We therefore call our approach "Model Predictive Simulation" or MPS. We show that MPS improves upon the MTR baseline, especially in safety critical metrics such as collision rate. Furthermore, our approach is compatible with any underlying forecasting model, and does not require extra training, so we believe it is a valuable contribution to the community.


Graph schemas as abstractions for transfer learning, inference, and planning

arXiv.org Artificial Intelligence

Transferring latent structure from one environment or problem to another is a mechanism by which humans and animals generalize with very little data. Inspired by cognitive and neurobiological insights, we propose graph schemas as a mechanism of abstraction for transfer learning. Graph schemas start with latent graph learning where perceptually aliased observations are disambiguated in the latent space using contextual information. Latent graph learning is also emerging as a new computational model of the hippocampus to explain map learning and transitive inference. Our insight is that a latent graph can be treated as a flexible template -- a schema -- that models concepts and behaviors, with slots that bind groups of latent nodes to the specific observations or groundings. By treating learned latent graphs (schemas) as prior knowledge, new environments can be quickly learned as compositions of schemas and their newly learned bindings. We evaluate graph schemas on two previously published challenging tasks: the memory & planning game and one-shot StreetLearn, which are designed to test rapid task solving in novel environments. Graph schemas can be learned in far fewer episodes than previous baselines, and can model and plan in a few steps in novel variations of these tasks. We also demonstrate learning, matching, and reusing graph schemas in more challenging 2D and 3D environments with extensive perceptual aliasing and size variations, and show how different schemas can be composed to model larger and more complex environments. To summarize, our main contribution is a unified system, inspired and grounded in cognitive science, that facilitates rapid transfer learning of new environments using schemas via map-induction and composition that handles perceptual aliasing.


PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX

arXiv.org Machine Learning

PGMax is an open-source Python package for easy specification of discrete Probabilistic Graphical Models (PGMs) as factor graphs, and automatic derivation of efficient and scalable loopy belief propagation (LBP) implementation in JAX. It supports general factor graphs, and can effectively leverage modern accelerators like GPUs for inference. Compared with existing alternatives, PGMax obtains higher-quality inference results with orders-of-magnitude inference speedups. PGMax additionally interacts seamlessly with the rapidly growing JAX ecosystem, opening up exciting new possibilities. Our source code, examples and documentation are available at https://github.com/vicariousinc/PGMax.


Better Together: Resnet-50 accuracy with $13 \times$ fewer parameters and at $3\times$ speed

arXiv.org Machine Learning

Recent research on compressing deep neural networks has focused on reducing the number of parameters. Smaller networks are easier to export and deploy on edge-devices. We introduce Adjoined networks as a training approach that can regularize and compress any CNN-based neural architecture. Our one-shot learning paradigm trains both the original and the smaller networks together. The parameters of the smaller network are shared across both the architectures. We prove strong theoretical guarantees on the regularization behavior of the adjoint training paradigm. We complement our theoretical analysis by an extensive empirical evaluation of both the compression and regularization behavior of adjoint networks. For resnet-50 trained adjointly on Imagenet, we are able to achieve a $13.7x$ reduction in the number of parameters (For size comparison, we ignore the parameters in the last linear layer as it varies by dataset and are typically dropped during fine-tuning. Else, the reductions are $11.5x$ and $95x$ for imagenet and cifar-100 respectively.) and a $3x$ improvement in inference time without any significant drop in accuracy. For the same architecture on CIFAR-100, we are able to achieve a $99.7x$ reduction in the number of parameters and a $5x$ improvement in inference time. On both these datasets, the original network trained in the adjoint fashion gains about $3\%$ in top-1 accuracy as compared to the same network trained in the standard fashion.


On sampling from data with duplicate records

arXiv.org Machine Learning

Data deduplication is the task of detecting records in a database that correspond to the same real-world entity. Our goal is to develop a procedure that samples uniformly from the set of entities present in the database in the presence of duplicates. We accomplish this by a two-stage process. In the first step, we estimate the frequencies of all the entities in the database. In the second step, we use rejection sampling to obtain a (approximately) uniform sample from the set of entities. However, efficiently estimating the frequency of all the entities is a non-trivial task and not attainable in the general case. Hence, we consider various natural properties of the data under which such frequency estimation (and consequently uniform sampling) is possible. Under each of those assumptions, we provide sampling algorithms and give proofs of the complexity (both statistical and computational) of our approach. We complement our study by conducting extensive experiments on both real and synthetic datasets.


Record fusion: A learning approach

arXiv.org Machine Learning

Record fusion is the task of aggregating multiple records that correspond to the same real-world entity in a database. We can view record fusion as a machine learning problem where the goal is to predict the "correct" value for each attribute for each entity. Given a database, we use a combination of attribute-level, recordlevel, and database-level signals to construct a feature vector for each cell (or (row, col)) of that database. We use this feature vector alongwith the ground-truth information to learn a classifier for each of the attributes of the database. Our learning algorithm uses a novel stagewise additive model. At each stage, we construct a new feature vector by combining a part of the original feature vector with features computed by the predictions from the previous stage. We then learn a softmax classifier over the new feature space. This greedy stagewise approach can be viewed as a deep model where at each stage, we are adding more complicated non-linear transformations of the original feature vector. We show that our approach fuses records with an average precision of ~98% when source information of records is available, and ~94% without source information across a diverse array of real-world datasets. We compare our approach to a comprehensive collection of data fusion and entity consolidation methods considered in the literature. We show that our approach can achieve an average precision improvement of ~20%/~45% with/without source information respectively.


Semi-supervised clustering for de-duplication

arXiv.org Machine Learning

Data de-duplication is the task of detecting multiple records that correspond to the same real-world entity in a database. In this work, we view de-duplication as a clustering problem where the goal is to put records corresponding to the same physical entity in the same cluster and putting records corresponding to different physical entities into different clusters. We introduce a framework which we call promise correlation clustering. Given a complete graph $G$ with the edges labelled $0$ and $1$, the goal is to find a clustering that minimizes the number of $0$ edges within a cluster plus the number of $1$ edges across different clusters (or correlation loss). The optimal clustering can also be viewed as a complete graph $G^*$ with edges corresponding to points in the same cluster being labelled $0$ and other edges being labelled $1$. Under the promise that the edge difference between $G$ and $G^*$ is "small", we prove that finding the optimal clustering (or $G^*$) is still NP-Hard. [Ashtiani et. al, 2016] introduced the framework of semi-supervised clustering, where the learning algorithm has access to an oracle, which answers whether two points belong to the same or different clusters. We further prove that even with access to a same-cluster oracle, the promise version is NP-Hard as long as the number queries to the oracle is not too large ($o(n)$ where $n$ is the number of vertices). Given these negative results, we consider a restricted version of correlation clustering. As before, the goal is to find a clustering that minimizes the correlation loss. However, we restrict ourselves to a given class $\mathcal F$ of clusterings. We offer a semi-supervised algorithmic approach to solve the restricted variant with success guarantees.


Clustering with Same-Cluster Queries

Neural Information Processing Systems

We propose a framework for Semi-Supervised Active Clustering framework (SSAC), where the learner is allowed to interact with a domain expert, asking whether two given instances belong to the same cluster or not. We study the query and computational complexity of clustering in this framework. We consider a setting where the expert conforms to a center-based clustering with a notion of margin. We show that there is a trade off between computational complexity and query complexity; We prove that for the case of $k$-means clustering (i.e., when the expert conforms to a solution of $k$-means), having access to relatively few such queries allows efficient solutions to otherwise NP hard problems. In particular, we provide a probabilistic polynomial-time (BPP) algorithm for clustering in this setting that asks $O\big(k^2\log k + k\log n)$ same-cluster queries and runs with time complexity $O\big(kn\log n)$ (where $k$ is the number of clusters and $n$ is the number of instances). The success of the algorithm is guaranteed for data satisfying the margin condition under which, without queries, we show that the problem is NP hard. We also prove a lower bound on the number of queries needed to have a computationally efficient clustering algorithm in this setting.


Clustering with Same-Cluster Queries

arXiv.org Machine Learning

We propose a framework for Semi-Supervised Active Clustering framework (SSAC), where the learner is allowed to interact with a domain expert, asking whether two given instances belong to the same cluster or not. We study the query and computational complexity of clustering in this framework. We consider a setting where the expert conforms to a center-based clustering with a notion of margin. We show that there is a trade off between computational complexity and query complexity; We prove that for the case of $k$-means clustering (i.e., when the expert conforms to a solution of $k$-means), having access to relatively few such queries allows efficient solutions to otherwise NP hard problems. In particular, we provide a probabilistic polynomial-time (BPP) algorithm for clustering in this setting that asks $O\big(k^2\log k + k\log n)$ same-cluster queries and runs with time complexity $O\big(kn\log n)$ (where $k$ is the number of clusters and $n$ is the number of instances). The algorithm succeeds with high probability for data satisfying margin conditions under which, without queries, we show that the problem is NP hard. We also prove a lower bound on the number of queries needed to have a computationally efficient clustering algorithm in this setting.