Score-Matching Langevin Dynamics (SMLD)

This is a learning note of this series of videos.

Link to the paper: https://arxiv.org/abs/2403.18103

Core components: 1. Langevin Dynamics; 2. Stein score function; 3. score-matching loss

Goal: generate (sample) data {x1,,xT}\{ x_1, \dots, x_T \}  from a distribution p(x)p(x). If p(x)p(x) is known, plug in xlogp(x)\nabla_x \log p(x) into Langevin Dynamics to generate samples. If p(x)p(x) is unknown, train a neural network for approximation sθ(x)xlogp(x)s_{\theta}(x) \approx \nabla_x \log p(x)

1. Langevin Dynamics is gradient descent

Definition1.1: The Langevin dynamics for sampling from a known distribution p(x)p(x) is an iterative procedure for t=1,,Tt = 1, \dots , T:

xt+1=xt+τxlogp(xt)+2τz,zN(0,I)x_{t+1} = x_t + \tau \nabla_x \log p(x_t) + \sqrt{2 \tau} z, \qquad z \sim \mathcal{N}(0, \bold{I})

where τ\tau is the step size which users can control, and x0x_0 is white noise.

Remark: Without the noise term 2τz\sqrt{2\tau} z, Langevin dynamics is gradient descent.

The goal of sampling is equivalent to solving the optimization:

x=arg maxxlogp(x)x^* = \argmax_{x} \log p(x)

This optimization can be solved by gradient descent. One gradient descent step is:

xt+1=xt+τlogp(x)x_{t+1} = x_t + \tau \log p(x)

2. Langevin Dynamics is not only gradient descent, but also stochastic gradient descent

Why stochastic? We are more interested in sampling rather than optimization.

3. Langevin Dynamics from a physics perspective

Relationship between force F\bold{F} and mass mm and velocity v(t)v(t):

F=mdv(t)dt\bold{F} = m \cdot \frac{d v(t)}{dt}

Relationship between force F\bold{F} and the potential energy U(x)U(x):

F=xU(x)\bold{F} = \nabla_x U(x)

The randomness of Langevin dynamics comes from Brownian motion. Suppose there is a bag of molecules moving around. Their motion can be described according to the Brownian motion model:

dv(t)dt=λmv(t)+1mη,where ηN(0,σ2I)\frac{d\bold{v}(t)}{dt} = -\frac{\lambda}{m} \bold{v}(t) + \frac{1}{m} \bold{\eta}, \quad \text{where } \bold{\eta} \sim \mathcal{N}(0, \sigma^2 \bold{I})

According to the above three equations, we have:

xU(x)=F=mdv(t)dt=λv(t)+ηv(t)=1λxU(x)+1λη\nabla_x U(x) = \bold{F} = m \cdot \frac{d \bold{v}(t)}{dt} = -\lambda \bold{v}(t) + \bold{\eta} \quad \Rightarrow \quad \bold{v}(t) = -\frac{1}{\lambda} \nabla_x U(x) + \frac{1}{\lambda} \bold{\eta}

Since dxdt=v(t)\frac{dx}{dt} = \bold{v}(t) and ηN(0,σ2I)\bold{\eta} \sim \mathcal{N} (0, \sigma^2 \bold{I}), we have

dxdt=1λxU(x)+σλz,where zN(0,I)\frac{d\bold{x}}{dt} = -\frac{1}{\lambda} \nabla_x U(x) + \frac{\sigma}{\lambda} \bold{z}, \quad \text{where } z \sim \mathcal{N} (0, \bold{I})

If we let τ=dtλ\tau = \frac{dt}{\lambda} and discretize the above differential equation, we will obtain:

xt+1=xt1λxU(x)dt+σλzdt=xtτxU(xt)+στztx_{t+1} = x_t -\frac{1}{\lambda} \nabla_x U(x) dt + \frac{\sigma}{\lambda} \bold{z} dt = x_t - \tau \nabla_x U(x_t) + \sigma \tau z_t

A lazy choice to determine the energy potential is using the Boltzmann distribution with the form:

p(x)=1Zexp{U(x)}p(\bold{x}) = \frac{1}{Z} \exp \{ -U(\bold{x}) \}

