Provably Fast Finite Particle Variants of SVGD via Virtual Particle Stochastic Approximation

Part of Advances in Neural Information Processing Systems 36 (NeurIPS 2023) Main Conference Track

Bibtex Paper Supplemental


Aniket Das, Dheeraj Nagaraj


Stein Variational Gradient Descent (SVGD) is a popular particle-based variational inference algorithm with impressive empirical performance across various domains. Although the population (i.e, infinite-particle) limit dynamics of SVGD is well characterized, its behavior in the finite-particle regime is far less understood. To this end, our work introduces the notion of *virtual particles* to develop novel stochastic approximations of population-limit SVGD dynamics in the space of probability measures, that are exactly realizable using finite particles. As a result, we design two computationally efficient variants of SVGD, namely VP-SVGD and GB-SVGD, with provably fast finite-particle convergence rates. Our algorithms can be viewed as specific random-batch approximations of SVGD, which are computationally more efficient than ordinary SVGD. We show that the $n$ particles output by VP-SVGD and GB-SVGD, run for $T$ steps with batch-size $K$, are at-least as good as i.i.d samples from a distribution whose Kernel Stein Discrepancy to the target is at most $O(\tfrac{d^{1/3}}{(KT)^{1/6}})$ under standard assumptions. Our results also hold under a mild growth condition on the potential function, which is much weaker than the isoperimetric (e.g. Poincare Inequality) or information-transport conditions (e.g. Talagrand's Inequality $\mathsf{T}_1$) generally considered in prior works. As a corollary, we analyze the convergence of the empirical measure (of the particles output by VP-SVGD and GB-SVGD) to the target distribution and demonstrate a **double exponential improvement** over the best known finite-particle analysis of SVGD. Beyond this, our results present the **first known oracle complexities for this setting with polynomial dimension dependence**, thereby completely eliminating the curse of dimensionality exhibited by previously known finite-particle rates.