Variational Schrödinger Diffusion Models

Authors: Wei Deng, Weijian Luo, Yixin Tan, Marin Biloš, Yu Chen, Yuriy Nevmyvaka, Ricky T. Q. Chen

What

This paper presents Variational Schr”odinger Diffusion Model (VSDM), a novel diffusion model that leverages variational inference to enhance the scalability of Schr”odinger bridge (SB) for optimizing transportation plans, while preserving efficient transport.

Why

While SB offers optimal transport guarantees, it faces scalability limitations due to the need for costly simulated trajectories. VSDM overcomes this by linearizing forward score functions, leading to closed-form updates and enabling simulation-free training of backward score functions. This enhances scalability and makes the algorithm more tuning-friendly for large-scale experiments.

How

The authors employ variational inference to approximate the forward score function in SB using a locally linear function, leading to the variational FB-SDE. They then utilize a multivariate OU process for the forward diffusion and derive closed-form expressions for the backward score function. They also use stochastic approximation to adaptively optimize the variational score for efficient transport.

Result

VSDM demonstrates effectiveness in generating anisotropic shapes and produces straighter sample trajectories, indicating more efficient transport, compared to single-variate diffusions. It achieves competitive performance in image generation on CIFAR10 and conditional time series modeling, all without relying on warm-up initializations. Furthermore, VSDM is observed to be significantly faster than the original SB with nonlinear forward scores.

LF

The paper acknowledges that linearizing the forward score function inevitably results in sub-optimal transport in general cases. Future work includes exploring critically damped (momentum) acceleration and Hessian approximations to develop advanced optimization techniques akin to “ADAM” for diffusion models.

Abstract

Schr”odinger bridge (SB) has emerged as the go-to method for optimizing transportation plans in diffusion models. However, SB requires estimating the intractable forward score functions, inevitably resulting in the costly implicit training loss based on simulated trajectories. To improve the scalability while preserving efficient transportation plans, we leverage variational inference to linearize the forward score functions (variational scores) of SB and restore simulation-free properties in training backward scores. We propose the variational Schr”odinger diffusion model (VSDM), where the forward process is a multivariate diffusion and the variational scores are adaptively optimized for efficient transport. Theoretically, we use stochastic approximation to prove the convergence of the variational scores and show the convergence of the adaptively generated samples based on the optimal variational scores. Empirically, we test the algorithm in simulated examples and observe that VSDM is efficient in generations of anisotropic shapes and yields straighter sample trajectories compared to the single-variate diffusion. We also verify the scalability of the algorithm in real-world data and achieve competitive unconditional generation performance in CIFAR10 and conditional generation in time series modeling. Notably, VSDM no longer depends on warm-up initializations and has become tuning-friendly in training large-scale experiments.