Therefore,

xlogp(x)=x(U(x)logZ)=xU(x)\nabla_x \log p(x) = \nabla_x (-U(\bold{x}) - \log Z) = - \nabla_x U(\bold{x})

If we choose σ=2τ\sigma = \sqrt{\frac{2}{\tau}}, we will obtain:

xt+1=xtτxU(xt)+στzt=xt+τxlogp(xt)+στzt=xt+τxlogp(xt)+2τzt\begin{align*} \bold{x}_{t+1} &= \bold{x}_t - \tau \nabla_x U(\bold{x}_t) + \sigma \tau \bold{z}_t \\ &= \bold{x}_t + \tau \nabla_x \log p(\bold{x}_t) + \sigma \tau \bold{z}_t \\ &= \bold{x}_t + \tau \nabla_x \log p(\bold{x}_t)+ \sqrt{2 \tau} \bold{z}_t\end{align*}

which is the Langevin Dynamics.

4. Stein’s Score Function

Definition 4.1: (Stein’s score function)

sθ(x)=xlogpθ(x)\bold{s}_{\theta}(\bold{x}) = \nabla_{\bold{x}} \log p_{\bold{\theta}}(\bold{x})

Distinguish it from ordinary score function:

sx(θ)=θlogpθ(x)\bold{s}_{\bold{x}}(\theta) = \nabla_{\theta} \log p_{\theta}(\bold{x})

Example 4.1: If p(x)p(x) is a Gaussian with p(x)=12πσ2e(xμ)22σ2p(x) = \frac{1}{\sqrt{2\pi \sigma^2}} e^{- \frac{(x-\mu)^2}{2\sigma^2}}, then

s(x)=xlogp(x)=x((xμ)22σ2)=(xμ)σ2s(x) = \nabla_x \log p(x) = \nabla_x (-\frac{(x-\mu)^2}{2 \sigma^2}) = -\frac{(x - \mu)}{\sigma^2}

5. Score-Matching Techniques

sθ(x)=xp(x)\bold{s}_{\theta}(x) = \nabla_x p(\bold{x})

Problem: We don’t know p(x)p(\bold{x}), so we can not calculate sθ(x)s_{\theta}(x).

Goal: Calculate xp(x)\nabla_{\bold{x}} p(\bold{x}) without knowing the real distribution p(x)p(\bold{x})

5.1 Explicit Score-Matching

Use q(x)q(\bold{x}) as an approximation of the true data distribution p(x)p(\bold{x}). We can use classical kernel density estimation to obtain q(x)q(\bold{x}):

q(x)=1Mm=1M1hK(xxmh)q(\bold{x}) = \frac{1}{M} \sum_{m=1}^{M} \frac{1}{h} K\left(\frac{\bold{x} - \bold{x}_m}{h}\right)

where hh is the hyperparameter for the kernel function K()K(\cdot), and xm\bold{x}_m is the m-th sample in the training set.

Since q(x)q(\bold{x}) is an approximation of p(x)p(\bold{x}), we can learn sθ(x)\bold{s}_{\theta}(\bold{x}) based on q(x)q(\bold{x}). This leads to the explict score matching loss:

JESM(θ)=Eq(x)sθ(x)xlogq(x)2J_{\text{ESM}}(\theta) = \mathbb{E}_{q(\bold{x})}||\bold{s}_{\theta}(\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x})||^2

By substituting the kernel density estimation, we can show that the loss is:

JESM(θ)=Eq(x)sθ(x)xlogq(x)2=sθ(x)xlogq(x)2[1Mm=1M1hK(xxmh)]dx=1Mm=1Msθ(x)xlogq(x)21hK(xxmh)dx\begin{align*} J_{\text{ESM}}(\theta) &= \mathbb{E}_{q(\bold{x})}||\bold{s}_{\theta}(\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x})||^2 \\ &= \int ||\bold{s}_{\theta}(\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x})||^2 \left[ \frac{1}{M} \sum_{m=1}^{M} \frac{1}{h} K\left(\frac{\bold{x} - \bold{x}_m}{h}\right)\right] dx\\ &= \frac{1}{M} \sum_{m=1}^M \int ||\bold{s}_{\theta}(\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x})||^2 \frac{1}{h} K\left(\frac{\bold{x} - \bold{x}_m}{h}\right) dx \end{align*}

