The other day I went to the cinema and watched M3GAN, a true movie masterpiece1 about the death and carnage that ensues when you simply train your extremely complex ML model and don’t do proper ethics work. And that, of course, made me want to write a little bit about something relatively hip, hop, and happening2 in the ML/AI space. But, like, I’m not gonna be that on trend3 because fuck that noise, so I’m gonna talk about diffusion models.
It’s worth noting that I know bugger all about diffusion models. But when they first came out, I had a quick look at how they worked and then promptly forgot about them because, let’s face it, I work on different things. But hey. If that’s not enough4 knowledge to write a blog post, I don’t know what is.
And here’s the thing. Most of the time when I blog about something I know a lot about it. Sometimes too much. But this is not one of those times. There are plenty of resources on the internet if you want to learn about diffusions models from an expert. Oodles. But where else but here can you read the barely proof-read writing of a man who read a couple of papers yesterday?
And who doesn’t want5 that?
A prelude: Measure transport for sampling from arbitrary distributions
One of the fundamental tasks in computational statistics is to sample from a probability distribution. There are millions of ways of doing this, but the most popular generic method is Markov chain Monte Carlo. But this is not the post about MCMC methods. I’ve already made a post about MCMC methods.
Instead, let’s focus on stranger ways to do it. In particular, let’s think about methods that, create a mapping \(T: \mathbb{R}^d \rightarrow \mathbb{R}^d\) that may depend on some properties of the target distribution such that the following procedure constructs a sample \(x \sim p(x)\):
- Sample \(u \sim p(u)\) for some known distribution \(q(u)\)
- Set \(x = T(u)\)
The general problem of starting with a distribution \(q(\cdot)\) and mapping it to another distribution \(p(\cdot)\) is an example of a problem known as measure transport. Transport problems have been studied by mathematicians for yonks. It turns out that there are an infinite number of mappings \(T\) that will do the job, so it’s up to us to choose a good one.
Probably the most famous6 transport problem is the optimal transport problem that was first studied by Monge and Kantorovich that tries to find a mapping \(T\) that minimises \[ \mathbb{E}_{x \sim q}(c(x, T(x))) \] subject to the constraint that \(T(x) \sim p\) whenever \(x \sim q\), where \(c(x,y)\) is some sort of cost function. There are canonical choices of cost function, but for the most part we are free to choose something that is convenient.
The measure transport concept is underneath the method of normalising flows, but the presentation that I’m most familiar with is due to Youssef Marzouk and his collaborators in 2011 and predates the big sexy normalising flow papers by a few years.
Continuous distributions in 1D
If \(p\) and \(q\) are both continuous, univariate distributions, it is pretty easy to construct a transport map. In particular, if \(F_p\) is the cumulative distribution function of \(p\), then \[ T(x) = F_p^{-1}(F_q(x)) \] is a transport map. This works because, if \(x \sim q\), then \(F_q(x) \sim \text{Unif}(0,1)\). From this, we can use everyone’s favourite result that you can sample from a continuous univariate random variable \(p\) by evaluating the quantile function at a uniform random value.
There are, of course, two problems with this: it only works in one dimension and we usually don’t know \(F^{-1}\) explicitly.
The second of these isn’t really a problem if we are willing to do something splendifferously dumb. And I am. Because I’m gay and frivolous7.
If I write \(Q(t) = F^{-1}(t)\) then I can differentiate this to get \[ \frac{dQ}{dt} = \frac{1}{p(Q)},\qquad Q(0) = -\infty. \] This is a very non-linear differential equation. We can make it even more non-linear differential equation by repeating the procedure to get \[ \frac{d^2Q}{dt^2} = \frac{1}{p(Q)^2} p'(Q)\frac{dQ}{dt}. \] Noting that \(Q' = 1/p(Q)\) we get \[ \frac{d^2 Q}{dt^2} = \frac{p'(Q)}{p(Q)} \left(\frac{dQ}{dt}\right)^2. \] This is a rubbish differential equation, but it has the singular advantage that it doesn’t depend8 on the normalising constant for \(p\), which can be useful. The downside is that the boundary conditions are infinite on both ends.
Regardless of that particular challenge, we could use this to build a generic algorithm.
Sample \(u \sim \text{Unif}(0,1)\)
Use a numerical differential equation solver to solve the equation with boundary conditions \[ q(0) = -M, \quad q(1) = M \] for some sufficiently large number \(M\) and return \(x = q(u)\)
This will sample from \(p(x)\) truncated to \([-M, M]\).
I was going to write some python code to do this, but honestly it hurts my soul. So I shan’t.
Transport maps: A less terrible method that works on general densities
Outside of one dimension, there is (to the best of my knowledge) no direct solution to the transport problem. That means that we need to solve our own. Thankfully, the glorious Youssef Marzouk and a bunch of his collaborators have spent some quality time mapping out this idea. A really nice survey of their results can be found in this paper.
Essentially the idea is that we can try to find the most convenient transport map available to us. In particular, it’s useful to minimise the Kullback-Leibler divergence between \(q\) and its transport. After a little bit9 of maths, this is equivalent to minimising \[ \mathbb{E}_{x \sim q}\left(\log p(T(x)) + \log \det \nabla T(x)\right), \] where \(\nabla T(x)\) is the Jacobian of \(T\). To finish the specification of the optimisation problem, it’s enough to consider triangular maps10 \[ T(x) = \begin{pmatrix} T_1(x_1) \\ T_2(x_1,x_2,) \\ \vdots \\ T_d(x_1, \ldots, x_d) \end{pmatrix} \] with the additional constraint that their Jacobians have positive determinants. Using a triangular map has two distinct advantages: it’s parsimonious and it makes the positive determinant constraint much easier to deal with. Triangular maps are also sufficient for the problem (my man Bogachev showed it in 2005).
That said, this can be a somewhat tricky optimisation problem. Youssef and his friends have spilt a lot of ink on this topic. And if you’re the sort of person who just fucking loves a weird optimisation problem, I’m sure you’ve got thoughts. With and without the triangular constraint, this can be parameterised as the composition of a sequence of simple functions, in which case you turn three times and scream neural net and a normalising11 flow appears.
What if we only have samples from the target density
All of that is very lovely. And quite nice in its context. But what happens if you don’t actually have access of the (unnormalised) log density of the target? What if you only have samples?
The good news is that you’re not shit out of luck. But it’s a bit tricky. And once again, that lovely review paper by Youssef and friends will tell us how to do it.
In particular, they noticed that if you swap the direction of the KL divergence, you get the optimisation problem for the inverse mapping \(S(x) = T^{-1}(x)\) that aims to minimise \[ \mathbb{E}_{x \sim p}\left(\log(q(S(x)) + \log \det \nabla S(x)\right) \] where \(S\) is once again a triangular map subject to the monotonicity constraints \[ \frac{\partial S_k}{\partial x_k} > 0. \] Because we have the freedom to choose the reference density \(q(x)\), we can choose it to be iid standard normals, in which case we get the optimisation problem \[\begin{align*} &\min_S \mathbb{E}_{x \sim p}\left[\sum_{k = 1}^d \frac{1}{2}\left(S_k(z_1, \ldots, s_k)\right)^2 - \log \frac{\partial S_k}{\partial x_k} \right]\\ &\text{s.t.}& \\ &\quad \frac{\partial S_k}{\partial x_k} >0 \\ &\quad S \text{ is triangular}, \end{align*}\] which is a convex, separable optimisation problem that can be solved12 using, for instance, a stochastic gradient method. This can be turned into an unconstrained optimisation problem by explicitly parameterising the monotonicity constraint.
The monotonicity of \(S\) makes the resulting nonlinear solve to compute \(T = S^{-1}\) relatively straightforward. In fact, if \(d\) isn’t too big you can solve this sequentially dimension-by-dimension. But, of course, when you’ve got a lot of parameters this is a poor method and it would make more sense13 to attack it with some sort of gradient descent method. It might even be worth taking the time to learn the inverse function \(T = S^{-1}\) so that can be applied for, essentially, free.
So does it work?
To some extent, the answer is yes. This is very much normalising flows in its most embryonic form. They work to some extent. And this presentation makes some of the problems fairly obvious:
There’s no real guarantee that \(T\) is going to be a nice smooth map, which means that we may have problems moving beyond the training sample.
The most natural way to organise the computations are naturally sequential involving sweeps across the \(d\) parameters. This can be difficult to parallelise efficiently on modern architectures.
The complexity of the triangular map is going to depend on the order of variables. This is fine if you’re processing something that is inherently sequential, but if you’re working with image data, this can be challenging.
Of course, there are a pile of ways that these problems can be overcome in whole or in part. I’d point you to the last five years of ML conference papers. You’re welcome.
Continuous normalising flows: Making the problem easier by making it harder
A really clever idea, which is related to normalising flows, is to ask what if, instead of looking for a single14 map \(S(x) = T^{-1}(x)\), we tried to find a sequence of maps \(S(x,t)\) that smoothly move from the identity map to to the transport map.
This seems like it would be a harder problem. And it is. You need to make an infinite number of maps. But the saving grace is that as \(t\) changes slightly, the map \(S(\cdot, t)\) is also only going to change slightly. This means that we can parameterise the change relatively simply.
To this end, we write \[ \frac{\partial S}{\partial t} = f(S, t), \] for some relatively simple function \(f\) that models the infinitesimal change in the transport map as we move along the path. The hope is that learning the vector field \(f\) will be easier than learning \(S\) directly. To finish the specification, we require that \[ S(x,0) = x. \]
The question is _can we learn the function \(f\) from data? If we can, it will be (relatively) easy to evaluate the transport map for any sample by just solving15 the differential equation.
It turns out that the map \(S\) is most useful for training the normalising flow, while \(T\) is useful for generating samples from the trained model. If we were using the methods in the previous section, we would have had to commit to either modelling \(S\) or \(T\). One of the real advantages of the continuous formulation is that we can just as easily solve the equation with the terminal condition16 \[ S(x,1) = u \] and solve the equation backwards in time to calculate \(T(u) = S(x, 0)\)! The dynamics of both equations are driven by the vector field \(f\)!
A very quick introduction to inverse problems
It turns out that learning parameters of differential equation (and other physical models) has a long and storied history in applied mathematics under the name of inverse problems. If that sounds like statistics, you’d be right. It’s statistics, except with no interest in measurement or, classically, uncertainty.
The classic inverse problem framing involves a forward map \(\mathcal{F}(f)(t, x)\) that takes as its input some parameters (often a function) and returns the full state of a system (often another function). For instance, the forwards map could be the solution of a partial differential equation like \[ \frac{\partial S}{\partial t} = f(S, t), \qquad S(0) = x. \] The thing that you should notice about this is that the forward map is a) possibly expensive to compute, b) not explicitly known, and c) extremely17 non-linear.
The problem is specified with \(n\) data points \((t_1, x_1), \ldots, (t_n, x_n)\) and the aim is to find the value of \(f\) that best fits the data. The traditional choice is to minimise the mean-square error \[ \theta = \arg \min_\theta \sum_{i=1}^n \left(y_i - \mathcal{F}(f)(t_i,x_i)\right)^2. \]
Now every single one of you will know immediately that this question is both vague and ill-posed. There are many functions \(f\) that will fit the data. This means that we need to enforce18 some sort of complexity penalty on \(f\). This leads to the method known as Tikhonov regularisation19 \[ \theta = \arg \min_{\theta \in B} \sum_{i=1}^n \left(y_i - \mathcal{F}(f)(t_i,x_i)\right)^2 + \lambda\|f\|_B^2, \] where \(B\) is some Banach space and \(\lambda>0\) is some tuning parameter.
As you can imagine, there’s a lot of maths under this about when there is a unique minimum, how the reconstruction behaves as \(n\rightarrow \infty\) and \(\lambda \rightarrow 0\), and how the choice of \(B\) effects the estimation of \(\theta\). There is also quite a lot of work20 looking at how to actually solve these sorts of optimisation problems.
Eventually, the field evolved and people started to realise that it’s actually fairly important to quantify the uncertainty in the estimate. This is … tricky under the Tikhonov regularlisation framework, which became a big motivation for Bayesian inverse problems.
As with all Bayesianifications, we just need to turn the above into a likelihood and a prior. Easy. Well, the likelihood part, at least, is easy. If we want to line up with Tikhonov regularisation, we can choose a Gaussian likelihood \[ y_i \mid f, x_i, t_i, \sigma \sim N(\mathcal{F}(f)(t_i,x_i), \sigma^2). \]
This is familiar to statisticians, the forward model is essentially working as a non-standard link function in a generalised linear model. There are two big practical differences. The first one is that \(\mathcal{F}\) is very non-linear and almost certainly not monotone. The second problem is that evaluations of \(\mathcal{F}\) are typically very21 expensive. For instance, you may need to solve a system of differential equations. This means that any computational method22 is going to need to minimise the number of likelihood evaluations.
The choice of prior on \(f\) can, however, be a bit tricky. The problem is that in most traditional inverse problems \(f\) is a function23 and so we need to put a carefully specified prior on it. And there is a lot of really interesting work on what this means in a Bayesian setting. This is really the topic for another blogpost, but it’s certainly an area where you need to be aware of the limitations of different high-dimensional priors and how they perform in various contexts. For instance, if the function you are trying to reconstruct is likely to have a lot of sharp boundaries24 then you need to make sure that your prior can support functions with sharp boundaries. My little soldier bois25 don’t, so you need to get more26 creative.
The likelihood for a normalising flow
Our aim now is to cast the normalising flow idea into the inverse problems framework. To do this, we remember that we begin our flow from a sample from \(p(x)\) and we then deform it until it becomes a sample from \(q(u)\) at some known time (which I’m going to choose as \(t=1\)). This means that if \(x_i \sim p\), then \[ S(x_i, 1) \sim q. \]
We can now derive a relationship between \(p\) and \(q\) using the change of variables formula. In particular, \[ p(x \mid f) = q(S(x,1))\left|\det \left( \frac{d S(x,1)}{dx }\right)\right|, \] which means that our log likelihood will be \[ \log p(x \mid f) = \log q(S(x,1)) + \log \left|\det \left( \frac{d S(x,1)}{dx }\right)\right|. \]
The log-determinant term looks like it might cause some trouble. If \(S\) is parameterised as a triangular map it can be written explicitly, but there is, of course, another route.
For notational ease, let’s consider \(z_t = S(x, t)\), for some \(t <1\). Then \[ \log p(z_t \mid f) = \log q(S(x,1)) + \log \left|\det \left( \frac{d S(x,t)}{dx }\right)\right|. \] We can differentiate this with respect to \(t\) to get 27 to get \[ \frac{\partial \log p(z_t \mid f)}{\partial t} = \operatorname{tr}\left(\frac{df}{dx}(z_t,t)\right), \] where I used one of those magical vector calculus identities to get that trace. Remembering that \(S(x,0) = x\), the log-determinant of the Jacobian at zero is zero and so we get the initial condition \[ \log p(z_t \mid f) = \log q(S(x,1)). \]
The likelihood can be evaluated28 by solving the system of differential equations \[\begin{align*} \frac{d z_t}{dt} &= f(z_t, t) \\ \frac{d \ell}{dt} &=\operatorname{tr}\left(\frac{df}{dx}(z_t,t)\right) \\ z_0 &= x \\ \ell(0) &= 0, \end{align*}\] and the log likelihood is evaluated as \[ \log p(x \mid f) = \log q(z_1) + \ell(1). \]
It turns out that you can take gradients of the log-likelihood efficiently by solving an augmented system of differential equations that’s twice the size of the original. This allows for all kinds of gradient-driven inferential shenanigans.
But oh that complexity
One big problem with normalising flows as written is that we only have two pieces of information about the entire trajectory \(z_t\):
we know that \(z(1) \sim q\), and
we know that \(z(0) \sim p\).
We know absolutely nothing about \(z_t\) outside of those boundary conditions. This means that our model for \(f\) basically gets to freestyle in those areas.
We can avoid this to some extent by choosing appropriate neural network architectures and/or appropriate penalties in the classical case or priors in the Bayesian case. There’s a whole mini-literature on choosing appropriate penalties.
Just to show how complex it is, let me quickly sketch what Finlay etc suggest as a way to keep the dynamics as boring as possible in the information desert. They lean into the literature on optimal transport theory to come up with the double penalty \[ \min_f \sum_{i=1}^n \left(-\log p(x_i) + \lambda_1 \int_0^T \|f(S(x_i,s),s)\|_2^2\,ds + \lambda_2 \int_0^T\left\|\frac{d f(S(x_i,s))}{ds}\right\|_F^2\,ds\right), \] where the first term minimises the kinetic energy and, essentially, finds the least exciting path from \(p\) to \(q\), while the second term ensures that the Jacobian of \(f\) doesn’t get too big29, which means that the mapping doesn’t have many sharp changes. Both of these penalty terms are designed to both aid generalisation and to make sure the differential equation isn’t unnecessarily difficult for a ODE solver.
A slightly odd feature of these penalties is that they are both data dependent. That suggests that a prior would, probably, require an amount of work. This is work that I don’t feel like doing today. Especially because this blog post isn’t about bloody normalising flows.
Diffusion models
Ok, so normalising flows are cool, but there are a couple of places where they could potentially be improved. There is a long literature on diffusion models, but the one I’m mostly stealing from is this one by Song et al. (2021)
Firstly, the vector field \(f\) directly effects how easy the differential equations are to solve. This means that if \(f\) is too complicated, it can take a long time to both train the model and generate samples from the trained model. To get around this you need to put fairly strict penalties30 and/or structural assumptions on \(f\).
Secondly, we only have information31 at two ends of the flow. The problem would become a lot easier if we could somehow get information about intermediate states. In the inverse problems literature, there’s a concept of value of information that talks about how useful sampling a particular time point can be in terms of reducing model uncertainty. In general this, or other criteria, can be used to design a set of useful sampling times. I don’t particularly feel like working any of this out but one thing I am fairly certain of is that no optimal design would only have information at \(t=0\) and \(t=1\)!
Diffusion models fix these two aspects of normalising flows at the cost of both a more complex mathematical formulation and some inexactness32 around the base distribution \(q\) when generating new samples.
Diffusions and stochastic differential equations
Diffusions are to applied mathematicians what gaffer tape is to33 a roadie. They are a ubiquitous, convenient, and they hold down the fort when nothing else works.
There are a number of diffusions that are familiar in statistics and machine learning. The most famous one is probably the Langevin diffusion \[ dX_t = \frac{1}{2}\nabla \log p(x) dt + \sigma dW_t, \] which is asymptotically distributed according to \(p\). This forms the basis of a bunch of MCMC methods as well as some faster, less adjusted methods.
But that is not the only diffusion. Today’s friend is the Ornstein-Uhlenbeck (OU) process, which is a Gaussian process that \[ dX_t = - \frac{1}{2} X_t \,dt + \sigma dW_t. \] The OU process can be thought of as a mean-reverting Brownian motion. As such, it has continuous but nowhere differentiable sample paths
The stationary distribution of \(X_t\) is \(X_\infty \sim N(0, \sigma^2I)\), where \(I\) is the identity matrix. In fact, if we start the diffusion at stationarity by setting \[ X_0 \sim N(0, \sigma^2I), \] then X_t is a stationary Gaussian process with covariance function \[ c(t, t') = \sigma^2e^{-\frac{1}{2} |t-t'|}I. \]
More interestingly in our context, however, is what happens if we start the diffusion from a fixed point \(x\), that will eventually be a sample from \(p(x)\). In that case, we can solve the linear stochastic differential equation exactly to get \[ X_t = xe^{-\frac{1}{2}t} + \sigma \int_0^t e^{\frac{1}{2}(s-t)}\,dW_s, \] where the integral on the right hand side can be interpreted34 as a white noise integral and so \[ X_t \sim N\left(xe^{-t}, \sigma^2\int_0^t e^{s-t}\,dt\right), \] and the variance is \[ \sigma^2\int_0^t e^{s-t}\,dt = \sigma^2 e^{-t}\frac{1}{2}\left(e^{t} - 1\right) = \sigma^2(1-e^{-t}). \] From these equations, we see that the mean of the diffusion hurtles exponentially fast towards zero and the variance moves at the same speed towards \(\sigma^2\).
More importantly, this means that, given a starting point \(X_0 = x\), we can generate data from any part of the diffusion \(X_t\)! If we want a sequence of observations from the same trajectory, we can generate them sequentially using the fact that and OU process is a Markov35 process. This means that we are no longer limited to information at just two points along the trajectory.
Reversing the diffusion
So far, there is nothing to learn here. The OU process has a known drift and variance, so everything is splendid. It’s even easy to simulate from. The challenge pops up when we try to reverse the diffusion, that is, when we try to remove noise from a sample rather than add noise to it.
In some sense, this shouldn’t be too disgusting. A diffusion is a Markov process and, if we run the Markov process back in time, we still get a Markov process. In fact, we are going to get another diffusion process.
The twist is that the new diffusion process is going to be quite a bit more complex than the original one. The problem is that unless \(X_0\) comes from a Gaussian distribution, this process will be non-Gaussian, and thus somewhat tricky to find the reverse trajectory of.
To see this, consider \(s>t\) and recall that \[ p(X_0, X_t, X_s) = p(X_s \mid X_t)p(X_t \mid X_0)p(X_0) \] and \[ p(X_t, X_s) = \int_{\mathbb{R}^d} p(X_s \mid X_t) p(X_t \mid X_0) p(X_0)\,dX_0. \] The first two terms in that integrand are Gaussian densities and thus their product is a bivariate Gaussian density \[ X_t, X_s \mid X_0 \sim N\left(X_0\begin{pmatrix}e^{-\frac{t}{2}}\\e^{-\frac{s}{2}}\end{pmatrix}, \sigma^2 \begin{pmatrix} 1 & e^{-\frac{s-t}{2}} - e^{-\frac{s+t}{2}} \\ e^{-\frac{s-t}{2}} - e^{-\frac{s+t}{2}} & 1\end{pmatrix}\right). \] Unfortunately, as \(X_0\) is not Gaussian, the marginal distribution will be non-Gaussian. This means that our reverse time transition density \[ p(X_t \mid X_s) = \frac{ \int_{\mathbb{R}^d} p(X_t,X_s \mid X_0) p(X_0)\,dX_0}{ \int_{\mathbb{R}^d} p(X_s \mid X_t) p(X_0)\,dX_0} \] is also going to be very non-linear.
In order to work out a stochastic differential equation that runs backwards in time and generates the same trajectory, we need a little bit of theory on how the unconditional density \(p(X_t)\) and the transition density \(p(X_t \mid X_s)\) evolves in time \(t\) (here and everywhere st). These are related through the Kolmogorov equations.
To introduce these, we need to briefly consider the more general diffusion \[ dX_t = f(X_t, t)dt + g(X_t,t)dW_t \] for nice36 vector/matrix-valued functions \(f\) and \(g\). Kolmogorov showed that the unconditional density \(p(X_t) = p(x,t)\) evolves according the the partial differential equation \[ \frac{\partial p(x,t)}{\partial t} = - \sum_{i=1}^d \frac{\partial}{\partial x_i}\left(f_i(x,t)p(x,t)\right) + \frac{1}{2}\sum_{i,j,k = 1}^d\frac{\partial^2}{\partial x_j}\left( g_{ik}(x,t)g_{jk}(x,t)p(x,t)\right) \] subject to the initial condition \[ p(x,0) =p(x). \] This is known as Kolmogorov’s forward equation or the Fokker-Planck equation.
The other key result is about the density of \(X_t\) conditioned on some future value \(X_s = y\), \(s \geq t\). We write this density as \(p(X_s =y\mid X_t =x) =p(x,t; u,s)\) and it satisfies the partial differential equation \[ \frac{\partial q(x,t;u,s)}{\partial t} = -\sum_{i=1}^d f_i(x,t)\frac{\partial q(x,t;u,s)}{\partial x_i} - \frac{1}{2}\sum_{i,j,k=1}^d g_{ik}(x,t)g_{jk}(x,t)\frac{\partial^2 q(x,t;u,s)}{\partial x_i\partial x_j} \] subject to the terminal condition \[ p(x,s;u,s) = p(u,s). \] This is known as the Kolmogorov backward equation. Great names. Beautiful names.
Let’s consider a differential equation for the joint density \[ p(X_t = x, X_s= y) = p(x,t,u,s) = q(x,t;u,s)p(x,t). \] Going ham with the product rule gives \[ \begin{align*} \frac{\partial p(x,t,u,s)}{\partial t} &= p(x, t)\frac{\partial q(x,t;u,s)}{\partial t} + q(x,t;u,s) \frac{\partial p(x,t)}{\partial t} \\ &=-\sum_{i=1}^d p(x,t)f_i(x,t)\frac{\partial q(x,t;u,s)}{\partial x_i} - \frac{1}{2}\sum_{ijk} p(x,t)g_{ik}(x,t)g_{jk}(x,t) \frac{\partial^2 q(x,t;u,s)}{\partial x_i \partial x_j} \\ &\qquad-\sum_{i=1}^dq(x,t;u,s)\frac{\partial}{\partial x_i}(p(x,t)f(x,t)) + \frac{1}{2} \sum_{ijk}q(x,t;u,s)\frac{\partial^2}{\partial x_i \partial x_j}(g_{ik}(x,t)g_{jk}(x,t)p(x,t)) . \end{align*} \tag{1}\] The first-order derivatives simplify, using the product rule, to \[ -\sum_{i=1}^d\frac{\partial}{\partial x_i}(p(x,t,u,s)f(x,t)) \]
Staring at this for a moment, we notice that this looks has the same structure as the first-order term on the forward equation. In that case, the second-order term would be \[ \begin{align*} &\frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial^2}{\partial x_i x_j}[p(x,t,u,s) g_{ik}(x,t)g_{jk}(x,t)] = \frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial^2}{\partial x_i x_j}[q(x,t;u,s) (p(x,t)g_{ik}(x,t)g_{jk}(x,t))] \\ &\qquad\qquad=\frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial}{\partial x_i}\left[ q(x,t;u,s)\frac{\partial }{\partial x_j}\left(p(x,t)g_{ik}(x,t)g_{jk}(x,t)\right) + p(x,t)g_{ik}(x,t)g_{jk}(x,t) \frac{\partial q(x,t;u,s)}{\partial x_j}\right] \end{align*} \]
If we notice that \[ \begin{align*} \frac{\partial}{\partial x_i}\left[p(x,t)g_{ik}(x,t)g_{jk}(x,t) \frac{\partial q(x,t;u,s)}{\partial x_j}\right] =& p(x,t)g_{ik}(x,t)g_{jk}(x,t) \frac{\partial^2 q(x,t;u,s)}{\partial x_i \partial x_j} \\ &\quad+ \frac{\partial}{\partial x_i} \left[p(x,t)g_{ik}(x,t)g_{jk}(x,t)\right]\left[ \frac{\partial q(x,t;u,s)}{ \partial x_j}\right] \end{align*} \] and \[ \begin{align*} \frac{\partial}{\partial x_i}\left[ q(x,t;u,s)\frac{\partial }{\partial x_j}\left(p(x,t)g_{ik}(x,t)g_{jk}(x,t)\right)\right] =& q(x,t;u,s)\frac{\partial^2 }{\partial x_i \partial x_j} p(x,t)g_{ik}(x,t)g_{jk}(x,t) \\ &\quad+ \frac{\partial}{\partial x_i} \left[p(x,t)g_{ik}(x,t)g_{jk}(x,t)\right]\left[ \frac{\partial q(x,t;u,s)}{ \partial x_j}\right] \end{align*} \] we can re-write the second-order derivative terms in Equation 1 as \[ \frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial}{\partial x_i}\left[ q(x,t;u,s)\frac{\partial }{\partial x_j}\left(p(x,t)g_{ik}(x,t)g_{jk}(x,t)\right) - p(x,t)g_{ik}(x,t)g_{jk}(x,t) \frac{\partial q(x,t;u,s)}{\partial x_j}\right] \]
This is almost, but not quite, what we want. We are a single minus sign away. Remembering that \(q(x,t;u,s) = p(x,t,u,s)/p(x,t)\) we probably don’t want it to turn up in any derivatives37. To this end, let’s make the substitution \[ \begin{align*} \frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial}{\partial x_i}\left[ p(x,t)g_{ik}(x,t)g_{jk}(x,t) \frac{\partial q(x,t;u,s)}{\partial x_j}\right] =& \frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial^2}{\partial x_i\partial x_j}[p(x,t,u,s) g_{ik}(x,t)g_{jk}(x,t)]\\ & -\frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial}{\partial x_i}\left[ q(x,t;u,s)\frac{\partial }{\partial x_j}\left(p(x,t)g_{ik}(x,t)g_{jk}(x,t)\right) \right]. \end{align*} \] With this substitution the second order terms are \[ \sum_{i=1}^d\frac{\partial}{\partial x_i}\left[ p(x,t,u,s) h(x,t)\right] - \frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial^2}{\partial x_i\partial x_j}[p(x,t,u,s) g_{ik}(x,t)g_{jk}(x,t)], \] where \[ h(x,t) = \frac{1}{p(x,t)}\sum_{j,k=1}^d\frac{\partial}{\partial x_j}\left[p(x,t)g_{ik}(x,t)g_{jk}(x,t))\right]. \]
If we write \[ [\bar{f}(x,t)]_i = f(x,t) - h(x,t) = f(x,t) - \frac{1}{p(x,t)}\sum_{j,k=1}^d\frac{\partial}{\partial x_j}\left[p(x,t)g_{ik}(x,t)g_{jk}(x,t))\right], \] we get the joint PDE \[ \frac{\partial p(x,t,u,s)}{\partial t} = -\sum_{i=1}^d\frac{\partial}{\partial x_i}[p(x,t,u,s)\bar{f}(x,t)] - \frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial^2}{\partial x_i\partial x_j}[p(x,t,u,s) g_{ik}(x,t)g_{jk}(x,t)]. \tag{2}\]
In order to identify the reverse time diffusion, we are going to find the reverse time backward equation, which confusingly, is for \[ q(u,s; x,t) =\frac{p(X_t = x, X_s =y))}{p(X_s =y)} =\frac{p(x,t,s,y)}{p(u,s)}. \] As \(p(u,s)\) is a constant in both \(x\) and \(t\), we can divide both sides of Equation 2 by it to get \[ \frac{\partial q(x,t;u,s)}{\partial t} = -\sum_{i=1}^d\frac{\partial}{\partial x_i}[q(x,t;u,s)\bar{f}(x,t)] - \frac{1}{2}\sum_{i,j,k=1}^d\frac{\partial^2}{\partial x_i\partial x_j}[q(x,t;u,s) g_{ik}(x,t)g_{jk}(x,t)]. \] where again \(s>t\) and \(s\) and \(y\) are known.
This is the forward Kolmogorov equation for the time-reversed38 diffusion \[ dX_t = \bar{f}(X_t, t)dt + g(X_t, t)d\tilde{W}_t, \qquad X_s = u, \] where \(d \tilde{W}_t\) is another white nose. Anderson (1982) shows how to connect the white noise \(dW_t\) that’s driving the forward dynamics with the white noise that’s driving the reverse dynamics \(d\tilde{W}_t\), but that’s overkill for our present situation.
In the context of an OU process, we get the reverse equation \[ dX_t= -\left[\frac{1}{2} X_t + \sigma^2 \nabla \log p(X_t, t)\right]\,dt + \sigma\, dW_t, \] where time runs backwards and I’ve used the formula for the logarithmic derivative.
Unlike the forward process, the reverse process is the solution to a non-linear stochastic differential equation. In general, this cannot be solved in closed form and we need to use a numerical SDE solver to generate a sample.
It’s worth noting that the OU process is an overly simple cartoon of a diffusion model. In practice, \(\sigma = \sigma_t\) is usually an increasing function of time so the system injects more noise as the diffusion moves along. This changes some of the exact equations slightly, but you can still sample \(X_t \mid X_0\) analytically for any \(t\) (as long as you choose a fairly simple function for \(\sigma_t\)). There is a large literature on these choices and, to be honest, I can’t be bothered going through them here. But obviously if you want to implement a diffusion model yourself you should look this stuff up.
Estimating the score
The reverse dynamics are driven by the score function \[ s_t(x) = \nabla \log(p(x,t)). \] Typically, we do not know the density \(p(x,t) = p(X_t= x \mid X_0 = x_0)\) and while we could solve the forward equation in order to estimate it, that is wildly inefficient in high dimensions.
If we can assume that for each \(t\), \(X_t \mid X_0=x_0\) is approximately \(N(\mu_t, \Sigma_t)\), then the resulting reverse diffusion is linear \[ dX_t = \left[\Sigma_t^{-1}\mu_t -\left(\frac{1}{2} I + \sigma^2\Sigma_t^{-1} \right)X_t\right]dt + \sigma dW_t, \qquad X_T = u. \] In this case \(X_t \mid X_T = u\) is Gaussian with a mean and covariance that has closed form solution in terms of \(\Sigma_t\) and \(\mu_t\) (perhaps after some numerical quadrature and matrix exponentials).
Unfortunately, as discussed above this is not true. A better approximation would be a mixture of Gaussians but, in general, we can use any method to approximate \[ s_t(x,t). \] There are no particular constraints on it, except we expect it to be fairly smooth39 in both \(t\) and \(x\). Hence, we can just learn the score.
As we are going to solve the SDE numerically, we only need to estimate the score at a finite set of locations. In every application that I’ve seen, these are pre-specified, however it would also be possible to use a basis function expansion to interpolate to arbitrary time points. But, to be honest, I think every single example I’ve seen just uses a regularly spaced grid.
So how do we estimate \(s_t\)? Well just like every other situation, we need to define a likelihood (or, I guess, an optimisation criterion). One way to think about this would be to note that you’ll never perfectly recover the initial signal. This is because we need to solve a non-linear stochastic partial differential equation and there will, inherently, be noise in that solution. So instead, assume that we have an initial sample \(x_0 \sim p(X_0)\) and that after solving the backward equation we have an unbiased estimator of \(x_0\) with standard deviation \(\tau_N\), where \(N\) is the number of time steps. We know a lot about how the error of SDE solvers scale with \(N\) and so we can use that to set an appropriate scale for \(\tau_N\). For instance, if you’re using the Euler–Maruyama method, then it has strong order \(1/2\) and \(\tau_N = \mathcal{O}(N^{-1/2})\) would likely be an appropriate scaling.
This strongly suggests a likelihood that looks like \[ \hat{X}_0(x_0, t) \mid s_t, x_0, t \sim N(x_0, \tau_N^2), \] where \(\hat{X}_0(x_0,t)\) is the estimate of \(X_0\) you get by running the reverse diffusion conditioned on \(\hat{X}_t = X_t(x_0)\), where \(X_t(x_0)\) is an exact sample at time \(t\) from the forward diffusion started at \(X_0 = x_0\).
This is the key to the success of diffusion models: given our training sample \(\{x_0^{(i)}\}_{i=1}^n\), we generate new data \(x_t(x_0)\) and we can generate as much of that data as we want. Furthermore, we can choose any set of \(t\)s we want. We can sample a single \((t, x_0)\) pair multiple times or we can look at a diversity of sampling data.
We can even try to recover an intermediate state \(\hat{X}_{t_1}(x_0,t_2)\) from information about a future state \(X_{t_2}(x_0)\), \(t_2 >t_1 \geq 0\). This gives us quite the opportunity to target our learning to areas of the \((t,x)\) space where we have relatively poor estimates of the score function.
Of course, that’s not what people do. They do stochastic gradient descent to minimise \[ \min_{s_t}\mathbb{E}_{x_0 \sim p(X_0), t \sim \text{Unif}[0,1]}\left(\|x_0 - \hat{X}_0(x_0,t)\|^2\right) \] possibly subject to some penalties on \(s_t\). In fact, the distribution on \(t\) is usually a discrete uniform. As with any sufficiently complex task, there is a lot of detailed work on exactly how to best parameterise, solve, and evaluate this optimisation procedure.
Generating samples
Once the model is trained and we have an estimate \(\hat{s}_t\) of the score function, we can generate new samples by first sampling \(u \sim N(0, \sigma^2)\) and running the reverse diffusion starting from \(X_t = u\) for some sufficiently large \(t\). One of the advantages of using a variant of the OU process with a non-constant \(\sigma\) is that we can choose \(t\) to be smaller. Nevertheless, there will always be a little bit of error introduced by the fact that \(X_t\) is only approximately \(N(0, \sigma^2)\). But really, in the context of all of the other errors, this one is pretty small.
Anyway, run the diffusion backwards and if you’ve estiamted \(s_t(x)\) well for the entire trajectory, you will get something that looks a lot like a new sample from \(p(X_0)\).
Some closing thoughts
So there you have it, a very high-level mathematical introduction to diffusion models. Along the way, I accidentally put them in some sort of historical context, which hopefully helped make some things clearer.
Obviously there are a lot of cool things that can happen. The ability to, essentially, design our training trajectories should definitely be utilised. To do that, we would need some measure of uncertainty in the recovery of \(s_t\). A possible way to do this would be to insert a probabilistic layer into neural net architecture. If this isn’t the final layer in the network, it should be possible to clean up any artifacts it introduces with further layers, but the uncertainty estimates from this hidden layer would still be indicative of the uncertainty in the recovery of the scores. Assuming, of course, that this is successful, it would be possible to target the training at improving the uncertainty.
Beyond the possibility of using a non-uniform distribution for \(t\), these uncertainty estimates might also help indicate the reliability of the generated sample. If the reverse diffusion spends too much time in areas with highly uncertain scores, it is unlikely that the generated data will be a good sample.
I am also somewhat curious about whether or not this type of system could be a reasonable alternative to bootstrap resampling in some contexts. I mean image creation is cool, but it’s not the only time people want to sample from a distribution that we only know empirically.
Footnotes
Maybe my favourite running gag was Ronny Chieng refusing to use the American pronunciation of Megan. ↩︎
I mean, my last post was recounting literature on the Markov property from the 70s and 80s. My only desire for this blog is for it to be very difficult to guess the topic of the next post.↩︎
I can’t stress enough that I made that tomato and feta tiktok pasta for dinner. Because that’s exactly how on trend I am.↩︎
I am very much managing expectations here↩︎
I cannot stress enough that this post will not help you implement a diffusion model. It might help you understand what is being implemented, but it also might not.↩︎
Really fucking relative.↩︎
Find a lesbian and follow her blog. Then you’ll get the good shit. There are tonnes of queer women in statistics. If you don’t know any it’s because they probably hate you.↩︎
The wokerati among you will notice that the quotient is the derivative of \(\log p(Q)\).↩︎
Look. I love you all. But I don’t want to introduce measure push-forwards. So if you want the maths read the damn paper.↩︎
This is the Knothe-Rosenblatt rearrangement of the optimal transport problem if you’re curious. And let’s face it, you’re not curious.↩︎
The normalising flow literature also has a lot of nice chats about how to model the \(T_j\)s using masked versions of the same neural net.↩︎
If you don’t have too much data, you could just replace that expectation with its empirical approximation. But when there is a lot of data, that will be expensive and stochastic gradient methods will perform better.↩︎
And be more likely to appropriately use your computational resources↩︎
We will see later that it doesn’t matter if we model \(T\) or \(S\), but the likelihood calculations come out nicer if we map from \(p(x)\) to \(q(u)\) rather than the other way around↩︎
There is a tonne of excellent software for efficiently solving differential equations!↩︎
My notation here is a bit awkward. The \(x\) in \(S(x,t)\) is keeping track of the initial condition, which in this case we do not know. But hey. Whatever.↩︎
Potentially even multi-modal↩︎
Classically this is done with a penalty, but you could also do it with things like early stopping and specific representations of the function. Which is nice because the continuous nomalising flow people use neural nets↩︎
The square on the norm isn’t always there↩︎
This was a big-sexy area in optimisation.↩︎
or at least a lot more expensive than, say, evaluating an exponential!↩︎
If you’re familiar with scalable ML methods, you might think well we have solved this problem. But I promise that it is not solved. The problem is that there’s no convenient analogue to subsampling the data. You can’t be half pregnant and you can’t half evaluate the forward map. There are, however, a pile of fabulous techniques that do their best to use multiple resolutions to get something that resembles a sensible MCMC scheme.↩︎
In our context, it’s a vector-valued function↩︎
Examples abound, but they include image reconstruction, tomographic inversion, and really anything where you’re estimating diffusivity↩︎
Gaussian processes↩︎
But not necessarily too creative. Not every transformation of a penalty makes a sensible prior. I’m looking at you lasso on increments.↩︎
Using the “well known” fact that the derivative of the log-determinant is the trace ↩︎
There are some complexities in practice around computing that trace. A straightforward implementation would require \(d\) autodiff sweeps, which would make the model totally impractical. There are basically two options: massively simplify \(f\) to be something like \(f(x) = h(Ax + b)\) for a smooth function \(h\) or use a stochastic trace estimator.↩︎
Measured in the Frobenius norm, of course↩︎
or priors↩︎
data + distributional assumptions = information↩︎
\(q\) will be the asymptotic distribution of the diffusion, but it isn’t achieved at finite time.↩︎
Arguably, gradient descent is to machine learners what arse crack is to roadies. It’s always present, but with just enough variation to make it interesting.↩︎
Technically it’s an Ito integral, but because the integrand is deterministic it reduces to a white noise integral↩︎
The Markov property implies that \(p(X_{t_1}, X_{t_2}\mid X_0 = x) = p(X_{t_1}\mid X_0 = x)p(X_{t_2} \mid X_{t_1})\). ↩︎
Lipschitz and bounded↩︎
I hate the quotient rule↩︎
This is why the signs don’t seem to match the forwards equation from before, but you can convince yourself if you do the change of variables \(\tau = s - t\), the new variable \(\tau\) runs forward in time and \(\bar{f}\) switches signs, which gives the right forwards equations (with different signs on the first and second order terms) in \((\tau,x)\).↩︎
If the \(p(X_0)\) is very rough, then, for very small \(t\), \(p(x,t)\) will also be quite rough but it will quickly become infinitely differentiable. It turns out that mathematicians know quite a lot about parabolic equations!↩︎
Reuse
Citation
@online{simpson2023,
author = {Simpson, Dan},
title = {Diffusion Models; or {Yet} Another Way to Sample from an
Arbitrary Distribution},
date = {2023-02-09},
url = {https://dansblog.netlify.app/posts/},
langid = {en}
}