Processing math: 0%

Learning Elastic Costs to Shape Monge Displacements

Part of Advances in Neural Information Processing Systems 37 (NeurIPS 2024) Main Conference Track

Bibtex Paper Supplemental

Authors

Michal Klein, Aram-Alexandre Pooladian, Pierre Ablin, Eugène Ndiaye, Jonathan Niles-Weed, Marco Cuturi

Abstract

Given a source and a target probability measure, the Monge problem studies efficient ways to map the former onto the latter.This efficiency is quantified by defining a *cost* function between source and target data. Such a cost is often set by default in the machine learning literature to the squared-Euclidean distance, 2_2(x,y):=12.The benefits of using *elastic* costs, defined using a regularizer \tau as c(\mathbf{x},\mathbf{y}):=\ell^2_2(\mathbf{x},\mathbf{y})+\tau(\mathbf{x}-\mathbf{y}), was recently highlighted in (Cuturi et al. 2023). Such costs shape the *displacements* of Monge maps T, namely the difference between a source point and its image T(\mathbf{x})-\mathbf{x}, by giving them a structure that matches that of the proximal operator of \tau.In this work, we make two important contributions to the study of elastic costs:*(i)* For any elastic cost, we propose a numerical method to compute Monge maps that are provably optimal. This provides a much-needed routine to create synthetic problems where the ground-truth OT map is known, by analogy to the Brenier theorem, which states that the gradient of any convex potential is always a valid Monge map for the \ell_2^2 cost; *(ii)* We propose a loss to *learn* the parameter \theta of a parameterized regularizer \tau_\theta, and apply it in the case where \tau_{A}({\bf z}):=\|A^\perp {\bf z}\|^2_2. This regularizer promotes displacements that lie on a low-dimensional subspace of \mathbb{R}^d, spanned by the p rows of A\in\mathbb{R}^{p\times d}. We illustrate the soundness of our procedure on synthetic data, generated using our first contribution, in which we show near-perfect recovery of A's subspace using only samples. We demonstrate the applicability of this method by showing predictive improvements on single-cell data tasks.