Problem with ESM: When the sample size MM is a large number, the computation of q(x)q(\bold{x}) is expensive. And when the sample size is limited and data is in a high dimensional space, the kernel density estimation can have poor performance.

5.2 Denoising Score Matching

The key difference is that we replace the distribution q(x)q(\bold{x}), by a conditional distribution q(xx)q(\bold{x} | \bold{x}'):

JDSM(θ)=Eq(x,x)[12sθ(x)xlogq(xx)]J_{\text{DSM}}(\theta) = \mathbb{E}_{q(\bold{x}, \bold{x}')} \left[ \frac{1}{2} ||\bold{s}_{\theta} (\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x} | \bold{x}')|| \right]

Specially, we set q(xx)=N(xx,σ2)q(\bold{x} | \bold{x}') = \mathcal{N}(\bold{x} | \bold{x}', \sigma^2), and we can let x=x+σz\bold{x} = \bold{x}' + \sigma \bold{z}. This leads to:

xlogq(xx)=xlog1(2πσ2)dexp(xx22σ2)=x(xx22σ2log(2πσ2)d)=xxσ2=zσ\begin{align*} \nabla_{\bold{x}} \log q(\bold{x} | \bold{x}') &= \nabla_{\bold{x}} \log \frac{1}{(\sqrt{2 \pi \sigma^2})^d} \exp \left( -\frac{||x - x'||^2}{2 \sigma^2} \right) \\ &= \nabla_{\bold{x}} \left( -\frac{||x - x'||^2}{2 \sigma^2} - \log (\sqrt{2 \pi \sigma^2})^d \right) \\ &= -\frac{\bold{x} - \bold{x}'}{\sigma^2} = -\frac{\bold{z}}{\sigma} \end{align*}

As a result, the loss function of the denoising score matching becomes:

JDSM(θ)=Eq(x,x)[12sθ(x)xlogq(xx)]=Eq(x)[12sθ(x+σz)+zσ2]\begin{align*} J_{\text{DSM}}(\theta) &= \mathbb{E}_{q(\bold{x}, \bold{x}')} \left[ \frac{1}{2} ||\bold{s}_{\theta} (\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x} | \bold{x}')|| \right] \\ &= \mathbb{E}_{q(\bold{x}')} \left[ \frac{1}{2} ||\bold{s}_{\theta}(\bold{x}' + \sigma \bold{z}) + \frac{\bold{z}}{\sigma}||^2 \right]\end{align*}

Replace the dummy variable xx' by xx, and note that sampling from q(x)q(x) can be replaced by sampling from p(x)p(x) give a training dataset. Then we can conclude the denoising score matching loss function:

JDSM(θ)=Ep(x)[12sθ(x+σz)+zσ2]J_{\text{DSM}}(\theta) = \mathbb{E}_{p(\bold{x})}\left[ \frac{1}{2} ||\bold{s}_{\theta}(\bold{x} + \sigma \bold{z}) + \frac{\bold{z}}{\sigma}||^2 \right]

Remark: The above loss function is highly interpretable. The quantity x+σz\bold{x} + \sigma \bold{z} is effectively adding noise σz\sigma \bold{z} to a clean image x\bold{x}. The score function is supposed to take this noisy image and predict the noise zσ\frac{\bold{z}}{\sigma}.

The training step can simply describe as follow. Given a training dataset {x(l)}l=1L\{ \bold{x}^{(l)} \}^L_{l=1}, we train a network θ\theta with the goal to

θ=arg minθ1Ll=1L12sθ(x(l)+σz(l))+z(l)σ2where z(l)N(0,I)\theta^* = \argmin_{\theta} \frac{1}{L} \sum_{l=1}^L \frac{1}{2} ||\bold{s}_{\theta} (\bold{x}^{(l)} + \sigma \bold{z}^{(l)}) + \frac{\bold{z}^{(l)}}{\sigma}||^2 \qquad \text{where } \bold{z}^{(l)} \sim \mathcal{N}(0, \bold{I})