Neural Bregman Divergences for Distance Learning

Lu, Fred, Raff, Edward, Ferraro, Francis

arXiv.org Artificial Intelligence 

Many metric learning tasks, such as triplet learning, nearest neighbor retrieval, and visualization, are treated primarily as embedding tasks where the ultimate metric is some variant of the Euclidean distance (e.g., cosine or Mahalanobis), and the algorithm must learn to embed points into the pre-chosen space. The study of non-Euclidean geometries is often not explored, which we believe is due to a lack of tools for learning non-Euclidean measures of distance. Recent work has shown that Bregman divergences can be learned from data, opening a promising approach to learning asymmetric distances. We propose a new approach to learning arbitrary Bergman divergences in a differentiable manner via input convex neural networks and show that it overcomes significant limitations of previous works. We also demonstrate that our method more faithfully learns divergences over a set of both new and previously studied tasks, including asymmetric regression, ranking, and clustering. Our tests further extend to known asymmetric, but non-Bregman tasks, where our method still performs competitively despite misspecification, showing the general utility of our approach for asymmetric learning. Learning a task-relevant metric among samples is a common application of machine learning, with use in retrieval, clustering, and ranking. A classic example of retrieval is in visual recognition where, given an object image, the system tries to identify the class based on an existing labeled dataset. To do this, the model can learn a measure of similarity between pairs of images, assigning small distances between images of the same object type. Given the broad successes of deep learning, there has been a recent surge of interest in deep metric learning--using neural networks to automatically learn these similarities (Hoffer & Ailon, 2015; Huang et al., 2016; Zhang et al., 2020). The traditional approach to deep metric learning learns an embedding function over the input space so that a simple distance measure between pairs of embeddings corresponds to task-relevant spatial relations between the inputs. The embedding function f is computed by a neural network, which is learned to encode those spatial relations. First, it is used to define the loss functions, such as triplet or contrastive loss, to dictate how this distance should be used to capture task-relevant properties of the input space. Second, since f is trained to optimize the loss function, the distance influences the learned embedding f.