This is a learning note of this series of videos.
Link to the paper: https://arxiv.org/abs/2403.18103
Core components : 1. Langevin Dynamics; 2. Stein score function; 3. score-matching loss
Goal : generate (sample) data { x 1 , … , x T } \{ x_1, \dots, x_T \} { x 1 , … , x T } from a distribution p ( x ) p(x) p ( x ) . If p ( x ) p(x) p ( x ) is known, plug in ∇ x log p ( x ) \nabla_x \log p(x) ∇ x log p ( x ) into Langevin Dynamics to generate samples. If p ( x ) p(x) p ( x ) is unknown, train a neural network for approximation s θ ( x ) ≈ ∇ x log p ( x ) s_{\theta}(x) \approx \nabla_x \log p(x) s θ ( x ) ≈ ∇ x log p ( x )
1. Langevin Dynamics is gradient descent Definition1.1 : The Langevin dynamics for sampling from a known distribution p ( x ) p(x) p ( x ) is an iterative procedure for t = 1 , … , T t = 1, \dots , T t = 1 , … , T :
x t + 1 = x t + τ ∇ x log p ( x t ) + 2 τ z , z ∼ N ( 0 , I ) x_{t+1} = x_t + \tau \nabla_x \log p(x_t) + \sqrt{2 \tau} z, \qquad z \sim \mathcal{N}(0, \bold{I}) x t + 1 = x t + τ ∇ x log p ( x t ) + 2 τ z , z ∼ N ( 0 , I ) where τ \tau τ is the step size which users can control, and x 0 x_0 x 0 is white noise.
Remark : Without the noise term 2 τ z \sqrt{2\tau} z 2 τ z , Langevin dynamics is gradient descent .
The goal of sampling is equivalent to solving the optimization:
x ∗ = arg max x log p ( x ) x^* = \argmax_{x} \log p(x) x ∗ = x arg max log p ( x ) This optimization can be solved by gradient descent. One gradient descent step is:
x t + 1 = x t + τ log p ( x ) x_{t+1} = x_t + \tau \log p(x) x t + 1 = x t + τ log p ( x )
2. Langevin Dynamics is not only gradient descent, but also stochastic gradient descent Why stochastic? We are more interested in sampling rather than optimization .
3. Langevin Dynamics from a physics perspective Relationship between force F \bold{F} F and mass m m m and velocity v ( t ) v(t) v ( t ) :
F = m ⋅ d v ( t ) d t \bold{F} = m \cdot \frac{d v(t)}{dt} F = m ⋅ d t d v ( t ) Relationship between force F \bold{F} F and the potential energy U ( x ) U(x) U ( x ) :
F = ∇ x U ( x ) \bold{F} = \nabla_x U(x) F = ∇ x U ( x ) The randomness of Langevin dynamics comes from Brownian motion. Suppose there is a bag of molecules moving around. Their motion can be described according to the Brownian motion model:
d v ( t ) d t = − λ m v ( t ) + 1 m η , where η ∼ N ( 0 , σ 2 I ) \frac{d\bold{v}(t)}{dt} = -\frac{\lambda}{m} \bold{v}(t) + \frac{1}{m} \bold{\eta}, \quad \text{where } \bold{\eta} \sim \mathcal{N}(0, \sigma^2 \bold{I}) d t d v ( t ) = − m λ v ( t ) + m 1 η , where η ∼ N ( 0 , σ 2 I ) According to the above three equations, we have:
∇ x U ( x ) = F = m ⋅ d v ( t ) d t = − λ v ( t ) + η ⇒ v ( t ) = − 1 λ ∇ x U ( x ) + 1 λ η \nabla_x U(x) = \bold{F} = m \cdot \frac{d \bold{v}(t)}{dt} = -\lambda \bold{v}(t) + \bold{\eta} \quad \Rightarrow \quad \bold{v}(t) = -\frac{1}{\lambda} \nabla_x U(x) + \frac{1}{\lambda} \bold{\eta} ∇ x U ( x ) = F = m ⋅ d t d v ( t ) = − λ v ( t ) + η ⇒ v ( t ) = − λ 1 ∇ x U ( x ) + λ 1 η Since d x d t = v ( t ) \frac{dx}{dt} = \bold{v}(t) d t d x = v ( t ) and η ∼ N ( 0 , σ 2 I ) \bold{\eta} \sim \mathcal{N} (0, \sigma^2 \bold{I}) η ∼ N ( 0 , σ 2 I ) , we have
d x d t = − 1 λ ∇ x U ( x ) + σ λ z , where z ∼ N ( 0 , I ) \frac{d\bold{x}}{dt} = -\frac{1}{\lambda} \nabla_x U(x) + \frac{\sigma}{\lambda} \bold{z}, \quad \text{where } z \sim \mathcal{N} (0, \bold{I}) d t d x = − λ 1 ∇ x U ( x ) + λ σ z , where z ∼ N ( 0 , I ) If we let τ = d t λ \tau = \frac{dt}{\lambda} τ = λ d t and discretize the above differential equation, we will obtain:
x t + 1 = x t − 1 λ ∇ x U ( x ) d t + σ λ z d t = x t − τ ∇ x U ( x t ) + σ τ z t x_{t+1} = x_t -\frac{1}{\lambda} \nabla_x U(x) dt + \frac{\sigma}{\lambda} \bold{z} dt = x_t - \tau \nabla_x U(x_t) + \sigma \tau z_t x t + 1 = x t − λ 1 ∇ x U ( x ) d t + λ σ z d t = x t − τ ∇ x U ( x t ) + σ τ z t A lazy choice to determine the energy potential is using the Boltzmann distribution with the form:
p ( x ) = 1 Z exp { − U ( x ) } p(\bold{x}) = \frac{1}{Z} \exp \{ -U(\bold{x}) \} p ( x ) = Z 1 exp { − U ( x )} Therefore,
∇ x log p ( x ) = ∇ x ( − U ( x ) − log Z ) = − ∇ x U ( x ) \nabla_x \log p(x) = \nabla_x (-U(\bold{x}) - \log Z) = - \nabla_x U(\bold{x}) ∇ x log p ( x ) = ∇ x ( − U ( x ) − log Z ) = − ∇ x U ( x ) If we choose σ = 2 τ \sigma = \sqrt{\frac{2}{\tau}} σ = τ 2 , we will obtain:
x t + 1 = x t − τ ∇ x U ( x t ) + σ τ z t = x t + τ ∇ x log p ( x t ) + σ τ z t = x t + τ ∇ x log p ( x t ) + 2 τ z t \begin{align*} \bold{x}_{t+1} &= \bold{x}_t - \tau \nabla_x U(\bold{x}_t) + \sigma \tau \bold{z}_t \\ &= \bold{x}_t + \tau \nabla_x \log p(\bold{x}_t) + \sigma \tau \bold{z}_t \\ &= \bold{x}_t + \tau \nabla_x \log p(\bold{x}_t)+ \sqrt{2 \tau} \bold{z}_t\end{align*} x t + 1 = x t − τ ∇ x U ( x t ) + σ τ z t = x t + τ ∇ x log p ( x t ) + σ τ z t = x t + τ ∇ x log p ( x t ) + 2 τ z t which is the Langevin Dynamics.
4. Stein’s Score Function Definition 4.1 : (Stein’s score function)
s θ ( x ) = ∇ x log p θ ( x ) \bold{s}_{\theta}(\bold{x}) = \nabla_{\bold{x}} \log p_{\bold{\theta}}(\bold{x}) s θ ( x ) = ∇ x log p θ ( x ) Distinguish it from ordinary score function :
s x ( θ ) = ∇ θ log p θ ( x ) \bold{s}_{\bold{x}}(\theta) = \nabla_{\theta} \log p_{\theta}(\bold{x}) s x ( θ ) = ∇ θ log p θ ( x ) Example 4.1 : If p ( x ) p(x) p ( x ) is a Gaussian with p ( x ) = 1 2 π σ 2 e − ( x − μ ) 2 2 σ 2 p(x) = \frac{1}{\sqrt{2\pi \sigma^2}} e^{- \frac{(x-\mu)^2}{2\sigma^2}} p ( x ) = 2 π σ 2 1 e − 2 σ 2 ( x − μ ) 2 , then
s ( x ) = ∇ x log p ( x ) = ∇ x ( − ( x − μ ) 2 2 σ 2 ) = − ( x − μ ) σ 2 s(x) = \nabla_x \log p(x) = \nabla_x (-\frac{(x-\mu)^2}{2 \sigma^2}) = -\frac{(x - \mu)}{\sigma^2} s ( x ) = ∇ x log p ( x ) = ∇ x ( − 2 σ 2 ( x − μ ) 2 ) = − σ 2 ( x − μ )
5. Score-Matching Techniques s θ ( x ) = ∇ x p ( x ) \bold{s}_{\theta}(x) = \nabla_x p(\bold{x}) s θ ( x ) = ∇ x p ( x ) Problem : We don’t know p ( x ) p(\bold{x}) p ( x ) , so we can not calculate s θ ( x ) s_{\theta}(x) s θ ( x ) .
Goal : Calculate ∇ x p ( x ) \nabla_{\bold{x}} p(\bold{x}) ∇ x p ( x ) without knowing the real distribution p ( x ) p(\bold{x}) p ( x )
5.1 Explicit Score-Matching Use q ( x ) q(\bold{x}) q ( x ) as an approximation of the true data distribution p ( x ) p(\bold{x}) p ( x ) . We can use classical kernel density estimation to obtain q ( x ) q(\bold{x}) q ( x ) :
q ( x ) = 1 M ∑ m = 1 M 1 h K ( x − x m h ) q(\bold{x}) = \frac{1}{M} \sum_{m=1}^{M} \frac{1}{h} K\left(\frac{\bold{x} - \bold{x}_m}{h}\right) q ( x ) = M 1 m = 1 ∑ M h 1 K ( h x − x m ) where h h h is the hyperparameter for the kernel function K ( ⋅ ) K(\cdot) K ( ⋅ ) , and x m \bold{x}_m x m is the m-th sample in the training set.
Since q ( x ) q(\bold{x}) q ( x ) is an approximation of p ( x ) p(\bold{x}) p ( x ) , we can learn s θ ( x ) \bold{s}_{\theta}(\bold{x}) s θ ( x ) based on q ( x ) q(\bold{x}) q ( x ) . This leads to the explict score matching loss:
J ESM ( θ ) = E q ( x ) ∣ ∣ s θ ( x ) − ∇ x log q ( x ) ∣ ∣ 2 J_{\text{ESM}}(\theta) = \mathbb{E}_{q(\bold{x})}||\bold{s}_{\theta}(\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x})||^2 J ESM ( θ ) = E q ( x ) ∣∣ s θ ( x ) − ∇ x log q ( x ) ∣ ∣ 2 By substituting the kernel density estimation, we can show that the loss is:
J ESM ( θ ) = E q ( x ) ∣ ∣ s θ ( x ) − ∇ x log q ( x ) ∣ ∣ 2 = ∫ ∣ ∣ s θ ( x ) − ∇ x log q ( x ) ∣ ∣ 2 [ 1 M ∑ m = 1 M 1 h K ( x − x m h ) ] d x = 1 M ∑ m = 1 M ∫ ∣ ∣ s θ ( x ) − ∇ x log q ( x ) ∣ ∣ 2 1 h K ( x − x m h ) d x \begin{align*} J_{\text{ESM}}(\theta) &= \mathbb{E}_{q(\bold{x})}||\bold{s}_{\theta}(\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x})||^2 \\ &= \int ||\bold{s}_{\theta}(\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x})||^2 \left[ \frac{1}{M} \sum_{m=1}^{M} \frac{1}{h} K\left(\frac{\bold{x} - \bold{x}_m}{h}\right)\right] dx\\ &= \frac{1}{M} \sum_{m=1}^M \int ||\bold{s}_{\theta}(\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x})||^2 \frac{1}{h} K\left(\frac{\bold{x} - \bold{x}_m}{h}\right) dx \end{align*} J ESM ( θ ) = E q ( x ) ∣∣ s θ ( x ) − ∇ x log q ( x ) ∣ ∣ 2 = ∫ ∣∣ s θ ( x ) − ∇ x log q ( x ) ∣ ∣ 2 [ M 1 m = 1 ∑ M h 1 K ( h x − x m ) ] d x = M 1 m = 1 ∑ M ∫ ∣∣ s θ ( x ) − ∇ x log q ( x ) ∣ ∣ 2 h 1 K ( h x − x m ) d x Problem with ESM : When the sample size M M M is a large number, the computation of q ( x ) q(\bold{x}) q ( x ) is expensive. And when the sample size is limited and data is in a high dimensional space, the kernel density estimation can have poor performance.
5.2 Denoising Score Matching The key difference is that we replace the distribution q ( x ) q(\bold{x}) q ( x ) , by a conditional distribution q ( x ∣ x ′ ) q(\bold{x} | \bold{x}') q ( x ∣ x ′ ) :
J DSM ( θ ) = E q ( x , x ′ ) [ 1 2 ∣ ∣ s θ ( x ) − ∇ x log q ( x ∣ x ′ ) ∣ ∣ ] J_{\text{DSM}}(\theta) = \mathbb{E}_{q(\bold{x}, \bold{x}')} \left[ \frac{1}{2} ||\bold{s}_{\theta} (\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x} | \bold{x}')|| \right] J DSM ( θ ) = E q ( x , x ′ ) [ 2 1 ∣∣ s θ ( x ) − ∇ x log q ( x ∣ x ′ ) ∣∣ ] Specially, we set q ( x ∣ x ′ ) = N ( x ∣ x ′ , σ 2 ) q(\bold{x} | \bold{x}') = \mathcal{N}(\bold{x} | \bold{x}', \sigma^2) q ( x ∣ x ′ ) = N ( x ∣ x ′ , σ 2 ) , and we can let x = x ′ + σ z \bold{x} = \bold{x}' + \sigma \bold{z} x = x ′ + σ z . This leads to:
∇ x log q ( x ∣ x ′ ) = ∇ x log 1 ( 2 π σ 2 ) d exp ( − ∣ ∣ x − x ′ ∣ ∣ 2 2 σ 2 ) = ∇ x ( − ∣ ∣ x − x ′ ∣ ∣ 2 2 σ 2 − log ( 2 π σ 2 ) d ) = − x − x ′ σ 2 = − z σ \begin{align*} \nabla_{\bold{x}} \log q(\bold{x} | \bold{x}') &= \nabla_{\bold{x}} \log \frac{1}{(\sqrt{2 \pi \sigma^2})^d} \exp \left( -\frac{||x - x'||^2}{2 \sigma^2} \right) \\ &= \nabla_{\bold{x}} \left( -\frac{||x - x'||^2}{2 \sigma^2} - \log (\sqrt{2 \pi \sigma^2})^d \right) \\ &= -\frac{\bold{x} - \bold{x}'}{\sigma^2} = -\frac{\bold{z}}{\sigma} \end{align*} ∇ x log q ( x ∣ x ′ ) = ∇ x log ( 2 π σ 2 ) d 1 exp ( − 2 σ 2 ∣∣ x − x ′ ∣ ∣ 2 ) = ∇ x ( − 2 σ 2 ∣∣ x − x ′ ∣ ∣ 2 − log ( 2 π σ 2 ) d ) = − σ 2 x − x ′ = − σ z As a result, the loss function of the denoising score matching becomes:
J DSM ( θ ) = E q ( x , x ′ ) [ 1 2 ∣ ∣ s θ ( x ) − ∇ x log q ( x ∣ x ′ ) ∣ ∣ ] = E q ( x ′ ) [ 1 2 ∣ ∣ s θ ( x ′ + σ z ) + z σ ∣ ∣ 2 ] \begin{align*} J_{\text{DSM}}(\theta) &= \mathbb{E}_{q(\bold{x}, \bold{x}')} \left[ \frac{1}{2} ||\bold{s}_{\theta} (\bold{x}) - \nabla_{\bold{x}} \log q(\bold{x} | \bold{x}')|| \right] \\ &= \mathbb{E}_{q(\bold{x}')} \left[ \frac{1}{2} ||\bold{s}_{\theta}(\bold{x}' + \sigma \bold{z}) + \frac{\bold{z}}{\sigma}||^2 \right]\end{align*} J DSM ( θ ) = E q ( x , x ′ ) [ 2 1 ∣∣ s θ ( x ) − ∇ x log q ( x ∣ x ′ ) ∣∣ ] = E q ( x ′ ) [ 2 1 ∣∣ s θ ( x ′ + σ z ) + σ z ∣ ∣ 2 ] Replace the dummy variable x ′ x' x ′ by x x x , and note that sampling from q ( x ) q(x) q ( x ) can be replaced by sampling from p ( x ) p(x) p ( x ) give a training dataset. Then we can conclude the denoising score matching loss function:
J DSM ( θ ) = E p ( x ) [ 1 2 ∣ ∣ s θ ( x + σ z ) + z σ ∣ ∣ 2 ] J_{\text{DSM}}(\theta) = \mathbb{E}_{p(\bold{x})}\left[ \frac{1}{2} ||\bold{s}_{\theta}(\bold{x} + \sigma \bold{z}) + \frac{\bold{z}}{\sigma}||^2 \right] J DSM ( θ ) = E p ( x ) [ 2 1 ∣∣ s θ ( x + σ z ) + σ z ∣ ∣ 2 ] Remark : The above loss function is highly interpretable. The quantity x + σ z \bold{x} + \sigma \bold{z} x + σ z is effectively adding noise σ z \sigma \bold{z} σ z to a clean image x \bold{x} x . The score function is supposed to take this noisy image and predict the noise z σ \frac{\bold{z}}{\sigma} σ z .
The training step can simply describe as follow. Given a training dataset { x ( l ) } l = 1 L \{ \bold{x}^{(l)} \}^L_{l=1} { x ( l ) } l = 1 L , we train a network θ \theta θ with the goal to
θ ∗ = arg min θ 1 L ∑ l = 1 L 1 2 ∣ ∣ s θ ( x ( l ) + σ z ( l ) ) + z ( l ) σ ∣ ∣ 2 where z ( l ) ∼ N ( 0 , I ) \theta^* = \argmin_{\theta} \frac{1}{L} \sum_{l=1}^L \frac{1}{2} ||\bold{s}_{\theta} (\bold{x}^{(l)} + \sigma \bold{z}^{(l)}) + \frac{\bold{z}^{(l)}}{\sigma}||^2 \qquad \text{where } \bold{z}^{(l)} \sim \mathcal{N}(0, \bold{I}) θ ∗ = θ arg min L 1 l = 1 ∑ L 2 1 ∣∣ s θ ( x ( l ) + σ z ( l ) ) + σ z ( l ) ∣ ∣ 2 where z ( l ) ∼ N ( 0 , I )