```
def f(x: jax.Array, n: Int) -> jax.Array:
out = jnp.zeros_like(x)
for j in range(n):
out = out + x
return out
```

It’s worth noting that I know bugger all about diffusion models. But when they first came out, I had a quick look at how they worked and then promptly forgot about them because, let’s face it, I work on different things. But hey. If that’s not enough^{4} knowledge to write a blog post, I don’t know what is.

And here’s the thing. Most of the time when I blog about something I know a lot about it. Sometimes too much. But this is not one of those times. There are *plenty* of resources on the internet if you want to learn about diffusions models from an expert. Oodles. But where else but here can you read the barely proof-read writing of a man who read a couple of papers yesterday?

And who doesn’t want^{5} that?

One of the fundamental tasks in computational statistics is to sample from a probability distribution. There are millions of ways of doing this, but the most popular generic method is Markov chain Monte Carlo. But this is not the post about MCMC methods. I’ve already made a post about MCMC methods.

Instead, let’s focus on stranger ways to do it. In particular, let’s think about methods that, create a mapping that may depend on some properties of the target distribution such that the following procedure constructs a sample :

- Sample for some known distribution
- Set

The general problem of starting with a distribution and mapping it to another distribution is an example of a problem known as *measure transport*. Transport problems have been studied by mathematicians for yonks. It turns out that there are an infinite number of mappings that will do the job, so it’s up to us to choose a good one.

Probably the most famous^{6} transport problem is the *optimal transport problem* that was first studied by Monge and Kantorovich that tries to find a mapping that minimises subject to the constraint that whenever , where is some sort of cost function. There are canonical choices of cost function, but for the most part we are free to choose something that is convenient.

The measure transport concept is underneath the method of normalising flows, but the presentation that I’m most familiar with is due to Youssef Marzouk and his collaborators in 2011 and predates the big sexy normalising flow papers by a few years.

If and are both continuous, univariate distributions, it is pretty easy to construct a transport map. In particular, if is the cumulative distribution function of , then is a transport map. This works because, if , then . From this, we can use everyone’s favourite result that you can sample from a continuous univariate random variable by evaluating the quantile function at a uniform random value.

There are, of course, two problems with this: it only works in one dimension and we usually don’t know explicitly.

The second of these isn’t really a problem if we are willing to do something splendifferously dumb. And I am. Because I’m gay and frivolous^{7}.

If I write then I can differentiate this to get This is a *very* non-linear differential equation. We can make it even more non-linear differential equation by repeating the procedure to get Noting that we get This is a rubbish differential equation, but it has the singular advantage that it doesn’t depend^{8} on the normalising constant for , which can be useful. The downside is that the boundary conditions are infinite on both ends.

Regardless of that particular challenge, we could use this to build a generic algorithm.

Sample

Use a numerical differential equation solver to solve the equation with boundary conditions for some sufficiently large number and return

This will sample from truncated to .

I was going to write some python code to do this, but honestly it hurts my soul. So I shan’t.

Outside of one dimension, there is (to the best of my knowledge) no direct solution to the transport problem. That means that we need to solve our own. Thankfully, the glorious Youssef Marzouk and a bunch of his collaborators have spent some quality time mapping out this idea. A really nice survey of their results can be found in this paper.

Essentially the idea is that we can try to find the most convenient transport map available to us. In particular, it’s useful to minimise the *Kullback-Leibler* divergence between and its transport. After a little bit^{9} of maths, this is equivalent to minimising where is the Jacobian of . To finish the specification of the optimisation problem, it’s enough to consider *triangular* maps^{10} with the additional constraint that their Jacobians have positive determinants. Using a triangular map has two distinct advantages: it’s parsimonious and it makes the positive determinant constraint *much* easier to deal with. Triangular maps are also sufficient for the problem (my man Bogachev showed it in 2005).

That said, this can be a somewhat tricky optimisation problem. Youssef and his friends have spilt a lot of ink on this topic. And if you’re the sort of person who just fucking loves a weird optimisation problem, I’m sure you’ve got thoughts. With and without the triangular constraint, this can be parameterised as the composition of a sequence of simple functions, in which case you turn three times and scream *neural net* and a normalising^{11} flow appears.

All of that is very lovely. And quite nice in its context. But what happens if you don’t actually have access of the (unnormalised) log density of the target? What if you only have samples?

The good news is that you’re not shit out of luck. But it’s a bit tricky. And once again, that lovely review paper by Youssef and friends will tell us how to do it.

In particular, they noticed that if you swap the direction of the KL divergence, you get the optimisation problem for the inverse mapping that aims to minimise where is once again a triangular map subject to the monotonicity constraints Because we have the freedom to choose the reference density , we can choose it to be iid standard normals, in which case we get the optimisation problem which is a convex, separable optimisation problem that can be solved^{12} using, for instance, a stochastic gradient method. This can be turned into an unconstrained optimisation problem by explicitly parameterising the monotonicity constraint.

The monotonicity of makes the resulting nonlinear solve to compute relatively straightforward. In fact, if isn’t too big you can solve this sequentially dimension-by-dimension. But, of course, when you’ve got a lot of parameters this is a poor method and it would make more sense^{13} to attack it with some sort of gradient descent method. It might even be worth taking the time to learn the inverse function so that can be applied for, essentially, free.

To some extent, the answer is *yes*. This is *very much* normalising flows in its most embryonic form. They work to some extent. And this presentation makes some of the problems fairly obvious:

There’s no real guarantee that is going to be a nice smooth map, which means that we may have problems moving beyond the training sample.

The most natural way to organise the computations are naturally sequential involving sweeps across the parameters. This can be difficult to parallelise efficiently on modern architectures.

The complexity of the triangular map is going to depend on the order of variables. This is fine if you’re processing something that is inherently sequential, but if you’re working with image data, this can be challenging.

Of course, there are a *pile* of ways that these problems can be overcome in whole or in part. I’d point you to the last five years of ML conference papers. You’re welcome.

A really clever idea, which is related to normalising flows, is to ask *what if, instead of looking for a single*^{14} *map* , *we tried to find a sequence of maps* *that smoothly move from the identity map to to the transport map*.

This seems like it would be a harder problem. And it is. You need to make an infinite number of maps. But the saving grace is that as changes slightly, the map is also only going to change slightly. This means that we can parameterise the *change* relatively simply.

To this end, we write for some relatively simple function that models the infinitesimal change in the transport map as we move along the path. The hope is that learning the vector field will be *easier* than learning directly. To finish the specification, we require that

The question is _can we learn the function from data? If we can, it will be (relatively) easy to evaluate the transport map for any sample by just solving^{15} the differential equation.

It turns out that the map is most useful for *training* the normalising flow, while is useful for generating samples from the trained model. If we were using the methods in the previous section, we would have had to commit to *either* modelling *or* . One of the real advantages of the continuous formulation is that we can just as easily solve the equation with the *terminal condition*^{16} and solve the equation backwards in time to calculate ! The dynamics of both equations are driven by the vector field !

It turns out that learning parameters of differential equation (and other physical models) has a long and storied history in applied mathematics under the name of *inverse problems*. If that sounds like statistics, you’d be right. It’s statistics, except with no interest in measurement or, classically, uncertainty.

The classic inverse problem framing involves a *forward map* that takes as its input some parameters (often a function) and returns the full state of a system (often another function). For instance, the forwards map could be the solution of a partial differential equation like The thing that you should notice about this is that the forward map is a) possibly expensive to compute, b) not explicitly known, and c) extremely^{17} non-linear.

The problem is specified with data points and the aim is to find the value of that best fits the data. The traditional choice is to minimise the mean-square error

Now every single one of you will know immediately that this question is both vague and ill-posed. There are *many* functions that will fit the data. This means that we need to enforce^{18} some sort of complexity penalty on . This leads to the method known as Tikhonov regularisation^{19} where is some Banach space and is some tuning parameter.

As you can imagine, there’s a lot of maths under this about when there is a unique minimum, how the reconstruction behaves as and , and how the choice of effects the estimation of . There is also quite a lot of work^{20} looking at how to actually solve these sorts of optimisation problems.

Eventually, the field evolved and people started to realise that it’s actually fairly important to quantify the uncertainty in the estimate. This is … tricky under the Tikhonov regularlisation framework, which became a big motivation for *Bayesian* inverse problems.

As with all Bayesianifications, we just need to turn the above into a likelihood and a prior. Easy. Well, the likelihood part, at least, is easy. If we want to line up with Tikhonov regularisation, we can choose a Gaussian likelihood

This is familiar to statisticians, the forward model is essentially working as a non-standard link function in a generalised linear model. There are two big practical differences. The first one is that is *very* non-linear and almost certainly not monotone. The second problem is that evaluations of are typically very^{21} expensive. For instance, you may need to solve a system of differential equations. This means that any computational method^{22} is going to need to minimise the number of likelihood evaluations.

The choice of prior on can, however, be a bit tricky. The problem is that in most traditional inverse problems is a function^{23} and so we need to put a carefully specified prior on it. And there is a lot of really interesting work on what this means in a Bayesian setting. This is really the topic for another blogpost, but it’s certainly an area where you need to be aware of the limitations of different high-dimensional priors and how they perform in various contexts. For instance, if the function you are trying to reconstruct is likely to have a lot of sharp boundaries^{24} then you need to make sure that your prior can support functions with sharp boundaries. My little soldier bois^{25} don’t, so you need to get more^{26} creative.

Our aim now is to cast the normalising flow idea into the inverse problems framework. To do this, we remember that we begin our flow from a sample from and we then deform it until it becomes a sample from at some known time (which I’m going to choose as ). This means that if , then

We can now derive a relationship between and using the change of variables formula. In particular, which means that our log likelihood will be

The log-determinant term looks like it might cause some trouble. If is parameterised as a triangular map it can be written explicitly, but there is, of course, another route.

For notational ease, let’s consider , for some . Then We can differentiate this with respect to to get ^{27} to get where I used one of those *magical* vector calculus identities to get that trace. Remembering that , the log-determinant of the Jacobian at zero is zero and so we get the initial condition

The likelihood can be evaluated^{28} by solving the system of differential equations and the log likelihood is evaluated as

It turns out that you can take gradients of the log-likelihood efficiently by solving an augmented system of differential equations that’s twice the size of the original. This allows for all kinds of gradient-driven inferential shenanigans.

One big problem with normalising flows as written is that we only have two pieces of information about the entire trajectory :

we know that , and

we know that .

We know *absolutely nothing* about outside of those boundary conditions. This means that our model for basically gets to freestyle in those areas.

We can avoid this to some extent by choosing appropriate neural network architectures and/or appropriate penalties in the classical case or priors in the Bayesian case. There’s a whole mini-literature on choosing appropriate penalties.

Just to show how complex it is, let me quickly sketch what Finlay etc suggest as a way to keep the dynamics as boring as possible in the information desert. They lean into the literature on optimal transport theory to come up with the double penalty where the first term minimises the kinetic energy and, essentially, finds the least exciting path from to , while the second term ensures that the Jacobian of doesn’t get too big^{29}, which means that the mapping doesn’t have many sharp changes. Both of these penalty terms are designed to both aid generalisation and to make sure the differential equation isn’t unnecessarily difficult for a ODE solver.

A slightly odd feature of these penalties is that they are both data dependent. That suggests that a prior would, probably, require an *amount* of work. This is work that I don’t feel like doing today. Especially because this blog post isn’t about bloody normalising flows.

Ok, so normalising flows are cool, but there are a couple of places where they could potentially be improved. There is a *long* literature on diffusion models, but the one I’m mostly stealing from is this one by Song et al. (2021)

Firstly, the vector field *directly* effects how easy the differential equations are to solve. This means that if is too complicated, it can take a long time to both train the model and generate samples from the trained model. To get around this you need to put fairly strict penalties^{30} and/or structural assumptions on .

Secondly, we only have information^{31} at two ends of the flow. The problem would become *a lot* easier if we could somehow get information about intermediate states. In the inverse problems literature, there’s a concept of *value of information* that talks about how useful sampling a particular time point can be in terms of reducing model uncertainty. In general this, or other criteria, can be used to design a set of useful sampling times. I don’t particularly feel like working any of this out but one thing I am fairly certain of is that no optimal design would only have information at and !

Diffusion models fix these two aspects of normalising flows at the cost of both a more complex mathematical formulation and some inexactness^{32} around the base distribution when generating new samples.

Diffusions are to applied mathematicians what gaffer tape is to^{33} a roadie. They are a ubiquitous, convenient, and they hold down the fort when nothing else works.

There are a number of diffusions that are familiar in statistics and machine learning. The most famous one is probably the Langevin diffusion which is asymptotically distributed according to . This forms the basis of a bunch of MCMC methods as well as some faster, less adjusted methods.

But that is not the only diffusion. Today’s friend is the Ornstein-Uhlenbeck (OU) process, which is a Gaussian process that The OU process can be thought of as a mean-reverting Brownian motion. As such, it has continuous but nowhere differentiable sample paths

The stationary distribution of is , where is the identity matrix. In fact, if we *start* the diffusion at stationarity by setting then X_t is a *stationary* Gaussian process with covariance function

More interestingly in our context, however, is what happens if we start the diffusion from a fixed point , that will eventually be a sample from . In that case, we can solve the linear stochastic differential equation exactly to get where the integral on the right hand side can be interpreted^{34} as a white noise integral and so and the variance is From these equations, we see that the mean of the diffusion hurtles exponentially fast towards zero and the variance moves at the same speed towards .

More importantly, this means that, given a starting point , we can generate data from any part of the diffusion ! If we want a sequence of observations from the same trajectory, we can generate them sequentially using the fact that and OU process is a Markov^{35} process. This means that we are no longer limited to information at just two points along the trajectory.

So far, there is nothing to learn here. The OU process has a known drift and variance, so everything is splendid. It’s even easy to simulate from. The challenge pops up when we try to reverse the diffusion, that is, when we try to *remove* noise from a sample rather than add noise to it.

In some sense, this shouldn’t be too disgusting. A diffusion is a Markov process and, if we run the Markov process back in time, we still get a Markov process. In fact, we are going to get another diffusion process.

The twist is that the new diffusion process is going to be quite a bit more complex than the original one. The problem is that unless comes from a Gaussian distribution, this process will be non-Gaussian, and thus somewhat tricky to find the reverse trajectory of.

To see this, consider and recall that and The first two terms in that integrand are Gaussian densities and thus their product is a bivariate Gaussian density Unfortunately, as is not Gaussian, the marginal distribution will be non-Gaussian. This means that our reverse time transition density is also going to be *very* non-linear.

In order to work out a stochastic differential equation that runs backwards in time and generates the same trajectory, we need a little bit of theory on how the unconditional density and the transition density evolves in time (here and everywhere st). These are related through the Kolmogorov equations.

To introduce these, we need to briefly consider the more general diffusion for nice^{36} vector/matrix-valued functions and . Kolmogorov showed that the unconditional density evolves according the the partial differential equation subject to the initial condition This is known as Kolmogorov’s forward equation or the Fokker-Planck equation.

The other key result is about the density of *conditioned on some future value* , . We write this density as and it satisfies the partial differential equation subject to the *terminal* condition This is known as the Kolmogorov backward equation. Great names. Beautiful names.

Let’s consider a differential equation for the joint density Going ham with the product rule gives The first-order derivatives simplify, using the product rule, to

Staring at this for a moment, we notice that this looks has the same structure as the first-order term on the forward equation. In that case, the second-order term would be

If we notice that and we can re-write the second-order derivative terms in Equation 1 as

This is almost, but not quite, what we want. We are a single minus sign away. Remembering that we probably don’t want it to turn up in any derivatives^{37}. To this end, let’s make the substitution With this substitution the second order terms are where

If we write we get the joint PDE

In order to identify the reverse time diffusion, we are going to find the reverse time backward equation, which confusingly, is for As is a constant in both and , we can divide both sides of Equation 2 by it to get where again and and are known.

This is the forward Kolmogorov equation for the time-reversed^{38} diffusion where is another white nose. Anderson (1982) shows how to connect the white noise that’s driving the forward dynamics with the white noise that’s driving the reverse dynamics , but that’s overkill for our present situation.

In the context of an OU process, we get the reverse equation where time runs backwards and I’ve used the formula for the logarithmic derivative.

Unlike the forward process, the reverse process is the solution to a *non-linear* stochastic differential equation. In general, this cannot be solved in closed form and we need to use a numerical SDE solver to generate a sample.

It’s worth noting that the OU process is an overly simple cartoon of a diffusion model. In practice, is usually an increasing function of time so the system injects more noise as the diffusion moves along. This changes some of the exact equations slightly, but you can still sample analytically for any (as long as you choose a fairly simple function for ). There is a *large* literature on these choices and, to be honest, I can’t be bothered going through them here. But obviously if you want to implement a diffusion model yourself you should look this stuff up.

The reverse dynamics are driven by the score function Typically, we do not know the density and while we could solve the forward equation in order to estimate it, that is wildly inefficient in high dimensions.

If we can assume that for each , is approximately , then the resulting reverse diffusion is linear In this case is Gaussian with a mean and covariance that has closed form solution in terms of and (perhaps after some numerical quadrature and matrix exponentials).

Unfortunately, as discussed above this is not true. A better approximation would be a mixture of Gaussians but, in general, we can use *any* method to approximate There are no particular constraints on it, except we expect it to be fairly smooth^{39} in both and . Hence, we can just learn the score.

As we are going to solve the SDE numerically, we only need to estimate the score at a finite set of locations. In every application that I’ve seen, these are pre-specified, however it would also be possible to use a basis function expansion to interpolate to arbitrary time points. But, to be honest, I think every single example I’ve seen just uses a regularly spaced grid.

So how do we estimate ? Well just like every other situation, we need to define a likelihood (or, I guess, an optimisation criterion). One way to think about this would be to note that you’ll never *perfectly* recover the initial signal. This is because we need to solve a non-linear stochastic partial differential equation and there will, inherently, be noise in that solution. So instead, assume that we have an initial sample and that after solving the backward equation we have an unbiased estimator of with standard deviation , where is the number of time steps. We know a lot about how the error of SDE solvers scale with and so we can use that to set an appropriate scale for . For instance, if you’re using the Euler–Maruyama method, then it has strong order and would likely be an appropriate scaling.

This strongly suggests a likelihood that looks like where is the estimate of you get by running the reverse diffusion conditioned on , where is an exact sample at time from the forward diffusion started at .

This is the key to the success of diffusion models: given our training sample , we generate new data and we can generate as much of that data as we want. Furthermore, we can choose any set of s we want. We can sample a single pair multiple times or we can look at a diversity of sampling data.

We can even try to recover an intermediate state from information about a future state , . This gives us quite the opportunity to target our learning to areas of the space where we have relatively poor estimates of the score function.

Of course, that’s not what people do. They do stochastic gradient descent to minimise possibly subject to some penalties on . In fact, the distribution on is usually a discrete uniform. As with any sufficiently complex task, there is a lot of detailed work on exactly how to best parameterise, solve, and evaluate this optimisation procedure.

Once the model is trained and we have an estimate of the score function, we can generate new samples by first sampling and running the reverse diffusion starting from for some sufficiently large . One of the advantages of using a variant of the OU process with a non-constant is that we can choose to be smaller. Nevertheless, there will always be a little bit of error introduced by the fact that is only *approximately* . But really, in the context of all of the other errors, this one is pretty small.

Anyway, run the diffusion backwards and if you’ve estiamted well for the entire trajectory, you will get something that looks a lot like a new sample from .

So there you have it, a very high-level mathematical introduction to diffusion models. Along the way, I accidentally put them in some sort of historical context, which hopefully helped make some things clearer.

Obviously there are *a lot* of cool things that can happen. The ability to, essentially, design our training trajectories should definitely be utilised. To do that, we would need some measure of uncertainty in the recovery of . A possible way to do this would be to insert a probabilistic layer into neural net architecture. If this isn’t the final layer in the network, it should be possible to clean up any artifacts it introduces with further layers, but the uncertainty estimates from this hidden layer would still be indicative of the uncertainty in the recovery of the scores. Assuming, of course, that this is successful, it would be possible to target the training at improving the uncertainty.

Beyond the possibility of using a non-uniform distribution for , these uncertainty estimates might also help indicate the reliability of the generated sample. If the reverse diffusion spends too much time in areas with highly uncertain scores, it is unlikely that the generated data will be a good sample.

I am also somewhat curious about whether or not this type of system could be a reasonable alternative to bootstrap resampling in some contexts. I mean image creation is cool, but it’s not the only time people want to sample from a distribution that we only know empirically.

Maybe my favourite running gag was Ronny Chieng refusing to use the American pronunciation of Megan. ↩︎

I mean, my last post was recounting literature on the Markov property from the 70s and 80s. My only desire for this blog is for it to be very difficult to guess the topic of the next post.↩︎

I can’t stress enough that I made that tomato and feta tiktok pasta for dinner. Because that’s exactly how on trend I am.↩︎

I am very much managing expectations here↩︎

I cannot stress enough that this post will not help you implement a diffusion model. It might help you understand what is being implemented, but it also might not.↩︎

Really fucking relative.↩︎

Find a lesbian and follow her blog. Then you’ll get the good shit. There are tonnes of queer women in statistics. If you don’t know any it’s because they probably hate you.↩︎

The wokerati among you will notice that the quotient is the derivative of .↩︎

Look. I love you all. But I don’t want to introduce measure push-forwards. So if you want the maths read the damn paper.↩︎

This is the Knothe-Rosenblatt rearrangement of the optimal transport problem if you’re curious. And let’s face it, you’re not curious.↩︎

The normalising flow literature also has a lot of nice chats about how to model the s using masked versions of the same neural net.↩︎

If you don’t have too much data, you could just replace that expectation with its empirical approximation. But when there is a lot of data, that will be expensive and stochastic gradient methods will perform better.↩︎

And be more likely to appropriately use your computational resources↩︎

We will see later that it doesn’t matter if we model or , but the likelihood calculations come out nicer if we map from to rather than the other way around↩︎

There is a tonne of excellent software for efficiently solving differential equations!↩︎

My notation here is a bit awkward. The in is keeping track of the

*initial condition*, which in this case we do not know. But hey. Whatever.↩︎Potentially even multi-modal↩︎

Classically this is done with a penalty, but you could also do it with things like early stopping and specific representations of the function. Which is nice because the continuous nomalising flow people use neural nets↩︎

The square on the norm isn’t always there↩︎

This was a big-sexy area in optimisation.↩︎

or at least a lot more expensive than, say, evaluating an exponential!↩︎

If you’re familiar with scalable ML methods, you might think

*well we have solved this problem*. But I promise that it is not solved. The problem is that there’s no convenient analogue to subsampling the data. You can’t be half pregnant and you can’t half evaluate the forward map. There are, however, a pile of fabulous techniques that do their best to use multiple resolutions to get something that resembles a sensible MCMC scheme.↩︎In our context, it’s a vector-valued function↩︎

Examples abound, but they include image reconstruction, tomographic inversion, and really anything where you’re estimating diffusivity↩︎

Gaussian processes↩︎

But not necessarily too creative. Not every transformation of a penalty makes a sensible prior. I’m looking at you lasso on increments.↩︎

Using the “well known” fact that the derivative of the log-determinant is the trace ↩︎

There are some complexities in practice around computing that trace. A straightforward implementation would require autodiff sweeps, which would make the model totally impractical. There are basically two options: massively simplify to be something like for a smooth function or use a stochastic trace estimator.↩︎

Measured in the Frobenius norm, of course↩︎

or priors↩︎

data + distributional assumptions = information↩︎

will be the asymptotic distribution of the diffusion, but it isn’t achieved at finite time.↩︎

Arguably, gradient descent is to machine learners what arse crack is to roadies. It’s always present, but with just enough variation to make it interesting.↩︎

Technically it’s an Ito integral, but because the integrand is deterministic it reduces to a white noise integral↩︎

The Markov property implies that . ↩︎

Lipschitz and bounded↩︎

I hate the quotient rule↩︎

This is why the signs don’t seem to match the forwards equation from before, but you can convince yourself if you do the change of variables , the new variable runs forward in time and switches signs, which gives the right forwards equations (with different signs on the first and second order terms) in .↩︎

If the is very rough, then, for very small , will also be quite rough but it will quickly become infinitely differentiable. It turns out that mathematicians know quite a lot about parabolic equations!↩︎

BibTeX citation:

```
@online{simpson2023,
author = {Dan Simpson},
editor = {},
title = {Diffusion Models; or {Yet} Another Way to Sample from an
Arbitrary Distribution},
date = {2023-02-09},
url = {https://dansblog.netlify.app/posts/},
langid = {en}
}
```

For attribution, please cite this work as:

Dan Simpson. 2023. “Diffusion Models; or Yet Another Way to Sample
from an Arbitrary Distribution.” February 9, 2023. https://dansblog.netlify.app/posts/.

There’s a long literature on effective approximation to Gaussian Processes that don’t turn out to be computational nightmares. I’m definitely not going to summarise them here, but I’ll point to an earlier (quite technical) post that mentioned some of them. The particular computational approximation that I am most fond of makes use of the Markov property and efficient sparse matrix computations to reduce memory use and make the linear algebra operations significantly faster.

One of the odder challenges with Markov models is that information about how Markov structures work in more than one dimension can be quite difficult to find. So in this post I am going to lay out some of the theory.

A much more practical (and readable) introduction to this topic can be found in this lovely paper by Finn, David, and Håvard. So don’t feel the burning urge to read this post if you don’t want to. I’m approaching the material from a different viewpoint and, to be very frank with you, I was writing something else and this section just became extremely long so I decided to pull it out into a blog post.

So please enjoy today’s entry in *Dan writes about the weird corners of Gaussian processes*. I promise that even though this post doesn’t make it seem like this stuff is useful, it really is. If you want to know anything else about this topic, essentially all of the Markov property parts of this post come from Rozanov’s excellent book Markov Random Fields.

By the end of today’s post we will have defined^{1} a Markovian process in terms of its reproducing Kernel Hilbert space (RKHS), that is the space of functions that contain the posterior mean^{2} when there are Gaussian observations. This space always exists and its inner product is entirely determined by the covariance function of a GP. That said, for a given covariance function, the RKHS can. be difficult to find. Furthermore, the problem with basing our modelling off a RKHS is that it is not immediately obvious how we will do the associated computations This is in contrast to a covariance function approach, where it is quite easy^{3} to work out how to convert the model specification to something you can attack with a computer. By the end of this post we will have tacked that.

The extra complexity of the RKHS pays off in modelling flexibility, both in terms of the types of model that can be build and the spaces^{4} you can build them on. I am telling you this now because things are about to get a little mathematical.

To motivate the technique, let’s consider the covariance operator where is the domain over which the GP is defined (usually but maybe you’re feeling frisky).

To see how this could be useful, we are going to need to think a little bit about how we can simulate a multivariate Gaussian random variable . To do this, we first compute the square root^{5} and sample a vector of iid standard normal variables . Then . You can check it by checking the covariance. (it’s ok. I’ll wait.)

While the square root of the covariance operator is a fairly straightforward mathematical object^{6}, the analogue of the iid vector of standard normal random variables is a bit more complex.

Thankfully I’ve covered this in a previous blog. The engineering definition of white noise as a GP such that for every , is an iid random variable is not good enough for our purposes. Such a process is hauntingly irregular^{7} and it’s fairly difficult to actually do anything with it. Instead, we consider white noise as a random function defined on the subsets of our domain. This feels like it’s just needless technicality, but it turns out to actually be very very useful.

**Definition 1 (White noise) **A (complex) Gaussian white noise is a random measure^{8} such that, for every^{9} disjoint^{10} pair of sets satisfies the following properties

- If and are disjoint then
- If and are disjoint then and are uncorrelated
^{11}, ie .

This doesn’t feel like we are helping very much because how on *earth* am I going to define the product ? Well the answer, you may be shocked to discover, requires a little bit more maths. We need to define an integral, which turns out to not be *shockingly* difficult to do. The trick is to realise that if I have an indicator function then^{12} In that calculation, I just treated like I would any other measure. (If you’re more of a probability type of girl, it’s the same thing as noticing .)

We can extend the above by taking the sum of two indicator function where and are disjoint and and are any real numbers. By the same reasoning above, and using the linearity of the integral, we get that where the last line follows by doing the ordinary^{13} integral of .

It turns out that every interesting function can be written as the limit of piecewise constant functions^{14} and we can therefore *define* for any function^{15}

With this notion in hand, we can finally define the action of an operator on white noise.

**Definition 2 (The action of an operator on white noise) **Let be an operator on some Hilbert space of functions with adjoint , then we define to be the random measure that satisfies, for every ,

One of those inconvenient things that you may have noticed from above is that is *not* going to be a function. It is going to be a measure or, as it is more commonly known, a *generalised Gaussian process*. This is the GP analogue of a generalised function and, as such, only gives an actual value when you integrate it against some sufficiently smooth function.

**Definition 3 (Generalised Gaussian Process) **A generalised Gaussian process is a random signed measure (or a random generalised function) that, for any , is Gaussian. We will often write which helps us understand that a generalised GP is indexed by functions.

In order to separate this out from the ordinary GP , we will write it as These two ideas coincide in the special case where which will occur when smooths the white noise sufficiently. In all of the cases we really care about today, this happens. But there are plenty of Gaussian processes that can only be considered as generalised GPs^{16}

This type of construction for is used in two different situations: kernel convolution methods directly use the representation, and the SPDE methods of Lindgren, Lindström and Rue use it indirectly.

I’m interested in the SPDE method, as it ties into today’s topic. Also because it works really well. This method uses a slightly modified version of the above equation where is the (left) inverse of . I have covered this method in a previous post, but to remind you the SDPE method in its simplest form involves three steps:

Approximate for some set of weights and a set of deterministic functions that we are going to use to approximate the GP

Approximate

^{17}the*test function*for some set of deterministic weightsPlug these approximations into the equation to get the equation

As this has to be true for *every* vector , this is equivalent to the linear system where and .

Obviously this method is only going to be useful if it’s possible to compute the elements of and efficiently. In the special case where is a differential operator^{18} and the basis functions are chosen to have compact support^{19}, these calculations form the basis of the finite element method for solving partial differential equations.

The most important thing, however, is that if is a differential operator *and* the basis functions have compact support, the matrix is sparse and the matrix can be made^{20} diagonal, which means that has a sparse precision matrix. This can be used to make inference with these GPs very efficient and is the basis for GPs in the INLA software.

A natural question to ask is *when will we end up with a sparse precision matrix*? The answer is not quite when is a differential operator. Although that will lead to a sparse precision matrix (and a Markov process), it is not required. So the purpose of the rest of this post is to quantify all of the cases where a GP has the Markov property and we can make use of the resulting computational savings.

Part of the reason why I introduced the notion of a generalised Gaussian process is that it is useful in the definition of the Markov process. Intuitively, we know what this definition is going to be: if I split my space into three disjoint sets , and in such a way that you can’t get from to without passing through , then the Markov property should say, roughly, that every random variable is conditionally independent of every random variable *given* (or conditional on) knowing the values of the entire set .

That definition is all well and good for a hand-wavey approach, but unfortunately it doesn’t quite hold up to mathematics. In particular, if we try to make a line^{21}, we will hit a few problems. So instead let’s do this properly.

All of the material here is covered in Rozanov’s excellent but unimaginatively named book *Markov Random Fields*.

To set us up, we should consider the types of sets we have. There are three main sets that we are going to be using: the open^{22} set , its boundary^{23} . For example, if and is the interior of the unit circle, and its open complement . For a 2D example, if is the *interior* of the unit circle, then could be the unit circle, and would be the *exterior* of the unit circle.

One problem with these sets, is that while will be a 2D set, is only one dimensional (it’s a circle, so it’s a line!). This causes some troubles mathematically, which we need to get around by using the fattening of , which is the set where is the distance from to the nearest point in .

With all of this in hand we can now give a general definition of the Markov property.

**Definition 4 (The Markov property for a generalised Gaussian process) **Consider a zero mean generalised GP^{24} . For any^{25} subset , we define the collection of random variables^{26} We will call the *random field*^{27} associated with .

Let be a system of domains^{28} in . We say that has the Markov^{29} property (with respect to ) if, for all and for any sufficiently small , where and .

The Markov property defined above is great and everything, but in order to manipulate it, we need to think carefully about the how the domains , and can be used to divide up the space . To do this, we need to basically localise the Markov property to one set of , , . This concept is called a *splitting*^{30} of and by .

**Definition 5 **For some domain and , set . The space splits and if where is the sum of orthogonal components^{31} and if and only if there is some such that^{32}

This emphasizes that we can split our space into three separate components: inside , outside and on the boundary of and the ability to do that for any^{33} domain is the key part of the Markov^{34} property.

A slightly more convenient way to deal with splitting spaces is the case where the we have overlapping sets , that cover the domain (ie ) and the splitting set is their intersection . In this case, the splitting equation becomes I shan’t lie: that looks wild. But it makes sense when you take and , in which case and .

The final thing to add before we can get to business is a way to get rid of all of the annoying s. The idea is to take the intersection of all of the as the splitting space. If we define we can re-write^{35} the splitting equation as

This gives the following statement of the Markov property.

**Definition 6 **Let be a system of domains^{36} in . We say that has the Markov property (with respect to ) if, for all , ,, we have, for some and

We are going to fall further down the abstraction rabbit hole in the hope of ending up somewhere useful. In this case, we are going to invent an object that has no reason to exist and we will show that it can be used to compactly restate the Markov property. It will turn out in the next section that it is actually a useful characterization that will lead (finally) to an operational characterisation of a Markovian Gaussian process.

**Definition 7 (Dual random field) **Let be a generalised Gaussian process with an associated random field , and let be a complete system of open domains in . The *dual* to the random field , on the system is the random field , that satisfies and

This definition looks frankly a bit wild, but I promise you, we will use it.

The reason for its structure is that it directly relates to the Markov property. In particular, the existence of a dual field implies that, if we have any , then That’s the first thing we need to show to demonstrate the Markov property.

The second part is much easier. If we note that , it follows that

This gives us our third (and final) characterisation of the (second-order) Markov property.

**Definition 8 **Let be a system of domains^{37} in . Assume that the random field has an associated dual random field .

We say that , has the Markov property (with respect to^{38} ) if and only if for all , When this holds, we say that the dual field is *orthogonal* with respect to .

There is probably more to say about dual fields. For instance, the dual of the dual field is the original field. Neat, huh. But really, all we need to do is know that an orthogonal dual field implies a the Markov property. Because next we are going to construct a dual field, which will give us an actually useful characterisation of Markovian GPs.

In this section, our job is to construct a dual random field. To do this, we are going to exploit the notion of a *conjugate ^{39} Gaussian process*, which is a generalised

We will return to the issue of whether or not actually exists later, but assuming it does let’s see how it’s associated random field relates to for . While it is not always true that these things are equal, it *is* always true that We will consider when equality holds in the next section. But first let’s show the inclusion.

The space contains all random variables of the form , where the support of is compact in , which means that it is a positive distance from . That means that, for some , the support of is outside^{42} of . So if we fix that and consider any smooth with support in^{43} , then, from the definition of the conjugate GP, we have^{44} This means that is perpendicularity to and, therefore, . Now, is defined as the intersection of these spaces, but it turns out that^{45} for any spaces and , This is because and so every function that’s orthogonal to functions in is also orthogonal to functions in . The same goes for . We have shown that and every is in for some . This gives the inclusion

To give conditions for when it’s an actual equality is a bit more difficult. It, maybe surprisingly, involves thinking carefully about the reproducing kernel Hilbert space of . We are going to take this journey together in two steps. First we will give a condition on the RKHS that guarantees that exists. Then we will look at when .

First off, though, we need to make sure that exists. Obviously^{46} if it exists then it is unique and .

But does it exist? The answer turns out to be *sometimes*. But also *usually*. To show this, we need to do something that is, frankly, just a little bit fancy. We need to deal with the reproducing kernel Hilbert space^{47}. This feels somewhat surprising, but it turns out that it is a fundamental object^{48} and intrinsically tied to the space .

The reproducing kernel space, which we will now^{49} call because we are using for something else in this section, is a set of deterministic generalised functions , that can be evaluated at functions^{50} as A generalised function if there is a corresponding random variable in that satisfies It can be shown^{51} that there is a one-to-one correspondence between and , in the sense that for every there is a unique .

We can use this correspondence to endow with an inner product

So far, so abstract. The point of the conjugate GP is that it gives us an explicit construction of the^{52} mapping . And, importantly for the discussion of existence, if there is a conjugate GP then the RKHS has a particular relationship with .

To see this, let’s assume exists. Then, for each , the generalised function is in because, by the definition of we have that Hence, the embedding is given by .

Now, if we do a bit of mathematical trickery and equate things that are isomorphic, . On its face, that doesn’t make much sense because on the left we have a space of actual functions and on the right we have a space of generalised functions. To make it work, we associate each smooth function with the generalised function defined above.

This make the closure^{53} of under the norm and hence we have showed that if there is a conjugate GP, then It turns out that if is dense in then that implies that there exists a conjugate function defined through the isomorphism . This is because and is continuous. Hence if we choose then .

We have shown the following.

**Theorem 1 **A conjugate GP exists if and only if is dense in .

This is our first step towards making statements about the stochastic process into statements about the RKHS. We shall continue along this road.

You might, at this point, be wondering if that condition ever actually holds. The answer is yes. It does fairly often. For instance, if is a stationary GP with spectral density , the biorthogonal function exists if and only if there is some such that This basically says that the theory we are developing doesn’t work for GPs with extremely smooth sample paths (like a GP with the square-exponential covariance function). This is not a restriction that bothers me at all.

For non-stationary GPs that aren’t too smooth, this will also hold as long as nothing too bizarre is happening at infinity.

We have shown already^{54} that (that last bit with all the complements can be read as “the support of is inside and always more than from the boundary.”). It follows then that This is nice because it shows that is related to the space that is if is a function that is the limit of a sequence of functions with for some , then and *every* such random variable has an associated .

So, in the sense^{55} of isomorphisms these are equivalent, that is

This means that if we can show that , then we have two spaces that are isomorphic to the same space *and* use the same isomorphism . This would mean that the spaces are equivalent.

This can also be placed in the language of function spaces. Recall that Hence will be isomorphic to if and only if that is, if and only if every is the limit of a sequence of smooth functions compactly supported within .

This turns out to not *always* be true, but it’s true in the situations that we most care about. In particular, we get the following theorem, which I am certainly not going to prove.

Assume that the conjugate GP exists. Assume that *either* of the following holds:

Multiplication by a function is bounded in , ie

The shift operator is bounded under both the RKHS norm and the covariance

^{56}norm for small , ie holds in both norms for all , sufficiently small.

Then is the dual of over the system of sets that are bounded or have bounded complements in .

The second condition is particularly important because it *always* holds for stationary GPs with as their covariance structure is shift invariant. It’s not impossible to come up with examples of generalised GPs that don’t satisfy this condition, but they’re all a bit weird (eg the “derivative” of white noise). So as long as your GP is not too weird, you should be fine.

And with that, we are finally here! We have that is the dual random field to , *and* we have a lovely characterisation of in terms of the RKHS . We can combine this with our definition of a Markov property for GPs with a dual random field and get that a GP is Markovian if and only if We can use the isomorphism to say that if , , then there is a such that Moreover, this isomorphism is unitary (aka it preserves the inner product) and so Hence, has the Markov property if and only if

Let’s memorialise this as a theorem.

**Theorem 2 **A GP with a conjugate GP is Markov if and only if its RKHS is local, ie if and have disjoint supports, then

This result is *particularly* nice because it entirely characterises the RHKS inner product of a Markovian GP. The reason for this is a deep result from functional analysis called Peetre’s Theorem, which states, in our context, that locality implies that the inner product has the form where^{57} are integrable functions and only a finite number of them are non-zero at any point .

This connection between the RKHS and the dual space also gives the following result for stationary GPs.

**Theorem 3 **Let be a stationary Gaussian process. Then has the Markov property if and only if its spectral density is the inverse of a non-negative, symmetric polynomial.

This follows from the characterisation of the RKHS as having the inner product as where is the Fourier transform of and the fact that a differential operator can is transformed to a polynomial in Fourier space.

*Waaaay* back near the top of the post I described a way to write a (generalised) GP in terms of its covariance operator and the white noise process From the discussions above, it follows that the corresponding conjugate GP is given by This means that the RKHS inner product is given by From the discussion above, if is Markovian, then is^{58} a differential^{59} operator.

To close out this post, let’s look at how we can use the RKHS to build an approximation to a Markovian GP. This is equivalent^{60} to the SPDE method that was very briefly sketched above, but it only requires knowledge of the RKHS inner product.

In particular, if we have a set of basis functions , , we can define the approximate RKHS as the space of all functions equipped with the inner product where the LHS and are functions and on the right they are the vectors of weights, and

For a finite dimensional GP, the matrix that defines the RKHS inner product is^{61} the inverse of the covariance matrix. Hence the finite dimensional GP associated with the RKHS is the random function where the weights .

If the GP is Markovian *and* the basis functions have compact support, then is a sparse matrix and maybe he’ll love me again.

or redefined if you’ve read my other post↩︎

For other observation models it contains the posterior mode↩︎

Step 1: Open Rasmussen and Williams.↩︎

For example, the process I’m about to describe is not meaningfully different for a process on a sphere. Whereas if you want to use a covariance function on a sphere you are stuck trying to find a whole new class of positive definite functions. It’s frankly very annoying. Although if you want to build a career out of characterising positive definite functions on increasingly exotic spaces, you probably don’t find it annoying.↩︎

Or the Cholesky factor if you add a bunch of transposes in the right places, but let’s not kid ourselves this is not a practical discussion of how to do it↩︎

Albeit a bit advanced. It’s straightforward in the sense that for an infinite-dimensional operator it happens to work a whole like a symmetric positive semi-definite matrix. It is not straightforward in the sense that your three year old could do it. Your three year old can’t do it. But it will keep them quiet in the back seat of the car while you pop into the store for some fags. It’s ok. The window’s down.↩︎

For any subset ,

*and*↩︎Countably additive set-valued function taking any value in ↩︎

measurable↩︎

↩︎

If is also Gaussian then this is the same as them being independent↩︎

Recall that is our whole space. Usually , but it doesn’t matter here.↩︎

A bit of a let down really.↩︎

like but with more subsets↩︎

is the space of functions with the property that .↩︎

eg the Gaussian free field in physics, or the de Wijs process.↩︎

You can use a separate set of basis functions here, but I’m focusing on simplicity↩︎

The standard example is ↩︎

In particular piecewise linear tent functions build on a triangulation↩︎

Read the paper, it’s a further approximation but the error is negligible↩︎

()-dimensional sub-manifold↩︎

This set does not include its boundary↩︎

This is defined as the set , where is the closure of . But let’s face it. It’s the fucking boundary. It means what you think it means.↩︎

I’m using here as a

*generic*generalised GP, rather than , which is built using an ordinary GP. This doesn’t really make much of a difference (the Markov property for one is the same as the other), but it makes me feel better.↩︎measurable↩︎

Here is the support of , that is the values of such that .↩︎

This is the terminology of Rozanov. Random Field is also another term for stochastic process. Why only let words mean one thing?↩︎

non-empty connected open sets↩︎

Strictly, this is the

*weak*or*second-order*Markov property↩︎If you’re curious, this is basically the same thing as a splitting -algebra. But, you know, sans the -algebra bullshit.↩︎

That is, any can be written as the sum , where , , and are

*mutually orthogonal*(ie !).↩︎This is using the idea that the conditional expectation is a projection.↩︎

Typically any open set, or any open connected set, or any open, bounded set. A subtlety that I don’t really want to dwell on is that it is possible to have a GP that is Markov with respect to one system of domains but not another.↩︎

The Markov property can be restated in this language as for every system of complementary domains and boundary , , , there exists a small enough such that splits and ↩︎

Technically we are assuming that for small enough . This is not a particularly onerous assumption.↩︎

non-empty connected open sets↩︎

non-empty connected open sets↩︎

The result works with some subsystem . To prove it for it’s enough to prove it for some subset that separates points of . This is a wildly technical aside and if it makes no sense to you, that’s very much ok. Frankly I’m impressed you’ve hung in this long.↩︎

Rozanov also calls this the

*biorthogonal*GP. I like conjugate more.↩︎Up to this point, it hasn’t been technically necessary for the GP to be generalised. However, here is very much is. It turns out that if realisations of are almost surely continuous, then realisations of are almost surely generalised functions.↩︎

I’m writing this as if all of these GPs are real valued, but for full generality, we should be dealing with complex GPs. Just imagine I put complex conjugates in all the correct places. I can’t stop you.↩︎

That is, inside and more than from the boundary↩︎

can be non-zero inside but only if it’s less than away from the boundary.↩︎

It’s zero because the two functions are never non-zero at the same time, so their product is zero.↩︎

Here, and probably in a lot of other places, we are taking the union of spaces to be the span of their sum. Sorry.↩︎

Really Daniel. Really. (It’s an isomorphism so if you do enough analysis courses this is obvious. If that’s not clear to you, you should just trust me. Trust issues aren’t sexy. Unless you have cum gutters. In which case, I’ll just spray my isomorphisms on them and you can keep scrolling TikTok.)↩︎

This example is absolutely why I hate that we’ve settled on RKHS as a name for this object because the thing that we are about to construct does not always have a reproducing kernel property. Cameron-Martin space is less confusing. But hey. Whatever. The RKHS for the rest of this section is not always a Hilbert space with a reproducing kernel. We are just going to have to be ok with that.↩︎

Nothing about this analysis relies on Gaussianity. So this is a general characterisation of a Markov property for

*any*stochastic process with second moments.↩︎In previous blogs, this was denoted and truly it was too confusing when I tried to do it here. And by that point I wasn’t going back and re-naming .↩︎

is the space of all infinitely differentiable compactly supported functions on ↩︎

The trick is to notice that the set of all possible is dense in .↩︎

unitary↩︎

the space containing the limits (in the -norm) of all sequences in ↩︎

If you take some limits↩︎

I mean, really. Basically we say that if there is an isomorphism between and . Could I be more explicit? Yes. Would that make this unreadable? Also yes.↩︎

.↩︎

is a multi-index, which can be interpreted as , and ↩︎

in every local coordinate system↩︎

Because defines an inner product, it’s actually a symmetric elliptic differential operator↩︎

Technically, you need to choose different basis functions for . In particular, you need to choose where . This is then called a Petrov-Galerkin approximation and truly we don’t need to think about it at all. Also I am completely eliding issues of smoothness in all of this. It maters, but it doesn’t matter too much. So let’s just assume everything exists.↩︎

If you don’t believe me you are welcome to read the monster blog post, where it’s an example.↩︎

BibTeX citation:

```
@online{simpson2023,
author = {Dan Simpson},
editor = {},
title = {Markovian {Gaussian} Processes: {A} Lot of Theory and Some
Practical Stuff},
date = {2023-01-21},
url = {https://dansblog.netlify.app/posts/},
langid = {en}
}
```

For attribution, please cite this work as:

Dan Simpson. 2023. “Markovian Gaussian Processes: A Lot of Theory
and Some Practical Stuff.” January 21, 2023. https://dansblog.netlify.app/posts/.

And, along the way, I’m going to re-do the sparse Cholesky factorisation to make it, well, better.

In order to temper expectations, I will tell you that this post does not do the numerical factorisation, only the symbolic one. Why? Well I wrote most of it on a long-haul flight and I didn’t get to the numerical part. And this was long enough. So hold your breaths for Part 7b, which will come as soon as I write it.

You can consider this a *much* better re-do of Part 2. This is no longer my first python coding exercise in a decade, so hopefully the code is better. And I’m definitely trying a lot harder to think about the limitations of JAX.

Before I start, I should probably say why I’m doing this. JAX is a truly magical thing that will compute gradients and every thing else just by clever processing of the Jacobian-vector product code. Unfortunately, this is only possible if the Jacobian-vector product code is JAX traceable and this code is structurally extremely similar^{1} to the code for the sparse Cholesky factorisation.

I am doing this in the hope of (eventually getting to) autodiff. But that won’t be this blog post. This blog post is complicated enough.

The first an most important rule of programming with JAX is that loops will break your heart. I mean, whatever, I guess they’re fine. But there’s a problem. Imagine the following function

```
def f(x: jax.Array, n: Int) -> jax.Array:
out = jnp.zeros_like(x)
for j in range(n):
out = out + x
return out
```

This is, basically, the worst implementation of multiplication by an integer that you can possibly imagine. This code will run fine in Python, but if you try to JIT compile it, JAX is gonna get *angry*. It will produce the machine code equivalent of

```
def f_n(x):
out = x
out = out + x
out = out + x
// do this n times
return out
```

There are two bad things happening here. First, note that the “compiled” code depends on `n`

and will have to be compiled anew each time `n`

changes. Secondly, the loop has been replaced by `n`

copies of the loop body. This is called *loop unrolling* and, when used judiciously by a clever compiler, is a great way to speed up code. When done completely for *every* loop this is a nightmare and the corresponding code will take a geological amount of time to compile.

A similar thing^{2} happens when you need to run autodiff on `f(x,n)`

. For each `n`

an expression graph is constructed that contains the unrolled for loop. This suggests that autodiff might also end up being quite slow (or, more problematically, more memory-hungry).

So the first rule of JAX is to avoid for loops. But if you can’t do that, there are three built-in loop structures that play nicely with JIT compilation and sometimes^{3} differentiation. These three constructs are

- A while loop
`jax.lax.while(cond_func, body_func, init)`

- An accumulator
`jax.lax.scan(body_func, init, xs)`

- A for loop
`jax.lax.fori_loop(lower, upper, body_fun, init)`

Of those three, the first and third work mostly as you’d expect, while the second is a bit more hairy. The `while`

function is roughly equivalent to

```
`
def jax_lax_while_loop(cond_func, body_func, init):
x = init
while cond_func(x):
x = body_func(x)
return x
```

So basically it’s just a while loop. The thing that’s important is that it compiles down to a single XLA operation^{4} instead of some unrolled mess.

One thing that is important to realise is that while loops are only forwards-mode differentiable, which means that it is *very* expensive^{5} to compute gradients. The reason for this is that we simply do not know how long that loop actually is and so it’s impossible to build a fixed-size expression graph.

The `jax.lax.scan`

function is probably the one that people will be least familiar with. That said, it’s also the one that is roughly “how a for loop should work”. The concept that’s important here is a for-loop with *carry over*. Carry over is information that changes from one step of the loop to the next. This is what separates us from a `map`

statement, which would apply the same function independently to each element of a list.

The scan function looks like

```
def jax_lax_scan(body_func, init, xs):
len_x0 = len(x0)
if not all(len(x) == len_x0 for x in xs):
raise ValueError("All x must have the same length!!")
carry = init
ys = []
for x in xs:
carry, y = body_func(carry, x)
ys.append(y)
return carry, np.stack(ys)
```

A critically important limitation to `jax.lax.scan`

is that is that every `x`

in `xs`

must have the same shape! This mean, for example, that

`xs = [[1], [2,3], [4], 5,6,7]`

is not a valid argument. Like all limitations in JAX, this serves to make the code transformable into efficiently compiled code across various different processors.

For example, if I wanted to use `jax.lax.scan`

on my example from before I would get

```
from jax import lax
from jax import numpy as jnp
def f(x, n):
init = jnp.zeros_like(x)
xs = jnp.repeat(x, n)
def body_func(carry, y):
val = carry + y
return (val, val)
final, journey = lax.scan(body_func, init, xs)
return (final, journey)
final, journey = f(1.2, 7)
print(final)
print(journey)
```

```
8.4
[1.2 2.4 3.6000001 4.8 6. 7.2 8.4 ]
```

This translation is a bit awkward compared to the for loop but it’s the sort of thing that you get used to.

This function can be differentiated^{6} and compiled. To differentiate it, I need a version that returns a scalar, which is easy enough to do with a lambda.

```
from jax import jit, grad
f2 = lambda x, n: f(x,n)[0]
f2_grad = grad(f2, argnums = 0)
print(f2_grad(1.2, 7))
```

`7.0`

The `argnums`

option tells JAX that we are only differentiating wrt the first argument.

JIT compilation is a tiny bit more delicate. If we try the natural thing, we are going to get an error.

```
f_jit_bad = jit(f)
bad = f_jit_bad(1.2, 7)
```

```
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
When jit-compiling jnp.repeat, the total number of repeats must be static. To fix this, either specify a static value for `repeats`, or pass a static value to `total_repeat_length`.
The error occurred while tracing the function f at /var/folders/08/4p5p665j4d966tr7nvr0v24c0000gn/T/ipykernel_24749/3851190413.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument 'n'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
```

In order to compile a function, JAX needs to know how big everything is. And right now it does not know what `n`

is. This shows itself through the `ConcretizationTypeError`

, which basically says that as JAX was looking through your code it found something it can’t manipulate. In this case, it was in the `jnp.repeat`

function.

We can fix this problem by declaring this parameter `static`

.

```
f_jit = jit(f, static_argnums=(1,))
print(f_jit(1.2,7)[0])
```

`8.4`

A static parameter is a parameter value that is known at compile time. If we define `n`

to be static, then the first time you call `f_jit(x, 7)`

it will compile and then it will reuse the compiled code for any other value of `x`

. If we then call `f_jit(x, 9)`

, the code will *compile again*.

To see this, we can make use of a JAX oddity: if a function prints something^{7}, then it will only be printed upon compilation and never again. This means that we can’t do *debug by print*. But on the upside, it’s easy to check, when things are compiling.

```
def f2(x, n):
print(f"compiling: n = {n}")
return f(x,n)[0]
f2_jit = jit(f2, static_argnums=(1,))
print(f2_jit(1.2,7))
print(f2_jit(1.8,7))
print(f2_jit(1.2,9))
print(f2_jit(1.8,7))
```

```
compiling: n = 7
8.4
12.6
compiling: n = 9
10.799999
12.6
```

This is a perfectly ok solution as long as the static parameters don’t change very often. In our context, this is going to have to do with the sparsity pattern.

Finally, we can talk about `jax.lax.fori_loop`

, the in-built for loop. This is basically a convenience wrapper for `jax.lax.scan`

(when `lower`

and `upper`

are static) or `jax.lax.while`

(when they are not). The Python pseudocode is

```
def jax_lax_fori_loop(lower, upper, body_func, init):
out = init
for i in range(lower, upper):
out = body_func(i, out)
return out
```

To close out this bit where I repeat the docs, there is also a traceable if/else: `jax.lax.cond`

which has the pseudocode

```
def jax_lax_cond(pred, true_fun, false_fun, val):
if pred:
return true_fun(val)
else:
return false_fun(val)
```

In order to build a JAX-traceable sparse Cholesky factorisation , we are going to need to build up a few moving parts.

Build the elimination tree of and find the number of non-zeros in each column of

Build the

*symbolic factorisation*^{8}of (aka the location of the non-zeros of )Do the actual numerical decomposition.

In the previous post we did not explicitly form the elimination tree. Instead, I used dynamic memory allocation. This time I’m being more mature.

The elimination tree^{9} is a (forest of) rooted tree(s) that compactly represent the non-zero pattern of the Cholesky factor . In particular, the elimination tree has the property that, for any , if and only if there is a path from to in the tree. Or, in the language of trees, if and only if is a descendant of in the tree .

We can describe^{10} by listing the parent of each node. The parent node of in the tree is the smallest with .

We can turn this into an algorithm. An efficient version, which is described in Tim Davies book takes about operations. But I’m going to program up a slower one that takes operations, but has the added benefit^{11} of giving me the column counts for free.

To do this, we are going to walk the tree and dynamically add up the column counts as we go.

To start off, let’s do this in standard python so that we can see what the algorithm look like. The key concept is that if we write as the elimination tree encoding the structure of^{12} `L[:j, :j]`

, then we can ask about how this tree connects with node `j`

.

A theorem gives a very simple answer to this.

**Theorem 1 **If , then implies that is a descendant of in . In particular, that means that there is a directed path in from to .

This tells us that the connection between and node is that for each non-zero elements of the th row of , we can walk $ must have a path in from and we will eventually get to a node that has no parent in . Because there *must* be a path from to in , it means that the parent of this terminal node must be .

As with everything Cholesky related, this works because the algorithm proceeds from left to right, which in this case means that the node label associated with *any* descendant of is always less than .

The algorithm is then a fairly run-of-the-mill^{13} tree traversal, where we keep track of where we have been so we don’t double count our columns.

Probably the most important thing here is that I am using the *full* sparse matrix rather than just its lower triangle. This is, basically, convenience. I need access to the left half of the th row of , which is conveniently the same as the top half of the th column. And sometimes you just don’t want to be dicking around with swapping between row- and column-based representations.

```
import numpy as np
def etree_base(A_indices, A_indptr):
n = len(A_indptr) - 1
parent = [-1] * n
mark = [-1] * n
col_count = [1] * n
for j in range(n):
mark[j] = j
for indptr in range(A_indptr[j], A_indptr[j+1]):
node = A_indices[indptr]
while node < j and mark[node] != j:
if parent[node] == -1:
parent[node] = j
mark[node] = j
col_count[node] += 1
node = parent[node]
return (parent, col_count)
```

To convince ourselves this works, let’s run an example and compare the column counts we get to our previous method.

```
from scipy import sparse
import scipy as sp
def make_matrix(n):
one_d = sparse.diags([[-1.]*(n-2), [2.]*n, [-1.]*(n-2)], [-2,0,2])
A = (sparse.kronsum(one_d, one_d) + sparse.eye(n*n))
A_csc = A.tocsc()
A_csc.eliminate_zeros()
A_lower = sparse.tril(A_csc, format = "csc")
A_index = A_lower.indices
A_indptr = A_lower.indptr
A_x = A_lower.data
return (A_index, A_indptr, A_x, A_csc)
def _symbolic_factor(A_indices, A_indptr):
# Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.
n = len(A_indptr) - 1
L_sym = [np.array([], dtype=int) for j in range(n)]
children = [np.array([], dtype=int) for j in range(n)]
for j in range(n):
L_sym[j] = A_indices[A_indptr[j]:A_indptr[j + 1]]
for child in children[j]:
tmp = L_sym[child][L_sym[child] > j]
L_sym[j] = np.unique(np.append(L_sym[j], tmp))
if len(L_sym[j]) > 1:
p = L_sym[j][1]
children[p] = np.append(children[p], j)
L_indptr = np.zeros(n+1, dtype=int)
L_indptr[1:] = np.cumsum([len(x) for x in L_sym])
L_indices = np.concatenate(L_sym)
return L_indices, L_indptr
```

```
# A_indices/A_indptr are the lower triangle, A is the entire matrix
A_indices, A_indptr, A_x, A = make_matrix(37)
parent, col_count = etree_base(A.indices, A.indptr)
L_indices, L_indptr = _symbolic_factor(A_indices, A_indptr)
true_parent = L_indices[L_indptr[:-2] + 1]
true_parent[np.where(np.diff(L_indptr[:-1]) == 1)] = -1
print(all(x == y for (x,y) in zip(parent[:-1], true_parent)))
true_col_count = np.diff(L_indptr)
print(all(true_col_count == col_count))
```

```
True
True
```

Excellent. Now we just need to convert it to JAX.

Or do we?

To be honest, this is a little pointless. This function is only run once per matrix so we won’t really get much speedup^{14} from compilation.

Nevertheless, we might try.

```
@jit
def etree(A_indices, A_indptr):
# print("(Re-)compiling etree(A_indices, A_indptr)")
## innermost while loop
def body_while(val):
# print(val)
j, node, parent, col_count, mark = val
update_parent = lambda x: x[0].at[x[1]].set(x[2])
parent = lax.cond(lax.eq(parent[node], -1), update_parent, lambda x: x[0], (parent, node, j))
mark = mark.at[node].set(j)
col_count = col_count.at[node].add(1)
return (j, parent[node], parent, col_count, mark)
def cond_while(val):
j, node, parent, col_count, mark = val
return lax.bitwise_and(lax.lt(node, j), lax.ne(mark[node], j))
## Inner for loop
def body_inner_for(indptr, val):
j, A_indices, A_indptr, parent, col_count, mark = val
node = A_indices[indptr]
j, node, parent, col_count, mark = lax.while_loop(cond_while, body_while, (j, node, parent, col_count, mark))
return (j, A_indices, A_indptr, parent, col_count, mark)
## Outer for loop
def body_out_for(j, val):
A_indices, A_indptr, parent, col_count, mark = val
mark = mark.at[j].set(j)
j, A_indices, A_indptr, parent, col_count, mark = lax.fori_loop(A_indptr[j], A_indptr[j+1], body_inner_for, (j, A_indices, A_indptr, parent, col_count, mark))
return (A_indices, A_indptr, parent, col_count, mark)
## Body of code
n = len(A_indptr) - 1
parent = jnp.repeat(-1, n)
mark = jnp.repeat(-1, n)
col_count = jnp.repeat(1, n)
init = (A_indices, A_indptr, parent, col_count, mark)
A_indices, A_indptr, parent, col_count, mark = lax.fori_loop(0, n, body_out_for, init)
return parent, col_count
```

Wow. That is *ugly*. But let’s see^{15} if it works!

```
parent_jax, col_count_jax = etree(A.indices, A.indptr)
print(all(x == y for (x,y) in zip(parent_jax[:-1], true_parent)))
print(all(true_col_count == col_count_jax))
```

`True`

`True`

Success!

I guess we could ask ourselves if we gained any speed.

Here is the pure python code.

```
import timeit
A_indices, A_indptr, A_x, A = make_matrix(20)
times = timeit.repeat(lambda: etree_base(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(50)
times = timeit.repeat(lambda: etree_base(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
times = timeit.repeat(lambda: etree_base(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

```
n = 400: [0.0, 0.0, 0.0, 0.0, 0.0]
n = 2500: [0.03, 0.03, 0.03, 0.03, 0.03]
```

`n = 40000: [0.83, 0.82, 0.82, 0.82, 0.82]`

And here is our JAX’d and JIT’d code.

```
A_indices, A_indptr, A_x, A = make_matrix(20)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(50)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
parent, col_count= etree(A.indices, A.indptr)
L_indices, L_indptr = _symbolic_factor(A_indices, A_indptr)
A_indices, A_indptr, A_x, A = make_matrix(1000)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 400: [0.13, 0.0, 0.0, 0.0, 0.0]`

`n = 2500: [0.12, 0.0, 0.0, 0.0, 0.0]`

`n = 40000: [0.14, 0.02, 0.02, 0.02, 0.02]`

`n = 1000000: [2.24, 2.11, 2.12, 2.12, 2.12]`

You can see that there is some decent speedup. For the first three examples, the computation time is dominated by the compilation time, but we see when the matrix has a million unknowns the compilation time is negligible. At this scale it would probably be worth using the fancy algorithm. That said, it is probably not worth sweating a three second that is only done once when your problem is that big!

Now that we know how many non-zeros there are, it’s time to populate them. Last time, I used some dynamic memory allocation to make this work, but JAX is certainly not going to allow me to do that. So instead I’m going to have to do the worst thing possible: think.

The way that we went about it last time was, to be honest, a bit arse-backwards. The main reason for this is that I did not have access to the elimination tree. But now we do, we can actually use it.

The trick is to slightly rearrange^{16} the order of operations to get something that is more convenient for working out the structure.

Recall from last time that we used the *left-looking* Cholesky factorisation, which can be written in the dense case as

```
def dense_left_cholesky(A):
n = A.shape[0]
L = np.zeros_like(A)
for j in range(n):
L[j,j] = np.sqrt(A[j,j] - np.inner(L[j, :j], L[j, :j]))
L[(j+1):, j] = (A[(j+1):, j] - L[(j+1):, :j] @ L[j, :j].transpose()) / L[j,j]
return L
```

This is not the only way to organise those operations. An alternative is the *up-looking* Cholesky factorisation, which can be implemented in the dense case as

```
def dense_up_cholesky(A):
n = A.shape[0]
L = np.zeros_like(A)
L[0,0] = np.sqrt(A[0,0])
for i in range(1,n):
#if i > 0:
L[i, :i] = (np.linalg.solve(L[:i, :i], A[:i,i])).transpose()
L[i, i] = np.sqrt(A[i,i] - np.inner(L[i, :i], L[i, :i]))
return L
```

This is quite a different looking beast! It scans row by row rather than column by column. And while the left-looking algorithm is based on matrix-vector multiplies, the up-looking algorithm is based on triangular solves. So maybe we should pause for a moment to check that these are the same algorithm!

```
A = np.random.rand(15, 15)
A = A + A.transpose()
A = A.transpose() @ A + 1*np.eye(15)
L_left = dense_left_cholesky(A)
L_up = dense_up_cholesky(A)
print(round(sum(sum(abs((L_left - L_up)[:])))),2)
```

`0 2`

They are the same!!

The reason for considering the up-looking algorithm is that it gives a slightly nicer description of the non-zeros of row `i`

, which will let us find the location of the non-zeros in the whole matrix. In particular, the non-zeros to the left of the diagonal on row `i`

correspond to the non-zero indices of the solution to the lower triangular linear system^{17} Because is sparse, this is a system of linear equations, rather than equations that we would have in the dense case. That means that the sparsity pattern of will be the union of the sparsity patterns of the columns of that correspond to the non-zero entries of .

This means two things. Firstly, if , then . Secondly, if $x^{(i)}_j $ *and* , then . These two facts give us a way of finding the non-zero set of if we remember just one more fact: a definition of the elimination tree is that if is a descendant of in the elimination tree.

This reduces the problem of finding the non-zero elements of to the problem of finding all of the descendants of in the subtree . And if there is one thing that people who are ok at programming are *excellent* at it is walking down a damn tree.

So let’s do that. Well, I’ve already done it. In fact, that was how I found the column counts in the first place! With this interpretation, the outer loop is taking us across the rows. And once I am in row `j`

^{18}, I then find a starting node `node`

(which is a non-zero in ) and I walk along that node checking each time if I’ve actually seen that node^{19} before. If I haven’t seen it before, I added one to the column count of column `node`

^{20}.

To allocate the non-zero structure, I just need to replace that counter increment with an assignment.

We will do the pure python version first.

```
def symbolic_cholesky_base(A_indices, A_indptr, parent, col_count):
n = len(A_indptr) - 1
col_ptr = np.repeat(1, n+1)
col_ptr[1:] += np.cumsum(col_count)
L_indices = np.zeros(sum(col_count), dtype=int)
L_indptr = np.zeros(n+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
mark = [-1] * n
for i in range(n):
mark[i] = i
L_indices[L_indptr[i]] = i
for indptr in range(A_indptr[i], A_indptr[i+1]):
node = A_indices[indptr]
while node < i and mark[node] != i:
mark[node] = i
L_indices[col_ptr[node]] = i
col_ptr[node] += 1
node = parent[node]
return (L_indices, L_indptr)
```

Does it work?

```
A_indices, A_indptr, A_x, A = make_matrix(13)
parent, col_count = etree_base(A.indices, A.indptr)
L_indices, L_indptr = symbolic_cholesky_base(A.indices, A.indptr, parent, col_count)
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(x==y for (x,y) in zip(L_indices, L_indices_true)))
print(all(x==y for (x,y) in zip(L_indptr, L_indptr_true)))
```

```
True
True
```

Fabulosa!

Now let’s do the compiled version.

```
from functools import partial
@partial(jit, static_argnums = (4,))
def symbolic_cholesky(A_indices, A_indptr, L_indptr, parent, nnz):
## innermost while loop
def body_while(val):
i, L_indices, L_indptr, node, parent, col_ptr, mark = val
mark = mark.at[node].set(i)
#p =
L_indices = L_indices.at[col_ptr[node]].set(i)
col_ptr = col_ptr.at[node].add(1)
return (i, L_indices, L_indptr, parent[node], parent, col_ptr, mark)
def cond_while(val):
i, L_indices, L_indptr, node, parent, col_ptr, mark = val
return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))
## Inner for loop
def body_inner_for(indptr, val):
i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
node = A_indices[indptr]
i, L_indices, L_indptr, node, parent, col_ptr, mark = lax.while_loop(cond_while, body_while, (i, L_indices, L_indptr, node, parent, col_ptr, mark))
return (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
## Outer for loop
def body_out_for(i, val):
A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
mark = mark.at[i].set(i)
L_indices = L_indices.at[L_indptr[i]].set(i)
i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(A_indptr[i], A_indptr[i+1], body_inner_for, (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark))
return (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
## Body of code
n = len(A_indptr) - 1
col_ptr = L_indptr + 1
L_indices = jnp.zeros(nnz, dtype=int)
mark = jnp.repeat(-1, n)
init = (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(0, n, body_out_for, init)
return L_indices
```

Now let’s check it works

```
A_indices, A_indptr, A_x, A = make_matrix(20)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1])
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
A_indices, A_indptr, A_x, A = make_matrix(31)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1])
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
```

```
True
True
```

```
True
True
```

Success!

One *minor* problem. This is slow. as. balls.

```
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree_base(A.indices, A.indptr)
times = timeit.repeat(lambda: symbolic_cholesky_base(A.indices, A.indptr, parent, col_count),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree_base(A.indices, A.indptr)
times = timeit.repeat(lambda: symbolic_cholesky_base(A.indices, A.indptr, parent, col_count),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 2500: [0.05, 0.04, 0.04, 0.04, 0.04]`

`n = 40000: [1.97, 2.09, 2.04, 2.03, 1.92]`

And here is our JAX’d and JIT’d code.

```
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 2500: [0.15]`

`n = 40000: [29.19]`

Oooof. Something is going horribly wrong.

The first thing to check is if it’s the compile time. We can do this by explicitly *lowering* the the JIT’d function to its XLA representation and then compiling it.

```
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: jit(partial(symbolic_cholesky, nnz=int(L_indptr[-1]))).lower(A.indices, A.indptr, L_indptr, parent).compile(),number = 1, repeat = 5)
print(f"Compilation time: n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: jit(partial(symbolic_cholesky, nnz=int(L_indptr[-1]))).lower(A.indices, A.indptr, L_indptr, parent).compile(),number = 1, repeat = 5)
print(f"Compilation time: n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`Compilation time: n = 2500: [0.15, 0.15, 0.15, 0.16, 0.15]`

`Compilation time: n = 40000: [0.16, 0.15, 0.14, 0.16, 0.15]`

It is not the compile time.

And that is actually a good thing because that suggests that we aren’t having problems with the compiler unrolling all of our wonderful loops! But that does mean that we have to look a bit deeper into the code. Some smart people would probably be able to look at the `jaxpr`

intermediate representation to diagnose the problem. But I couldn’t see anything there.

Instead I thought *if I were a clever, efficient compiler, what would I have problems with?*. And the answer is the classic sparse matrix answer: indirect indexing.

The only structural difference between the `etree`

function and the `symbolic_cholesky`

function is this line in the `body_while()`

function:

` L_indices = L_indices.at[col_ptr[node]].set(i)`

In order to evaluate this code, the compiler has to resolve *two levels* of indirection. By contrast, the indexing in `etree()`

was always direct. So let’s see what happens if we take the same function and remove that double indirection.

```
@partial(jit, static_argnums = (4,))
def test_fun(A_indices, A_indptr, L_indptr, parent, nnz):
## innermost while loop
def body_while(val):
i, L_indices, L_indptr, node, parent, col_ptr, mark = val
mark = mark.at[node].set(i)
L_indices = L_indices.at[node].set(i)
col_ptr = col_ptr.at[node].add(1)
return (i, L_indices, L_indptr, parent[node], parent, col_ptr, mark)
def cond_while(val):
i, L_indices, L_indptr, node, parent, col_ptr, mark = val
return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))
## Inner for loop
def body_inner_for(indptr, val):
i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
node = A_indices[indptr]
i, L_indices, L_indptr, node, parent, col_ptr, mark = lax.while_loop(cond_while, body_while, (i, L_indices, L_indptr, node, parent, col_ptr, mark))
return (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
## Outer for loop
def body_out_for(i, val):
A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
mark = mark.at[i].set(i)
L_indices = L_indices.at[L_indptr[i]].set(i)
i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(A_indptr[i], A_indptr[i+1], body_inner_for, (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark))
return (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
## Body of code
n = len(A_indptr) - 1
col_ptr = L_indptr + 1
L_indices = jnp.zeros(nnz, dtype=int)
mark = jnp.repeat(-1, n)
init = (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(0, n, body_out_for, init)
return L_indices
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:test_fun(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:test_fun(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 2500: [0.14]`

`n = 40000: [0.17]`

That isn’t conclusive, but it does indicate that this might^{21} be the problem.

And this is a *big* problem for us! The sparse Cholesky algorithm has similar amounts of indirection. So we need to fix it.

Now. I want to pretend that I’ve got elegant ideas about this. But I don’t. So let’s just do it. The most obvious thing to do is to use the algorithm to get the non-zero structure of the *rows* of . These are the things that are being indexed by `col_ptr[node]]`

, so if we have these explicitly we don’t need multiple indirection. We also don’t need a while loop.

In fact, if we have the non-zero structure of the rows of , we can turn that into the non-zero structure of the columns in linear-ish^{22} time.

All we need to do is make sure that our `etree()`

function is also counting the number of nonzeros in each row.

```
@jit
def etree(A_indices, A_indptr):
# print("(Re-)compiling etree(A_indices, A_indptr)")
## innermost while loop
def body_while(val):
# print(val)
j, node, parent, col_count, row_count, mark = val
update_parent = lambda x: x[0].at[x[1]].set(x[2])
parent = lax.cond(lax.eq(parent[node], -1), update_parent, lambda x: x[0], (parent, node, j))
mark = mark.at[node].set(j)
col_count = col_count.at[node].add(1)
row_count = row_count.at[j].add(1)
return (j, parent[node], parent, col_count, row_count, mark)
def cond_while(val):
j, node, parent, col_count, row_count, mark = val
return lax.bitwise_and(lax.lt(node, j), lax.ne(mark[node], j))
## Inner for loop
def body_inner_for(indptr, val):
j, A_indices, A_indptr, parent, col_count, row_count, mark = val
node = A_indices[indptr]
j, node, parent, col_count, row_count, mark = lax.while_loop(cond_while, body_while, (j, node, parent, col_count, row_count, mark))
return (j, A_indices, A_indptr, parent, col_count, row_count, mark)
## Outer for loop
def body_out_for(j, val):
A_indices, A_indptr, parent, col_count, row_count, mark = val
mark = mark.at[j].set(j)
j, A_indices, A_indptr, parent, col_count, row_count, mark = lax.fori_loop(A_indptr[j], A_indptr[j+1], body_inner_for, (j, A_indices, A_indptr, parent, col_count, row_count, mark))
return (A_indices, A_indptr, parent, col_count, row_count, mark)
## Body of code
n = len(A_indptr) - 1
parent = jnp.repeat(-1, n)
mark = jnp.repeat(-1, n)
col_count = jnp.repeat(1, n)
row_count = jnp.repeat(1, n)
init = (A_indices, A_indptr, parent, col_count, row_count, mark)
A_indices, A_indptr, parent, col_count, row_count, mark = lax.fori_loop(0, n, body_out_for, init)
return (parent, col_count, row_count)
```

Let’s check that the code is actually doing what I want.

```
A_indices, A_indptr, A_x, A = make_matrix(57)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indices, L_indptr = _symbolic_factor(A_indices, A_indptr)
true_parent = L_indices[L_indptr[:-2] + 1]
true_parent[np.where(np.diff(L_indptr[:-1]) == 1)] = -1
print(all(x == y for (x,y) in zip(parent[:-1], true_parent)))
true_col_count = np.diff(L_indptr)
print(all(true_col_count == col_count))
true_row_count = np.array([len(np.where(L_indices == i)[0]) for i in range(57**2)])
print(all(true_row_count == row_count))
```

```
True
True
```

`True`

Excellent! With this we can modify our previous function to give us the row-indices of the non-zero pattern instead. Just for further chaos, please note that we are using a CSC representation of to get a CSR representation of .

Once again, we will prototype in pure python and then translate to JAX. The thing to look out for this time is that we *know* how many non-zeros there are in a row and we know where we need to put them. This suggests that we can compute these things in `body_inner_for`

and then do a vectorised version of our indirect indexing. This should compile down to a single XLA `scatter`

call. This will reduce the number of overall `scatter`

calls from to . And hopefully this will fix things.

```
def symbolic_cholesky2_base(A_indices, A_indptr, L_indptr, row_count, parent, nnz):
n = len(A_indptr) - 1
col_ptr = L_indptr + 1
L_indices = np.zeros(L_indptr[-1], dtype=int)
mark = [-1] * n
for i in range(n):
mark[i] = i
row_ind = np.repeat(nnz+1, row_count[i])
row_ind[-1] = L_indptr[i]
counter = 0
for indptr in range(A_indptr[i], A_indptr[i+1]):
node = A_indices[indptr]
while node < i and mark[node] != i:
mark[node] = i
row_ind[counter] = col_ptr[node]
col_ptr[node] += 1
node = parent[node]
counter +=1
L_indices[row_ind] = i
return L_indices
A_indices, A_indptr, A_x, A = make_matrix(13)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky2_base(A.indices, A.indptr, L_indptr, row_count, parent, L_indptr[-1])
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(x==y for (x,y) in zip(L_indices, L_indices_true)))
print(all(x==y for (x,y) in zip(L_indptr, L_indptr_true)))
```

```
True
True
```

Excellent. Now let’s JAX this. The JAX-heads among you will notice that we have a subtle^{23} problem: in a `fori_loop`

, JAX does not treat `i`

as static, which means that the length of the repeat (`row_count[i]`

) can never be static and it therefore can’t be traced.

Shit.

It is hard to think of a good option here. A few months back Junpeng Lao^{24} sent me a script with his attempts at making the Cholesky stuff JAX transformable. And he hit the same problem. I was, in an act of hubris, trying very hard to not end up here. But that was tragically slow. So here we are.

He came up with two methods.

Pad out

`row_ind`

so it’s always long enough. This only costs memory. The maximum size of`row_ind`

is`n`

. Unfortunately, this happens whenever has a dense row. Sadly, for Bayesian^{25}linear mixed models this will happen if we put Gaussian priors on the covariate coefficients^{26}and we try to marginalise them out with the other multivariate Gaussian parts. It is possible to write the routines that deal with dense rows and columns explicitly, but it’s a pain in the arse.Do some terrifying work with

`lax.scan`

and dynamic slicing.

I’m going to try the first of these options.

```
@partial(jit, static_argnums = (5, 6))
def symbolic_cholesky2(A_indices, A_indptr, L_indptr, row_count, parent, nnz, max_row):
## innermost while loop
def body_while(val):
i, counter, row_ind, node, col_ptr, mark = val
mark = mark.at[node].set(i)
row_ind = row_ind.at[counter].set(col_ptr[node])
col_ptr = col_ptr.at[node].add(1)
return (i, counter + 1, row_ind, parent[node], col_ptr, mark)
def cond_while(val):
i, counter, row_ind, node, col_ptr, mark = val
return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))
## Inner for loop
def body_inner_for(indptr, val):
i, counter, row_ind, parent, col_ptr, mark = val
node = A_indices[indptr]
i, counter, row_ind, node, col_ptr, mark = lax.while_loop(cond_while, body_while, (i, counter, row_ind, node, col_ptr, mark))
return (i, counter, row_ind, parent, col_ptr, mark)
## Outer for loop
def body_out_for(i, val):
L_indices, parent, col_ptr, mark = val
mark = mark.at[i].set(i)
row_ind = jnp.repeat(nnz+1, max_row)
row_ind = row_ind.at[row_count[i]-1].set(L_indptr[i])
counter = 0
i, counter, row_ind, parent, col_ptr, mark = lax.fori_loop(A_indptr[i], A_indptr[i+1], body_inner_for, (i, counter, row_ind, parent, col_ptr, mark))
L_indices = L_indices.at[row_ind].set(i, mode = "drop")
return (L_indices, parent, col_ptr, mark)
## Body of code
n = len(A_indptr) - 1
col_ptr = jnp.array(L_indptr + 1)
L_indices = jnp.ones(nnz, dtype=int) * (-1)
mark = jnp.repeat(-1, n)
## Make everything a jnp array. Really should use jaxtyping
A_indices = jnp.array(A_indices)
A_indptr = jnp.array(A_indptr)
L_indptr = jnp.array(L_indptr)
row_count = jnp.array(row_count)
parent = jnp.array(parent)
init = (L_indices, parent, col_ptr, mark)
L_indices, parent, col_ptr, mark = lax.fori_loop(0, n, body_out_for, init)
return L_indices
```

Ok. Let’s see if that worked.

```
A_indices, A_indptr, A_x, A = make_matrix(20)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count)))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
A_indices, A_indptr, A_x, A = make_matrix(31)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count)))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
```

```
True
True
```

```
True
True
```

Ok. Once more into the breach. Is this any better?

```
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count))),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count))),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
# A_indices, A_indptr, A_x, A = make_matrix(300)
# parent, col_count, row_count = etree(A.indices, A.indptr)
# L_indptr = np.zeros(A.shape[0]+1, dtype=int)
# L_indptr[1:] = np.cumsum(col_count)
# times = timeit.repeat(lambda:symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count))),number = 1, repeat = 1)
# print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 2500: [0.28]`

`n = 40000: [28.31]`

Fuck.

Right. Let’s try again. What if instead of doing all those scatters we instead, idk, just store two vectors and sort. Because at this point I will try fucking anything. What if we just list out the column index and row index as we find them (aka build the sparse matrix in COO^{27} format. The `jax.experimental.sparse`

module has support for (blocked) COO objects but doesn’t implement this transformation. `scipy.sparse`

has a fast conversion routine so I’m going to use it. In the interest of being 100% JAX, I tried a version with `jnp.lexsort[index[1][jnp.lexsort((index[1], index[0]))]`

], which basically does the same thing but it’s a lot slower.

```
def symbolic_cholesky3(A_indices, A_indptr, L_indptr, parent, nnz):
@partial(jit, static_argnums = (4,))
def _inner(A_indices_, A_indptr_, L_indptr, parent, nnz):
## Make everything a jnp array. Really should use jaxtyping
A_indices_ = jnp.array(A_indices_)
A_indptr_ = jnp.array(A_indptr_)
L_indptr = jnp.array(L_indptr)
parent = jnp.array(parent)
## innermost while loop
def body_while(val):
index, i, counter, node, mark = val
mark = mark.at[node].set(i)
index[0] = index[0].at[counter].set(node) #column
index[1] = index[1].at[counter].set(i) # row
return (index, i, counter + 1, parent[node], mark)
def cond_while(val):
index, i, counter, node, mark = val
return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))
## Inner for loop
def body_inner_for(indptr, val):
index, i, counter, mark = val
node = A_indices_[indptr]
index, i, counter, node, mark = lax.while_loop(cond_while, body_while, (index, i, counter, node, mark))
return (index, i, counter, mark)
## Outer for loop
def body_out_for(i, val):
index, counter, mark = val
mark = mark.at[i].set(i)
index[0] = index[0].at[counter].set(i)
index[1] = index[1].at[counter].set(i)
counter = counter + 1
index, i, counter, mark = lax.fori_loop(A_indptr_[i], A_indptr_[i+1], body_inner_for, (index, i, counter, mark))
return (index, counter, mark)
## Body of code
n = len(A_indptr_) - 1
mark = jnp.repeat(-1, n)
index = [jnp.zeros(nnz, dtype=int), jnp.zeros(nnz, dtype=int)]
counter = 0
init = (index, counter, mark)
index, counter, mark = lax.fori_loop(0, n, body_out_for, init)
return index
n = len(A_indptr) - 1
index = _inner(A_indices, A_indptr, L_indptr, parent, nnz)
## return jnp.lexsort[index[1][jnp.lexsort((index[1], index[0]))
return sparse.coo_array((np.ones(nnz), (index[1], index[0])), shape = (n,n)).tocsc().indices
```

First things first, let’s check how fast this is.

```
A_indices, A_indptr, A_x, A = make_matrix(15)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1]))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
A_indices, A_indptr, A_x, A = make_matrix(31)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1]))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 5)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 5)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(300)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 5)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(1000)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

```
True
True
True
True
```

`n = 2500: [0.13, 0.13, 0.13, 0.14, 0.14]`

`n = 40000: [0.19, 0.19, 0.19, 0.19, 0.19]`

`n = 90000: [0.43, 0.32, 0.32, 0.33, 0.33]`

`n = 1000000: [11.91]`

You know what? I’ll take it. It’s not perfect, in particular I would prefer a pure JAX solution. But everything I tried hit hard against the indirect memory access issue. The best I found was using `jnp.lexsort`

but even it had noticeable performance degradation as `nnz`

increased relative to the scipy solution.

So that’s where I’m going to leave it. I am off my flight and I’ve slept very well and now I’m going to be on holidays for a little while.

The next big thing to do is look at the numerical factorisation. We are going to run headlong into all of the problems we’ve hit today, so that should be fun. The reason why I’m separating it into a separate post^{28} is that I want to actually test all of those things out properly.

So next time you can expect

Classes! Because frankly this code is getting far too messy, especially now that certain things need to be passed as static arguments. The only reason I’ve avoided it up to now is that I think it hides too much of the algorithm in boilerplate. But now the boilerplate is ruining my life and causing far too many dumb typos

^{29}.Type hints! Because for a language where types aren’t explicit, they sure are important. Also because I’m going to class it up I might as well do it properly.

Some helper routines! I’m going to need a sparse-matrix scatter operation (aka the structured copy of to have the sparsity pattern of )! And I’m certainly going to need some reorderings

^{30}A battle royale between padded and non-padded methods!

It should be a fun time!

If you’re wondering about the break between sparse matrix posts, I realised this pretty much immediately and just didn’t want to deal with it!↩︎

If a person who actually knows how the JAX autodiff works happens across this blog, I’m so sorry.↩︎

omg you guys. So many details↩︎

These are referred to as HLOs (Higher-level operations)↩︎

Instead of doing one pass of reverse-mode, you would need to do passes of forwards mode to get the gradient with respect to a d-dimensional parameter.↩︎

Unlike

`jax.lax.while`

, which is only forwards differentiable,`jax.lax.scan`

is fully differentiable.↩︎In general, if the function has state.↩︎

This is the version of the symbolic factorisation that is most appropriate for us, as we will be doing a lot of Cholesky factorisations with the same sparsity structure. If we rearrange the algorithm to the up-looking Cholesky decomposition, we only need the column counts and this is also called the symbolic factorisation. This is, incidentally, how Eigen’s sparse Cholesky works.↩︎

Actually it’s a forest↩︎

Because we are talking about a tree, each child node has at most one parent. If it doesn’t have a parent it’s the root of the tree. I remember a lecturer saying that it should be called “father and son” or “mother and daughter” because every child has 2 parents but only one mother or one father. The 2000s were a wild time.↩︎

These can also be computed in approximately time, which is much faster. But the algorithm is, frankly, pretty tricky and I’m not in the mood to program it up. This difference would be quite important if I wasn’t storing the full symbolic factorisation and was instead computing it every time, but in my context it is less clear that this is worth the effort.↩︎

Python notation! This is rows/cols 0 to

`j-1`

↩︎Python, it turns out, does not have a

`do while`

construct because, apparently, everything is empty and life is meaningless.↩︎The argument for JIT works by amortizing the compile time over several function evaluations. If I wanted to speed this algorithm up, I’d implement the more complex version.↩︎

Obviously it did not work the first time. A good way to debug JIT’d code is to use the python translations of the control flow literals. Why? Well for one thing there is an annoying tendency for JAX to fail silently when their is an out-of-bounds indexing error. Which happens, just for example, if you replace

`node = A_indices[indptr]`

with`node = A_indices[A_indptr[indptr]]`

because you got a text message half way through the line.↩︎We will still use the left-looking algorithm for the numerical computation. The two algorithms are equivalent in exact arithmetic and, in particular, have identical sparsity structures.↩︎

I’m mixing 1-based indexing in the maths with 0-based in the code because I think we need more chaos in our lives.↩︎

Yes. I know. I’m swapping the meaning of and but you know that’s because in a symmetric matrix rows and columns are a bit similar. The upper half of column $ is the left half of row after all.↩︎

If

`mark[node]==j`

then I have already found`node`

and all of its ancestors in my sweep of row`j`

↩︎This is because

`L[j,node] != 0`

by our logic.↩︎I mean, I’m pretty sure it is. I’m writing this post in order, so I don’t know yet. But surely the compiler can’t reason about the possible values of

`node`

, which would be the only thing that would speed this up.↩︎Convert from CSR to

`(i, j, val)`

(called COO, which has a convenient implementation in`jax.experimental.sparse`

) to CSC. This involves a linear pass, a sort, and another linear pass. So it’s $n n`ish. Hire me fancy tech companies. I can count. Just don’t ask me to program quicksort.↩︎Replace “subtle” with “fairly obvious once I realised how it’s converted to a

`lax.scan`

, but not at all obvious to me originally”.↩︎Who demanded a footnote.↩︎

This also happens with the profile likelihood in non-Bayesian methods.↩︎

the s↩︎

COO stands for

*coordinate list*and it’s the least space-efficient of our options. It directly stores 3 length`n`

vectors`(row, col, value)`

. It’s great for specifying matrices and it’s pretty easy to convert from this format to any of the others.↩︎other than holiday↩︎

`A_index`

and`A.index`

are different↩︎I’m probably going to bind Eigen’s AMD decomposition. I’m certainly not writing it myself.↩︎

BibTeX citation:

```
@online{simpson2022,
author = {Dan Simpson},
editor = {},
title = {Sparse Matrices Part 7a: {Another} Shot at {JAX-ing} the
{Cholesky} Decomposition},
date = {2022-12-02},
url = {https://dansblog.netlify.app/posts/2022-11-27-sparse7/sparse7.html},
langid = {en}
}
```

For attribution, please cite this work as:

Dan Simpson. 2022. “Sparse Matrices Part 7a: Another Shot at
JAX-Ing the Cholesky Decomposition.” December 2, 2022. https://dansblog.netlify.app/posts/2022-11-27-sparse7/sparse7.html.

This question comes up a bunch. In this context, they were switching from double to single precision^{3} and were a little worried that some of their operations would be a bit more inexact than they were used to. Would this tank MCMC? Would everything still be fine?

Markov chain Monte Carlo (MCMC) is, usually, guess-and-check for people who want to be fancy.

It is a class of algorithms that allow you to construct a^{4} Markov chain that has a given *stationary distribution*^{5} . In Bayesian applications, we usually want to choose , but there are other applications of MCMC.

Most^{6} MCMC algorithms live in the Metropolis-Hastings family of algorithms. These methods require only one component: a proposal distribution . Given basically any^{7} proposal distribution, we can go from our current state to the new state using the following three steps:

Propose a potential new state

Sample a Bernoulli random variable with

Set according to the formula

The acceptance probability^{8} is chosen^{9} to balance^{10} out the proposal with the target distribution .

You can interpret the two ratios in the acceptance probability separately. The first one prefers proposals from high-density regions over proposals from low-density regions. The second ratio balances this by down-weighting proposed states that were *easy* to propose from the current location. When the proposal is symmetric, ie , the second ratio is always 1. However, in better algorithms like MALA^{11}, the proposal is not symmetric. If we look at the MALA proposal it’s pretty easy to see that we are biasing our samples towards the mode of the distribution. If we did not have the second ratio in the acceptance probability we would severely under-sample the tails of the distribution.

With this definition in hand, it’s now possible to re-cast the question my friend asked as > What happens to my MCMC algorithm if, instead of I accidentally compute and use that instead to simulate ?

So let’s go about answering that!

Unsurprisingly, this type of question has popped up over and over again in the literature:

This exact question was asked by Gareth Roberts and Jeff Rosenthal first

^{12}with Peter Schwartz and a second, more^{13}^{14}realistic, time with Laird Breyer. They found that as long as the chain’s convergence is sufficiently nice^{15}then the perturbed chain will converge nicely and have^{16}a central limit theorem.About 10 years ago, an absolute orgy

^{17}^{18}of research happened around the question*What happens if the acceptance probability is random but unbiased?*. These*exact approximate*^{19}or*pseudo-marginal*methods. These have some success in situations^{20}where the likelihood has a*parameter dependent*normalising constant that can’t be computed exactly, but can be estimated unbiasedly. The problem with this class of methods is that the extra noise tends to make the Markov chain perform pretty badly^{21}. This limits its practical use to models where we really can’t do anything else^{22}. That said, there is some interesting literature on random sub-sampling of data where it doesn’t really work and where it does work.A third branch of literature is on truly approximate algorithms. These try to understand what happens if you’re just wrong with and you don’t do anything to correct it. There are a lot of papers on this, and I’m not going to do anything approaching a thorough review. I have work

^{23}^{24}to do. So I will just list two older papers that were influential for me. The first was by Geoff Nichols, Colin Fox, and Alexis Muir Watt, which looks at what happens when you don’t correct your pseudo-marginal method correctly. It’s a really neat theory paper that is a great presentation^{25}of the concepts. The second paper is by Pierre Alquier, Nial Friel, Richard Everitt, and Aidan Boland, which looks at general approximate Markov chains. They show empirically that these methods work extremely well relative to pseudo-marginal methods for practical settings. There are also some nice results on perturbations of Markov chains in general, for instance this paper by Daniel Rudolf and Nikolaus Schweizer.

So how do I think of noisy Markov chains. Despite all appearances^{26} I am not really a theory person. So while I know that there’s a massive literature on the stability of Markov chains, it doesn’t really influence how I think about it.

Instead, I think about it in terms of that Nicholls, Fox, and Muir Watt paper paper. Or, specifically, a talk I saw Colin give at some point that was really clear.

The important thing to recognise is that *it is not important how well you compute* . What is important is if you get the same outcome. Imagine we have two random variables and . If our realisation of is the same as our realisation of , then we get the same . Or, to put it another way, when , no one can tell^{27} that it’s an approximate Markov chain.

This means that one way to understand inexact MCMC is to think of the Markov chain where^{28} indicates whether or not we made the wrong decision. It’s important to note that while is marginally a Markov chain, is not. You can actually think of as the observation of a hidden Markov model if you want to. I won’t stop you. Nothing will. There is no morality, there is no law. It is The Purge.

Although we can never actually observe , thinking about it is really useful. In particular, we note that until for the first time, the samples of are *identical* to a correct Metropolis-Hastings algorithm. After this point, the approximate chain and the (imaginary) exact chain will be different. But we can iterate this argument.

To do this, we can define the length of the Markov chain that would be the same as the exact MCMC algorithm started at by and

If we run our algorithm for steps, we can then think of the output as being the same as running Markov chains of different lengths. The th chain starts at and is length . It is worth remembering that these chains are not started from independent points. In particular, if is small, then the starting position of the th and the th chain will be heavily correlated.

To think about this we need to think about what happens after steps of a Markov chain. We are going to need the notation denotes steps of the exact algorithm.

The topic of convergence of Markov chains is a complex business, but we are going to assume that our exact Markov chain is^{29} *geometrically ergodic*, which means that for some function^{30} and .

Geometric ergodicity is a great condition because, among other things, it ensures that sample means from the Markov chain satisfy a central limit theorem. It’s also bloody impossible to prove. But usually indicators like R-hat do a decent job at suggesting that there might be problems. Also if you are spending a lot of time rejecting proposals in certain parts of the space, there’s a solid chance that you’re not geometrically ergodic.

Now let’s assume that we are interested in computing for some nice^{31} function . Then the nice thing about Markov chains is that, give or take^{32} where might depend on if is unbounded.

This suggests that the error is bounded by, roughly,

This suggests a few things:

If is small relative to , we are going to get

*very*similar estimates to just running parallel Markov chains and combining them*without removing any warm up iterations*. In particular, if almost all are big, it will be*a lot*like combining warmed up*independent*chains.Effective sample size and Monte Carlo standard error estimates will potentially be very wrong. This is because instead of computing them based on multiple dependent chains, we are pretending that all of our samples came from a single ergodic Markov chain. Is this a problem? I really don’t know. Again, if the s are usually large, we will be fine.

Because can be pretty large when is large, we might have some problems. It’s easy to imagine cases where we get stuck out in a tail and we just fire off a lot of events when is really big. This will be a problem. But also, if we are stuck out in a tail, we are rightly fucked anyway and all of the MCMC diagnostics should be screaming at you. We can take heart that is usually finite

^{33}and not, you know, massive.

So the take away from the last section was that if the random variables are usually pretty big, then everything will work ok. Intuitively this makes sense. If the s were always small, it would be very difficult to ever get close to any sort of stationary distribution.

The paper by Nicholls, Fox, and Muir Watt paper talks about potential sizes for . The general construction that they use is a *coupling*, which is a bivariate Markov chain that start from the same position and are updated as follows:

- Propose
- Generate a uniform random number
- Update as
- Update as

This Markov chain is coupled in three ways ways. The chain starts at the same values , the proposed is the same for both chains, and the randomness^{34} used to do the accept/reject step is the same. Together, this things mean that for all .

For this coupling construction, we can get the exact distribution of the . To do this, we remember that we will only make different decisions in the two chains (or uncouple) if is on different sides of the two acceptance probabilities. The probability of happening is

I guess you could write down the distribution of the in terms of this. In particular, you get , but honestly it would be an absolute nightmare.

When people get stuck in probability questions, the natural thing to do is to make the problem so abstract that you can make the answer up. In that spirit, let’s ask a slightly different: what is the distribution of the *maximal* decoupling time between the exact and the approximate chain. This is the distribution of the longest possible coupling of the two chains over all^{35} possible random sequences such that the distribution of is the same as our exact Markov chain and the distribution of is the same as our approximate Markov chain.

This maximal value of is called the *maximal agreement coupling time* or, more whimsically, the MEXIT time. It turns out that getting the distribution of is … difficult, but we^{36} can construct a random variable that is independent of such that almost surely and where is the transition distribution for the exact Markov^{37} chain and is the transition distribution for the approximate Markov chain.

For a Metropolis-Hastings algorithm, the transition distribution has the form where is the probability associated with the proposal density and I have been very explicit about the dependence of the acceptance probability on . (The term takes into account the probability of starting at and not accepting the proposed state.)

That definition of looks pretty nasty, but it’s not too bad: in particular, the infinitum only cares of . This means that the condition simplifies to

This simplifies further if we assume that the proposal distribution is absolutely continuous and has a strictly positive density. Then, it truly does not matter what is. For the first term, it just cancels, while the second term is monotone^{38} in , so we can take this term to be either zero or one and get^{39}

This is, as the Greeks would say, not too bad.

If, for instance, we know the relative error then and if we know^{40} , we get Similarly, if and , then we get

The nice thing is that we can choose our upper bounds so that and get the upper bound It follows that

Now this is a bit nasty. It’s an upper bound on the probability of a lower bound on the maximal decoupling time. Probability, eh.

Probably the most useful thing we can get from this is an upper bound on , which is^{41}

This confirms our intuition that if the relative error is large, we will have, on average, quite small . It’s not quite enough to show the opposite (good floating point error begets big ), but that’s probably true as well.

And that is where we end this saga. There is definitely more that could be said, but I decided to spend exactly one day writing this post and that time is now over.

Usually this is a lie, but it was actually a thing that happened last week↩︎

Don’t judge me (or my friends) based on this. I promise we also talk about other shit.↩︎

Hi GPUs!↩︎

usually reversible, although a lot of cool but not ready for prime time work is being done on non-reversible chains.↩︎

A stationary distribution, if it exists, is the distribution that is preserved by the Markov chain. If is the stationary distribution and , then if we construct by running the Markov chain then for every , the marginal distribution is .↩︎

But critically not all! The dynamic HMC algorithm used in Stan, for instance, is not a Metropolis-Hastings algorithm. Instead of doing an accept/reject step it samples from the proposed trajectory. Betancourt’s long intro to Hamiltonian Monte Carlo covers this very well.↩︎

The conditions for this to work are

*very*light. But that’s because the definition of “working” only thinks about what happens after infinitely many steps. To get a practically useful Metropolis-Hastings algorithm, you’ve got to work very hard on choosing your proposal density.↩︎sometimes called the Hastings correction↩︎

This is not the only choice that will work, but in some sense it is the most efficient one.↩︎

Technically, it is chosen by requiring that the Markov proposal satisfies the detailed balance condition , but everything about that equation is beyond the scope of this particular post.↩︎

Metropolis-adjusted Langevin Algorithm↩︎

Under the assumption that the total floating point error was bounded by a constant ↩︎

This time the assumption was that the rounding error for the acceptance probability at state was bounded by . This is a lot closer to how floating point arithmetic actually works. The trade off is that it requires a tighter condition on the drift function .↩︎

IEEE floating point arithmetic represents a real number using bits. Typically (double precision) or (single precision). You can read a great intro to this on Nick Higham’s blog. But in general, the

*best*we can represent a real number by is by a floating point number that satisfies where in single precision and in double precision. Of course, the acceptance probability is a non-linear combination of floating point numbers, so the actual error is going to be more complicated than that. I strongly recommend you read Nick Higham’s book on the subject.↩︎-geometrically ergodic with some light conditions on ↩︎

Geometric ergodicity implies the existence of a CLT! Which is nice, because all of our intuition about how to use the output from MCMC depends on a CLT.↩︎

Like all good orgies, this one was mostly populated by men↩︎

Yes, I know. My (limited) contribution this literature was some small contributions to a paper lead by Anne-Marie Lyne. But if years of compulsory catholicism taught me anything (other than “If you’re drinking with a nun or an aging homosexual, don’t try to keep up”) it’s that something does not have to be literally true to be morally true.↩︎

We have to slightly redefine the word “exact” to mean “targets the correct stationary distribution” for this name to make sense↩︎

Random graph models and point processes are two great examples↩︎

for instance, it gets stuck for long times at single values↩︎

the aforementioned point process and graph models↩︎

Playing God of War: Ragnarok↩︎

The first run of God of War Games were not my cup of tea, but the 2008 game, which is essentially a detailed simulation of what happens when a muscle bear is entrusted with walking an 11 year old up a hill, was really enjoyable. So far this is too.↩︎

Does it talk about involutions for not fucking reason? Of course it does. Read past that.↩︎

Yeah, like I have also read my blog. Think of it as being like social media. It is not a representation of me a whole person. It’s actually biased towards stuff that I have either found or find difficult.↩︎

A friend of mine has a “No one knows I’m a transexual” t-shirt that she likes to wear to supermarkets.↩︎

Note that both and are computed using the

*same*value .↩︎The norm here is usually either the total variation norm of the -norm. But truly it’s not important for the hand waving.↩︎

In most cases as .↩︎

Bounded and continuous always works. But everything is probably ok for unbounded functions as long as has a pile of finite moments.↩︎

This is roughly true. I basically used the geometric ergodicity bound to bound and summed it up. There are smarter things to do, but it’s close enough for government work. ↩︎

Sometimes, if you squint, this term will kinda, sorta start to look like , which isn’t usually toooo big. But also, sometimes it looks totally different. Theory is wild.↩︎

If you’ve ever wondered how

`rbinom(1,p)`

works, there you are.↩︎Think of this as the opposite of an adversarial example. We are trying to find the exact chain that is scared to leave the approximate chain behind. Which is either romantic or creepy, depending on finer details.↩︎

Well not me. Florian Völlering did it in his Theorem 1.4. I most certainly could not have done it.↩︎

Well the result does not need this to be a Markov chain!↩︎

it goes up if otherwise it goes down↩︎

The 1 case can basically never happen except in the trivial case where both acceptance probabilities are the same. And if we thought that was going to happen we would’ve done something bloody else↩︎

The the relative error being bounded does not stop the absolute error growing!↩︎

Look above and recognize the Geometric distribution↩︎

BibTeX citation:

```
@online{simpson2022,
author = {Dan Simpson},
editor = {},
title = {MCMC with the Wrong Acceptance Probability},
date = {2022-11-23},
url = {https://dansblog.netlify.app/posts/2022-11-23-wrong-mcmc/wrong-mcmc.html},
langid = {en}
}
```

For attribution, please cite this work as:

Dan Simpson. 2022. “MCMC with the Wrong Acceptance
Probability.” November 23, 2022. https://dansblog.netlify.app/posts/2022-11-23-wrong-mcmc/wrong-mcmc.html.

Paradoxes and counterexamples live in statistics as our morality plays and our ghost stories. They serve as the creepy gas station attendants that populate the roads leading to the curséd woods; existing not to force change on the adventurer, but to signpost potential danger.^{1}

As a rule, we should also look in askance at attempts to resolve these paradoxes and counterexamples. That is not what they are for. They are community resources, objects of our collective culture, monuments to thwarted desire.

But sometimes, driven by the endless thirst for content, it’s worth diving down into a counterexample and resolving it. This quixotic quest is not to somehow patch a hole, but to rather expand the hole until it can comfortably encase our wants, needs, and prayers.

To that end, let’s gather ’round the campfire and attend the tale of The Bayesian and the Ancillary Coin.

This example^{2} was introduced by Robins and Ritov, and greatly popularised (and frequently reformulated) by Larry Wasserman^{3}. It says^{4} this:

A committed subjective Bayesian (one who cleaves to the likelihood priniciple tighter than Rose clings to that door) will sometimes get a very wrong answer under some simple, but realistic, forms of randomization. Only a less committed Bayesian will be able to skirt the danger.

So this is what we’re going to do now. First let’s introduce a version of the problem that does not trigger the counterexample. We then introduce the randomization scheme that leads to the error and talk about exactly how things go wrong. As someone who is particular skeptical of any claims to purity^{5}, the next job is going to be deconstructing this idea of a committed^{6} subjective Bayesian. I will, perhaps unsurprisingly, argue that this is the only part of the Robins and Ritov (and Wasserman) conclusions that are somewhat questionable. In fact, a *true* committed subjective Bayesian^{7} can solve the problem. It’s just a matter of looking at it through the correct lens.

This example exists in a number of forms, that each add important corners to the problem, but in the interest of simplicity, we will start with a simple situation where no problems occur.

Assume that there is a large, but fixed, finite number , and unknown parameters , . The large number can be thought of as the number of strata in a population, while are the means of the corresponding stratum. Now construct an experiment where you draw To close out the generative model, we assume that the covariates have a known distribution .

A classical problem in mathematical statistics is to construct a -consistent^{8} estimator of the vector . But in the setting of this problem, this is quite difficult. The challenge is that if is a very large number, then we would need a gargantuan^{9} number of observations () in order to resolve all of the parameters properly.

But there is a saving grace! The *population*^{10} average can be estimated fairly easily. In fact, the sample mean (aka the most obvious estimator) is going to be -consistent.

Similarly, if we were to construct a Bayesian estimate of the population mean based off the prior and , then the posterior estimate of the population mean is, for large enough^{11} , This means that the^{12} Bayesian resolution of this problem is roughly the same as the classical resolution. This is a nice thing. For very simple problems, these estimators should be fairly similar. It’s only when shit gets complicated where things become subtle.

This scenario, where a model is parameterized by an extremely high dimensional parameter but the quantity of inferential inference is a low-dimensional summary of , is widely and deeply studied under the name of semi-parametric statistics.

Semi-parametric statistics is, unsurprisingly, harder than parametric statistics, but it also quite a bit more challenging than non-parametric statistics. The reason is that if we want to guarantee a good estimate of a particular finite dimensional summary, it turns out that it’s not enough to generically get a “good” estimate of the high-dimensional parameter. In fact, getting a good estimate of the high-dimensional parameter is often not possible (see the example we just considered).

Instead understanding semi-parametric models becomes the fine art of understanding what needs to be done well and what we can half arse. A description of this would take us *well* outside the scope of a mere blog post, but if you want to learn more about the topic, that’s what to google.

In order to destroy all that is right and good about the previous example, we only need to do one thing: randomize in a nefarious way. Robins and Ritov (actually, Wasserman who proposed the case with a finite ) add to their experiment biased coins with the property that for some *known* , and some .

They then go through the data and add a column . The new data is now a three dimensional vector . It’s important to this problem that the are known and that we have the conditional independence structure .

Robins, Ritov, and Wasserman all ask the same question: Can we still estimate the population mean if we only observe samples from the *conditional* distribution ?

The answer is going to turn out that there is a perfectly good estimator from classical survey statistics, but a Bayesian estimator is a bit more challenging to find.

Before we get there, it’s worth noting that unlike the problem in the previous section, this problem is at least a little bit interesting. It’s a cartoon of a very common situation where there is covariate-dependent randomization in a clinical trial. Or, maybe even more cleanly, a cartoon of a simple probability survey.

A critical feature of this problem is that because the are known and is known, the joint likelihood factors as so is ancillary^{13} for .

The simplest classical estimator for is the Horvitz-Thompson estimator It’s easy to show that this is a -consistent estimator. Better yet, *uniform* over in the sense that the convergence of the estimator isn’t affected (to leading order) by the specific values. This uniformity is quite useful as it gives some hope of good finite-data behaviour.

So now that we know that the problem *can* be solved, let’s see if we can solve it in a Bayesian way. Robins and Ritov gave the following result.

There is no uniformly consistent Baysesian estimator of the parameter unless the prior depends on the values.

Robins and Ritov argue that a “committed subjective Bayesian” would, by the Likelihood Principle, never allow their prior to depend on the ancillary statistic as the Likelihood Principle clearly states that inference should be independent on ancillary information.

There are, of course, ways to construct priors that depend on the sampling probabilities. Wasserman calls this “frequentist chasing”

So let’s investigate this, by talking about what went wrong, how to fix it, and whether fixing it makes us bad Bayesians.

So what is the likelihood principle and why is it being such a bastard to us poor liddle bayesians?

The likelihood principle says, roughly, that the all of the information needed for parameter inference^{14} should be contained in the likelihood function.

In particular, if we follow the likelihood principle, then if we have two likelihoods that are scalar multiples of each other, our estimates of the parameters should be the same.

Ok. Sure.

Why on earth do people care about the likelihood principle? I guess it’s because they aren’t happy with the fact that Bayesian methods actually work in practice and instead want to do some extremely boring philosophy-ish stuff to “prove” the superiority and purity of Bayesian methods. And you know all power^{15} to them. Your kink is not my kink.

In this context, it means that because is ancillary to for estimating we should avoid using the s (and the s) to estimate . This is in direct opposition to what the Horvitz-Thompson estimator uses.

What happens if we follow this principle? We get a bad estimate.

It’s pretty easy to see that the posterior mean will, eventually, converge to the true value. All that has to happen is you need to see enough observations in each category. So if you get enough data, you will eventually get a good estimate.

Unfortunately, when is large, this will potentially take a very very long^{16} time.

Let’s go a bit deeper and see why this behaviour is not wrong, *per se*, it’s just Bayesian.

Bayesian inference produces a posterior distribution, which is conditional on an observed sample. This posterior distribution is an update to the prior that describes how compatible different parameter configurations are with the observed sample.

The thing is, our sample only sees a small sample of the values of . This means that we are, essentially, estimating where is the set observed values of , which depends on . This target changes as we get more data and see more levels of and eventually coalesces towards the thing we are trying to compute.

But, and this is critical, we *cannot* say *anything* about for unless we can assume that they are, in some sense, very strongly related. Unfortunately, the whole point of this example is that we are not allowed^{17} to assume that!

In this extremely flexible model, it’s possible to have a sequence that is highly correlated^{18} with . If, for instance, were^{19} equally spaced on for some small , you would have the situation where you are very likely to see the largest values of and quite unlikely to see any of the smaller values. This would gravely bias your sample mean upwards.

This construction is the basis similar to the one that Robins and Ritov use to prove that there is always at parameter value where the posterior mean converges^{20} to the true mean at a rate no faster than , which would require an exponentially large number of samples to do any sort of inference.

A reasonable criticism of this argument is that surely most problems will not have strong correlation between the sampling probabilities and the conditional means. In a follow up paper, Ritov *et al.* argue that it’s not necessarily all that rare. For instance, if they are both realisations of independent GPs^{21} the empirical correlation between the two observed sequences can be far from zero! Less abstractly, it’s pretty easy to imagine something that is more popular with old people (who often answer their phones) than with young people (who don’t typically answer their phones). So this type of adversarial correlation certainly can happen in practice.

No.

Bayes does not need to be saved. She is doing exactly what it set out to do and is living her best life. Do not interfere^{22}.

So let’s look at why we don’t need to fix things.

Once again, recall the setting: we are observing the triple^{23} In particular, we can process this data to get some quantities:

- : The total sample size
- : The number of observed
- : The total number of times group was sampled
- : The number of times an observation from group was recorded.

Because of the structure of the problem, most observed values of and will be zero or one.

Nevertheless, we persist.

We now need priors on the . There are probably a tonne of options here, but I’m going to go with the simplest one, which is just to make them iid for some fixed and known value . We can then fit the resulting model and get the posterior for each . Note that because of the data sparsity, most of the posteriors will just be the same as the prior.

Then we can ask ourselves a much more Bayesian question: What would the average in our sample have been if we had recorded every ? Our best estimate of that quantity is

That’s all well and good. And, again, if I had small enough or large enough that I had a good estimate for all of the , this would be a good estimate. Moreover, for finite data this is likely to be a much better estimator than as it at least partially corrects for any potential imbalance in the covariate sampling.

It’s also worth noting here that there is nothing “Bayesian” about this. I am simply taking the knowledge I have from the sample I observed and processing the posterior to compute a quantity that I am interested in.

But, of course, that isn’t actually the quantity that I’m interested in. I’m interested in that quantity averaged over realisations of . We can compute this if we can quantify the effect that has on .

We can do this pretty easily. Our priors are iid^{24}, so this decouples into independent normal-normal models.

For any , denote as the subset of that are in category . We have that^{25}

If we expand the density for a we get Matching terms in these two expressions we get that while the posterior mean is where I’ve suppressed the dependence on the sample in the and notation because, as a true^{26} Bayesian, my sample is fixed and known. Hence

Then I get the following estimator for the mean of the complete sample We can also compute the posterior variance^{27} Note that most of the groups won’t have a corresponding observation, so, recalling that is the set of s that have been updated in the sample, we get where the term that multiplies is less than 1.

So that’s all well and good, but that isn’t really the thing we were trying to estimate. We are actually interested in estimating the population mean, which we will get if we let .

So let’s see if we can do this without violating any of the universally agreed upon sacred strictures of Bayes.

Here’s the thing, though. We have computed our posterior distributions and we can now use them as a generative model^{28} for our data. We also have the composition of the complete data set (the s) and full knowledge about how a new sample of the s would come into our world.

We can put these things together! And that’s not in anyway violating our Bayesian oaths! We are simply using our totally legally obtained posterior distribution to compute things. We are still true committed^{29} subjective Bayesians.

So we are going to ask ourselves a simple question. Imagine, for a given , we have iid samples^{30} What is the posterior mean ? In fact, because this is random data drawn from a hypothetical sample, we can (and should^{31}) ask questions about its distribution! To be brutally francis with you, I am too lazy to work out the variance of the posterior mean. So I’m just going to look at the mean of the posterior mean.

First things first, we need to look at the (average) posterior for when . The exact calculation we did before gives us And, while I said I wasn’t going to focus on the variance, it’s easy enough to write down as where the second term takes into account the variance due to the imputation.

With this, we can estimate sample mean for any number and any set of that sum to and any set of as where in the last line I’ve used the fact that the empirical proportion converges to and the posterior mean converges to . The little-o^{32} error term is as (and hence and ) goes to infinity.

To turn this into a practical estimate, we can plug in our values of and to get our Bayesian approximation to the population mean which is (up to the small term in brackets) the Horvitz-Thompson estimator!

I stress, again, that there is nothing inherently non-Bayesian about this derivation. Except possibly the question that it is asking. What I did was compute the posterior distribution and then I took it seriously and used it to compute a quantity of interest.

The only oddity is that the quantity of interest (the population mean) has a slightly awkward link to the observed sample. Hence, I estimated something that had a more direct link to the population mean: the sample mean of the completely observed sample under different realisations of the randomisation .

In order to estimate the sample mean under different realisations of the randomisation, I needed to use the posterior predictive distribution to impute these fictional samples. I then averaged over the imputed samples and sent the sample size to infinity to get an estimator^{33}.

Or, to put it differently, I used Bayes to get a posterior estimate for new data and then used this probabilistic model to estimate . There was no reason to use Bayesian methods to do this. Non-Bayesian questions do not invite Bayesian answers.

Now, would I go to all of this effort in real life? Probably not. And in the applications that I’ve come across, I’ve never had to. I’ve done a bunch of MRP^{34}, which is structurally quite similar to this problem except we can reasonably model the dependence structure between the s. This paper I wrote with Alex Gao, Lauren Kennedy, and Andrew Gelman is an example of the type of modelling you can do.

Wasserman derides “frequentist chasing” Bayesians, making the point that if they want a frequentist guarantee so badly, why not just do it the easy way.

Now. Laz. Mate.

Let me tell you that a lot of my self esteem has been traditionally gathered from chasers, so I absolutely refuse to be party to the slander.

But more than that, let’s be clear. Bayes is a way to probabilistically describe data. That is not enough in and of itself to be useful. For it to be useful, we need to *do something* with that posterior distribution.

So really, let’s talk about what a *true committed subjective Bayesian* does about this. Firstly, I mean really. There is no such thing^{35}. But leaving that aside, the closest I can get to a working definition is that a true committed subjective Bayesian is a person who understands that parameters are polite fictions that are used to describe the data. They are not, inherently, linked to any population quantity (for a true committed subjective Bayesian, such a thing does not exist).

The *only* way to link parameters in a Bayesian model to a population quantity of interest is to use some sort of extra-Bayesian^{36} information.

For instance, in the first example (the one without the ancillary coin), I made that link in secret using assumptions about the sample. We all know that those types of assumptions are fraught and the reason that people spend so much time whispering DAG into the ears of their sleeping lovers.

For the ancillary coin example, we used the given information about the sampling mechanism as our extra information to link our posterior distribution to the population quantity of interest. None of this changes the *purity*^{37} of the Bayesian analysis. Or makes a non-Bayesian solution preferable. (Although, in this case, a non-Bayesian solution is a fuckload easier to come up with.)

Of course Wasserman (and I presume Robins and Ritov) know all of this. But it’s fun to write it all down.

Moreover, I think that the three lessons here are fairly transferable:

- If you’re going to go to the trouble of computing a posterior, take it seriously. Use it to do things! You can even put it in as part of a probabilistic model.
- If you’re going to make Bayes work for you, think in terms of observables (eg the mean of the complete sample) rather than parameters.
- Appeals to purity are a bit of a waste of time.

Huge thanks to Sameer Deshpande for great comments!↩︎

I first came across this in a series of posts on Larry Wasserman’s now defunct but quite excellent blog.↩︎

It’s worth saying that these three people do fabulous statistics of the form that I don’t usually do. But that doesn’t make it less important to understand their contributions. You could say that while I am not a Lazbian, I think it’s important to know the theory.↩︎

I might have slightly reworded it.↩︎

Purity is needed in good olive oil and that’s it↩︎

A committed subjective Bayesian prefers Dutch baby to a Dutch book.↩︎

A true committed subjective Bayesian doesn’t wear anything under his kilt.↩︎

That is, an estimator where for all . This, roughly, means, that you can find a such that with high probability.↩︎

The asymptotics say that we should count our data in multiples of , so we’d to get even one decimal place of accuracy.↩︎

Remember .↩︎

Theorem 2 of Harmeling and Toussaint↩︎

a↩︎

If you’ve not come across it,

*ancillary*is the term used for parts of the data that don’t influence parameter estimates. It’s the opposite of a sufficient statistic. One way to see that it’s ancillary for*any*model , is to consider the log of the joint density , where the last two terms are constant in .↩︎You need to be specific here. Obviously this would be false if you were trying to do a statistical prediction. Or if you were trying to make a decision. Those things necessarily depend on extra stuff!↩︎

This is a lie. Insisting on talking about this shit rather than actually making Bayes useful and using it in new and exciting ways to do things that are hard to do without Bayesian methods is a waste of time. Worse than that, when you start pretending your method of choice is the only possible thing that a sensible and principled person would use, you start to look like a bit of a dickhead. It also turns people off trying these very flexible and useful methods. So yeah. I maybe do care a little bit. ↩︎

The expected number of samples to see one draw where is . The expected number of draws where that you need to actually observe the corresponding is . This suggests it will potentially take

*a lot*of draws to even have effectively one sample from each category, let alone the 20-100 you’d need to, practically, get some sort of reasonable estimate.↩︎Robins and Ritov have always been open that if there is a true parametric model for the (or if that function is “very smooth” in some technical sense, eg a realisation of a smooth Gaussian process) then the Bayesian estimator that incorporates this information will do perfectly well. ↩︎

So the RR example uses binary data, so then it’s the correlation between and , but the exact same argument works if is correlated something like . I went with the Gaussian version because at one point I thought I might end up having to derive posteriors and I’m all about simplicity.↩︎

expit is the inverse of the logit transform↩︎

Check the paper for the details as the situation is slightly different to the one I’m sketching out here, but there’s no real substantive difference.↩︎

Of course, if this were true we could use a GP prior for the s and we’d probably get a decent estimator anyway.↩︎

If you want to interfere, there are plenty of ways to build priors that incorporate the information. The Ritov etc paper has nice references to the various things that sprung up from this example. Are these useful beyond simply making sure the posterior mean of estimates ? Not really. They are priors designed to solve exactly one problem.↩︎

I’m using the C/C++ ternary operator. In R this would be parsed as

`ifelse(r[i] == 1, y[i], NA)`

. ↩︎Not exchangeable—there are no shared parameters!↩︎

Remember that . If we wanted a more flexible variance, we could obviously have one, but it makes not real difference to anything.↩︎

I promise I’m just rolling my eyes to see if I can see my brain.↩︎

Remember everything is independent!↩︎

This is the posterior predictive distribution!↩︎

A true committed subjective Bayesian knows that DP stands for Dirichlet Process. No matter the context.↩︎

The variance is because this is the posterior predictive distribution.↩︎

Does this seem like a frequentist question? I guess. But really it’s a question we can always ask about the posterior. Should we? Well if you are trying to estimate a population quantity you sort of have to. Because there isn’t really a concept of a population parameter within a Bayesian framework (true committed subjective or otherwise).↩︎

Remember that this means that the error (which is a random variable) goes to 0 as . A more careful person could probably work out how fast it would happen.↩︎

I only computed the mean, so feel free to pretend that I’m minimizing a loss function↩︎

Multilevel regression with poststratification, a survey modelling technique↩︎

No true Scotsman etc↩︎

or meta-Bayesian in the event that we are doing things like building a Bayesian pseudo-model of on the space of all considered model that just happens to give every model equal probability because Harold Fucking Jeffreys gave you an erection and you could either process that event like an adult or build a whole personality around it. And you chose the latter.↩︎

Can you tell that I hate this entire discussion?↩︎

BibTeX citation:

```
@online{simpson2022,
author = {Dan Simpson},
editor = {},
title = {On That Example of {Robins} and {Ritov;} or {A} Sleeping Dog
in Harbor Is Safe, but That’s Not What Sleeping Dogs Are For},
date = {2022-11-15},
url = {https://dansblog.netlify.app/posts/2022-11-12-robins-ritov/robins-ritov.html},
langid = {en}
}
```

For attribution, please cite this work as:

Dan Simpson. 2022. “On That Example of Robins and Ritov; or A
Sleeping Dog in Harbor Is Safe, but That’s Not What Sleeping Dogs Are
For.” November 15, 2022. https://dansblog.netlify.app/posts/2022-11-12-robins-ritov/robins-ritov.html.

Twice.

The first time was when I needed to understand approximation properties of a certain class of GPs. I wrote a post about it. It’s intense^{1}.

The second time that I really needed to dive into their arcana and apocrypha^{2} was when I foolishly asked the question *can we compute Penalised Complexity (PC) priors ^{3} ^{4} for Gaussian processes?*.

The answer was yes. But it’s a bit tricky.

So today I’m going to walk you through the ideas. There’s no real need to read the GP post before reading the first half of this one^{5}, but it would be immensely useful to have at least glanced at the post on PC priors.

This post is *very* long, but that’s mostly because it tries to be reasonably self-contained. In particular, if you only care about the fat stuff, you really only need to read the first part. After that there’s a long introduction to the theory of stationary Gaussian processes. All of this stuff is standard, but it’s hard to find collected in one place all of the things that I need to derive the PC prior. The third part actually derives the PC prior using a great deal of methods from the previous part.

We are in the situation where we have a model that looks something like this^{6} ^{7} where is a covariance function with parameters and we need to specify a joint prior on the GP parameters .

The simplest case of this would be GP regression, but a key thing here is that, in general, the structure (or functional form) of the priors on probably shouldn’t be too tightly tied to the specific likelihood. Why do I say that? Well the *scaling* of a GP should depend on information about the likelihood, but it’s less clear that anything else in the prior needs to know about the likelihood.

Now this view is predicated on us wanting to make an informative prior. In some very special cases, people with too much time on their hands have derived reference priors for specific models involving GPs. These priors care *deeply* about which likelihood you use. In fact, if you use them with a different model^{8}, you may not end up with a proper^{9} posterior. We will talk about those later.

To start, let’s look at the simplest way to build a PC prior. We will then talk about why this is not a good idea.

As always, the best place to start is the simplest possible option. There’s always a hope^{10} that we won’t need to pull out the big guns.

So what is the simplest solution? Well it’s to treat a GP as just a specific multivariate Gaussian distribution where is a correlation matrix.

The nice thing about a multivariate Gaussian is that we have a clean expression for its Kullback-Leibler divergence. Wikipedia tells us that for an -dimensional multivariate Gaussian To build a PC prior we need to consider a base model. That’s tricky in generality, but as we’ve assumed that the covariance matrix can be decomposed into the variance and a correlation matrix , we can at least specify an easy base model for . As always, the simplest model is one with no GP in it, which corresponds to . From here, we can follow the usual steps to specify the PC prior where we choose for some upper bound and some tail probability so that The specific choice of will depend on the context. For instance, if it’s logistic regression we probably want something like^{11} . If we have a GP on the log-mean of a Poisson distribution, then we probably want if you want the *mean* of the Poisson distribution to be less than the maximum integer^{12} in R. In most data, you’re gonna want^{13} . If the GP is on the mean of a normal distribution, the choice of will depend on the context and scaling of the data.

Without more assumptions about the form of the covariance function, it is impossible to choose a base model for the other parameters .

That said, there is one special case that’s important: the case where is a single parameter controlling the intrinsic length scale, that is the distance at which the correlation between two points units apart is approximately zero. The larger is, the more correlated observations of the GP are and, hence, the less wiggly its realisation is. On the other hand, as , the observations GP often behaves like realisations from an iid Gaussian and the GP becomes^{14} wilder and wilder.

This suggests that a good base model for the length-scale parameter would be . We note that if both the base model and the alternative have the same value of , then it cancels out in the KL-divergence. Under this assumption, we get that where I’m being a bit cheeky putting that limit in, as we might need to do some singular model jiggery-pokery of the same type we needed to do for the standard deviation. We will formalise this, I promise.

As the model gets more complex as the length scale decreases, we want our prior to control the smallest value can take. This suggests we want to choose to ensure How do we choose the lower bound ? One idea is that our prior should have very little probability of the length scale being smaller than the length-scale of the data. So we can chose to be the smallest distance between observations (if the data is regularly spaced) or as a low quantile of the distribution of distances between nearest neighbours.

All of this will specify a PC prior for a Gaussian process. So let’s now discuss why that prior is a bit shit.

The prior on the standard deviation is fine.

The prior on the length scale is more of an issue. There are a couple of bad things about this prior. The first one might seem innocuous at first glance. We decided to treat the GP as a multivariate Gaussian with covariance matrix . This is not a neutral choice. In order to do it, we need to *commit* to a certain set of observation locations^{15}. Why? The matrix depends entirely on the observation locations and if we use this matrix to define the prior we are tied to those locations.

This means that if we change the amount of data in the model we will need to change the prior. This is going to play havoc^{16} on any sort of cross-validation! It’s worth saying that the other two sources of information (the minimum length scale and the upper bound on ) are not nearly as sensitive to small changes in the data. This information is, in some sense, fundamental to the problem at hand and, therefore, much more stable ground to build your prior upon.

There’s another problem, of course: this prior is expensive to compute. The KL divergence involves computing which costs as much as another log-density evaluation for the Gaussian process (which is to say it’s very expensive).

So this prior is going to be *deeply* inconvenient if we have varying amounts of data (through cross-validation or sequential data gathering). It’s also going to be wildly more computationally expensive than you expect a one-dimensional prior to be.

All in all, it seems a bit shit.

It won’t be possible to derive a prior for a general Gaussian process, so we are going to need to make some simplifying assumptions. The assumption that we are going to make is that the covariance comes from the Whittle-Matérn^{17} ^{18} class where is the *smoothness* parameter, is the *length-scale* parameter, is the *marginal standard deviation*, and is the modified Bessel^{19} function of the second kind.

This class of covariance function is extremely important in practice. It interpolates between two of the most common covariance functions:

- when , it corresponds to the exponential covariance function,
- when , it corresponds to the squared exponential covariance.

There are years of experience suggesting that Matérn covariance functions with finite will often perform better than the squared exponential covariance.

Common practice is to fix^{20} the value of . There are a few reasons for this. One of the most compelling practical reasons is that we can’t easily evaluate its derivative, which rules out most modern optimisation and MCMC algorithms. It’s also *very* difficult to think about how you would set a prior on it. The techniques in this post will not help, and as far as I’ve ever been able to tell, nothing else will either. Finally, you could expect there to be *horrible* confounding between , , and , which will make inference very hard (both numerically and morally).

It turns out that even with fixed, we will run into a few problems. But to understand those, we are going to need to know a bit more about how inferring parameters in a Gaussian processes actually works.

Just for future warning, I will occasionally refer to a GP with a Matérn covariance function as a “Matérn field”^{21}.

Let’s take a brief detour into classical inference for a moment and ask ourselves *when can we recover the parameters of a Gaussian process*? For most models we run into in statistics, the answer to that question is *when we get enough data*. But for Gaussian processes, the story is more complex.

First of all, there is the very real question of what we mean by getting more data. When our observations are iid, this so easy that when asked how she got more data, Kylie just said she “did it again”.

But this is more complex once data has dependence. For instance, in a multilevel model you could have the number of groups staying fixed while the number of observations in each group goes to infinity, you could have the number of observations in each group staying fixed while the number of groups go to infinity, or you could have both^{22} going to infinity.

For Gaussian processes it also gets quite complicated. Here is a non-exhaustive list of options:

- You observe the same realisation of the GP at an increasing number of points that eventually cover the
*whole of*(this is called the*increasing domain*or*outfill*regime); or - You observe the same realisation of the GP at an increasing number of points
*that stay within a fixed domain*(this is called the*fixed domain*or*infill*regime); or - You observe multiple realisations of the same GP at a finite number of points that stay in the same location (this does not have a name, in space-time it’s sometimes called
*monitoring data*); or - You observe multiple realisations of the same GP at a (possibly different) finite number of points that can be in different locations for different realisations; or
- You observe realisations of a process that evolves in space
*and*time (not really a different regime so much as a different problem).

One of the truly unsettling things about Gaussian processes is that the ability to estimate the parameters depends on which of these regimes you choose!

Of course, we all know that asymptotic regimes are just polite fantasies that statisticians concoct in order to self-soothe. They are not reflections on reality. They serve approximately the same purpose^{23} as watching a chain of Law and Order episodes.

The point of thinking about what happens when we get more data is to use it as a loose approximation of what happens with the data you have. So the real question is *which regime is the most realistic for my data*?.

One way you can approach this question is to ask yourself what you would do if you had the budget to get more data. My work has mostly been in spatial statistics, in which case the answer is *usually*^{24} that you would sample more points in the same area. This suggests that fixed-domain asymptotics is a good fit for my needs. I’d expect that in most GP regression cases, we’re not expecting^{25} that further observations would be on new parts of the covariate space, which would suggest fixed-domain asymptotics are useful there too.

This, it turns out, is awkward.

The problem with a GP with the Matérn covariance function on a fixed domain is that it’s not possible^{26} to estimate all of its parameters at the same time. This isn’t the case for the other asymptotic regimes, but you’ve got to dance with who you came to the dance with.

To make this more concrete, we need to think about a Gaussian process as a realisation of a function rather than as a vector of observations. Why? Because under fixed-domain asymptotics we are seeing values of the function closer and closer together until we essentially see the entire function on that domain.

Of course, this is why I wrote a long and technical blog post on understanding Gaussian processes as random functions. But don’t worry. You don’t need to have read that part.

The key thing is that because a GP is a function, we need to think of it’s probability of being in a set of functions. There will be a set of function , which we call the *support* of , that is the smallest set such that Every GP has an associated support and, while you probably don’t think much about it, GPs are *obsessed* with their supports. They love them. They hug them. They share them with their friends. They keep them from their enemies. And they are one of the key things that we need to think about in order to understand why it’s hard to estimate parameters in a Matérn covariance function.

There is a key theorem that is unique^{27} to Gaussian processes. It’s usually phrased in terms of *Gaussian measures*, which are just the probability associated with a GP. For example, if is a GP then is the corresponding Gaussian measure. We can express the support of as the smallest set of functions such that .

**Theorem 1 (Feldman-Hájek theorem) **Two Gaussian measures and with corresponding GPs and on a locally convex space^{28} either satisfy, for every^{29} set ,

in which case we say that and are *equivalent*^{30} (confusingly^{31} written ) and , **or** in which case we say and are *singular* (written ) and and have disjoint supports.

Later on in the post, we will see some precise conditions for when two Gaussian measures are equivalent, but for now it’s worth saying that it is a *very* delicate property. In fact, if for any , then^{32} !

This seems like it will cause problems. And it can^{33}. But it’s *fabulous* for inference.

To see this, we can use one of the implications of singularity: if and only if where the the Kullback-Leibler divergence can be interpreted as the expectation of the likelihood ratio of vs under . Hence, if and are singular, we can (on average) choose the correct one using a likelihood ratio test. This means that we will be able to correctly recover the true^{34} parameter.

It turns out the opposite is also true.

**Theorem 2 **If , is a family of Gaussian measures corresponding to the GPs and for all values of , then there is *no* sequence of estimators such that, for all where is the probability under data drawn with true parameter . That is, there is no estimator that is (strongly) consistent for all .

*Proof*. We are going to do this by contradiction. So assume that there is a sequence such that For some , let . Then we can re-state our almost sure convergence as where the limit superior is defined^{35} as

For any with , the definition of equivalent measures tells us that and therefore The problem with this is that is that this data is generated using , but the estimator converges to instead of . Hence, the estimator isn’t uniformly (strongly) consistent.

This seems bad but, you know, it’s a pretty strong version of convergence. And sometimes our brothers and sisters in Christ who are more theoretically minded like to give themselves a treat and consider weaker forms of convergence. It turns out that that’s a disaster too.

**Theorem 3 **If , is a family of Gaussian measures corresponding to the GPs and for all values of , then there is *no* sequence of estimators such that, for all and all That is there is no estimator that is (weakly) consistent for all .

If you can’t tell the difference between these two theorems that’s ok. You probably weren’t trying to sublimate some childhood trauma and all of your sexual energy into maths just so you didn’t have to deal with the fact that you might be gay and you were pretty sure that wasn’t an option and anyway it’s not like it’s *that* important. Like whatever, you don’t need physical or emotional intimacy. You’ve got a pile of books on measure theory next to your bed. You are living your best life. Anyway. It makes almost no practical difference. BUT I WILL PROVE IT ANYWAY.

*Proof*. This proof is based on a kinda advanced fact, which involves every mathematician’s favourite question: what happens along a sub-sequence?

This basically says that the two modes of convergence are quite similar except convergence in probability is relaxed enough to have some^{36} values that aren’t doing so good at the whole converging thing.

With this in hand, let us build a contradiction. Assume that is weakly consistent for all . Then, if we generate data under , then we get that, along a sub-sequence

Now, if is weakly consistent for all , then so is . Then, by our assumption, for every and every

Our probability fact tells us that there is a *further* infinite sub-sub-sequence such that But Theorem 2 tells us that (and hence ) satisfies This is a contradiction unless , which proves the assertion.

All of that lead up immediately becomes extremely relevant once we learn one thing about Gaussian processes with Matérn covariance functions.

**Theorem 4 **Let be the Gaussian measure corresponding to the GP with Matérn covariance function with parameters , let be any finite domain in , and let . Then, restricted to , if and only if

I’ll go through the proof of this later, but the techniques require a lot of warm up, so let’s just deal with the consequences for now.

Basically, Theorem 4 says that we can’t consistently estimate the range and the marginal standard deviation for a one, two, or three dimensional Gaussian process. Hao Zhang noted this and that it remains true^{37} when dealing with non-Gaussian data.

The good news, I guess, is that in more than four^{38} dimensions the measures are always singular.

Now, I don’t give one single solitary shit about the existence of consistent estimators. I am doing Bayesian things and this post is supposed to be about setting prior distributions. But it is important. Let’s take a look at some simulations.

First up, let’s look at what happens in 2D when we directly (ie with no noise) observe a zero-mean GP with exponential covariance function () at points in the unit square. In this case, the log-likelihood is, up to an additive constant,

The R code is not pretty but I’m trying to be relatively efficient with my Cholesky factors.

```
set.seed(24601)
library(tidyverse)
cov_fun <- \(h,sigma, ell) sigma^2 * exp(-h/ell)
log_lik <- function(sigma, ell, y, h) {
V <- cov_fun(h, sigma, ell)
R <- chol(V)
-sum(log(diag(R))) - 0.5*sum(y * backsolve(R, backsolve(R, y, transpose = TRUE)))
}
```

We can now simulate 500 data points on the unit square, compute their distances, and simulate from the GP.

```
n <- 500
dat <- tibble(s1 = runif(n), s2 = runif(n),
dist_mat = as.matrix(dist(cbind(s1,s2))),
y = MASS::mvrnorm(mu=rep(0,n),
Sigma = cov_fun(dist_mat, 1.0, 0.2)))
```

With all of this in hand, let’s look at the likelihood surface along^{39} the line for various values of . I’m using some `purrr`

trickery^{40} here to deal with the fact that sometimes the Cholesky factorisation will throw an error.

```
m <- 100
f_direct <- partial(log_lik, y = dat$y, h = dat$dist_mat)
pars <- \(c) tibble(ell = seq(0.05,1, length.out = m),
sigma = sqrt(c * ell), c = rep(c, m))
ll <- map_df(3:8,pars) |>
mutate(contour = factor(c),
ll = map2_dbl(sigma, ell,
possibly(f_direct,
otherwise = NA_real_)))
ll |> ggplot(aes(ell, ll, colour = contour)) +
geom_line() +
scale_color_brewer(palette = "Set1") +
theme_bw()
```

We can see the same thing in 2D (albeit at a lower resolution for computational reasons). I’m also not computing a bunch of values that I know will just be massively negative.

```
f_trim <- \(sigma, ell) ifelse(sigma^2 < 3*ell | sigma^2 > 8*ell,
NA_real_, f_direct(sigma, ell))
m <- 50
surf <- expand_grid(ell = seq(0.05,1,length.out = m),
sigma = seq(0.1, 4, length.out = m)) |>
mutate(ll = map2_dbl(sigma, ell,
possibly(f_trim, otherwise = NA_real_)))
surf |> filter(ll > 50) |>
ggplot(aes(ell, sigma, fill = ll)) +
geom_raster() +
scale_fill_viridis_c() +
theme_bw()
```

Clearly there is a ridge in the likelihood surface, which suggests that our posterior is going to be driven by the prior along that ridge.

For completeness, let’s run the same experiment again when we have some known observation noise, that is . In this case, the log-likelihood is

Let us do the exact same thing again!

```
n <- 500
dat <- tibble(s1 = runif(n), s2 = runif(n),
dist_mat = as.matrix(dist(cbind(s1,s2))),
mu = MASS::mvrnorm(mu=rep(0,n),
Sigma = cov_fun(dist_mat, 1.0, 0.2)),
y = rnorm(n, mu, 1))
log_lik <- function(sigma, ell, y, h) {
V <- cov_fun(h, sigma, ell)
R <- chol(V + diag(dim(V)[1]))
-sum(log(diag(R))) - 0.5*sum(y * backsolve(R, backsolve(R, y, transpose = TRUE)))
}
m <- 100
f <- partial(log_lik, y = dat$y, h = dat$dist_mat)
pars <- \(c) tibble(ell = seq(0.05,1, length.out = m),
sigma = sqrt(c * ell), c = rep(c, m))
ll <- map_df(seq(0.1, 10, length.out = 30),pars) |>
mutate(contour = factor(c),
ll = map2_dbl(sigma, ell,
possibly(f, otherwise = NA_real_)))
ll |> ggplot(aes(ell, ll, colour = contour)) +
geom_line(show.legend = FALSE) +
#scale_color_brewer(palette = "Set1") +
theme_bw()
```

```
f_trim <- \(sigma, ell) ifelse(sigma^2 < 0.1*ell | sigma^2 > 10*ell,
NA_real_, f(sigma, ell))
m <- 20
surf <- expand_grid(ell = seq(0.05,1,length.out = m),
sigma = seq(0.1, 4, length.out = m)) |>
mutate(ll = map2_dbl(sigma, ell,
possibly(f_trim, otherwise = NA_real_)))
surf |> filter(ll > -360) |>
ggplot(aes(ell, sigma, fill = ll)) +
geom_raster() +
scale_fill_viridis_c() +
theme_bw()
```

Once again, we can see that there is going to be a ridge in the likelihood surface! It’s a bit less disastrous this time, but it’s not excellent even with 500 observations (which is a decent number on a unit square). The weird structure of the likelihood is still going to lead to a long, non-elliptical shape in your posterior that your computational engine (and your person interpreting the results) are going to have to come to terms with. In particular, if you only look at the posterior marginal distributions for and you may miss the fact that is quite well estimated by the data even though the marginals for both and are very wide.

This ridge in the likelihood is going to translate somewhat into a ridge in the prior. We will see below that how much of that ridge we see is going to be very dependent on how we specify the prior. The entire purpose of the PC prior is to meaningfully resolve this ridge using sensible prior information.

But before we get to the (improved) PC prior, it’s worthwhile to survey some other priors that have been proposed in the literature.

That ridge in the likelihood surface does not go away in low dimensions, which essentially means that our inference along that ridge is going to be driven by the prior.

Possibly the worst choice you could make in this situation is trying to make a minimally informative prior. Of course, that’s what somebody did when they made a reference prior for the problem. In fact it was the first paper^{41} that looks rigorously at prior distributions on the parameters of GPs. It’s just unfortunate that it’s quite shit. It has still been cited quite a lot. And there are some technical advances to the theory of reference priors, but if you use it you just find yourself mapping out that damn ridge.

On top of being, structurally, a bad choice, the reference prior has a few other downsides:

- It is very computationally intensive and quite complex. Not unlike the bad version of the PC prior!
- It requires
*strong*assumptions about the likelihood. The first version assumed that there was no observation noise. Later papers allowed there to be observation noise. But only if it’s Gaussian. - It is derived under the asymptotic regime where an infinite sequence of different independent realisations of the GP are observed at the same finite set of points. This is not the most useful regime for GPs.

All in all, it’s a bit of a casserole.

From the other end, there’s a very interesting contribution from Aad van der Vaart and Harry van Zanten wrote a very lovely theoretical paper that looked at which priors on could result in theoretically optimal contraction rates for the posterior of . They argued that should have a Gamma distribution. Within the Matérn class, their results are only valid for the squared exponential contrivance function.

One of the stranger things that I have never fully understood is that the argument I’m going to make below ends up with a gamma distribution on , which is somewhat different to van der Vaart and van Zanten. If I was to being forced to bullshit some justification I’d probably say something about the Matérn process depending only on the distance between observations makes the -sphere the natural geometry (the volume of which scales like ) rather than the -cube (the volume of which scales lie ). But that would be total bullshit. I simply have no idea. They’re proposal comes via the time-honoured tradition of “constant chasing” in some fairly tricky proofs, so I have absolutely no intuition for it.

We also found in other contexts that use the KL divergence rather than its square root tended to perform worse. So I’m kinda happy with our scaling and, really, their paper doesn’t cover the covariance functions I’m considering in this post.

Neither^{42} of these papers consider that ridge in the likelihood surface.

This lack of consideration—as well as their success in everything else we tried them on—was a big part of our push to make a useful version of a PC prior for Gaussian processes.

It has been a long journey, but we are finally where I wanted us to be. So let’s talk about how to fix the PC prior. In particular, I’m going to go through how to derive a prior on the length scale that has a simple form.

In order to solve this problem, we are going to do three things in the rest of this post:

- Restrict our attention to the stationary
^{43}GPs - Restrict our attention to the Matérn class of covariance functions.
- Greatly increase our mathematical
^{44}sophistication.

But before we do that, I’m going to walk you through the punchline.

This work was originally done with the magnificent Geir-Arne Fuglstad, the glorious Finn Lindren, and the resplendent Håvard Rue. If you want to read the original paper, the preprint is here^{45}.

The PC prior is derived using the base model , which might seem like a slightly weird choice. The intuition behind it is that if there is strong dependence between far away points, the realisations of cannot be too wiggly. In some context people talk about as a *“smoothness”*^{46} parameter because realisations with large “look”^{47} smoother than realisations with small .

Another way to see the same thing is to note that a Matérn field approaches a^{48} smoothing spline prior, in which case plays the role of the “smoothing parameter” of the spline. In that case, the natural base model of interacts with the base model of to shrink towards an increasingly flat surface centred on zero.

We still need to choose a quantity of interest in order to encode some explicit information in the prior. In this case, I’m going to use the idea that for any data set, we only have information up to a certain spatial resolution. In that case, we don’t want to put prior mass on the length scale being less than that resolution. Why? Well any inference about at a smaller scale than the data resolution is going to be driven entirely by unverifiable model assumptions. And that feels a bit awkward. This suggests that we chose a minimum^{49} length scale and choose the scaling parameter in the PC prior so that

Under these assumptions, the PC prior for the length scale in a -dimensional space is^{50} a Fréchet distribution^{51} with shape parameter and scale parameter . That is, where we choose to ensure that

In two dimensions, this is an inverse gamma prior, which gives rigorous justification to a commonly used prior in spatial statistics.

Ok, so let’s actually see how much of a difference using a weakly informative prior makes relative to using the reference prior.

In the interest of computational speed, I’m going to use the simplest possible model setup, and I’m only going to use 25 observations.

In this case^{52} is

Even with this limited setup, it took a lot of work to make Stan sample this posterior. You’ll notice that I did a ridge-aware reparameterisation. I also had to run twice as much warm up as I ordinarily would.

The Stan code is under the fold.

```
functions {
matrix cov(int N, matrix s, real ell) {
matrix[N,N] R;
row_vector[2] s1, s2;
for (i in 1:N) {
for (j in 1:N){
s1 = s[i, 1:2];
s2 = s[j, 1:2];
R[i,j] = exp(-sqrt(dot_self(s1-s2))/ell);
}
}
return 0.5 * (R + R');
}
matrix cov_diff(int N, matrix s, real ell) {
// dR /d ell = cov(N, p ,s, sigma2*|x-y|/ell^2, ell)
matrix[N,N] R;
row_vector[2] s1, s2;
for (i in 1:N) {
for (j in 1:N){
s1 = s[i, 1:2];
s2 = s[j, 1:2];
R[i,j] = sqrt(dot_self(s1-s2)) * exp(-sqrt(dot_self(s1-s2))/ell) / ell^2 ;
}
}
return 0.5 * (R + R');
}
real log_prior(int N, matrix s, real sigma2, real ell) {
matrix[N,N] R = cov(N, s, ell);
matrix[N,N] W = (cov_diff(N, s, ell)) / R;
return 0.5 * log(trace(W * W) - (1.0 / (N)) * (trace(W))^2) - log(sigma2);
}
}
data {
int<lower=0> N;
vector[N] y;
matrix[N,2] s;
}
parameters {
real<lower=0> sigma2;
real<lower=0> ell;
}
model {
{
matrix[N,N] R = cov(N, s, ell);
target += multi_normal_lpdf(y | rep_vector(0.0, N), sigma2 * R);
}
target += log_prior(N, s, sigma2, ell);
}
generated quantities {
real sigma = sqrt(sigma2);
}
```

By comparison, the code for the PC prior is fairly simple.

```
functions {
matrix cov(int N, matrix s, real sigma, real ell) {
matrix[N,N] R;
row_vector[2] s1, s2;
real sigma2 = sigma * sigma;
for (i in 1:N) {
for (j in 1:N){
s1 = s[i, 1:2];
s2 = s[j, 1:2];
R[i,j] = sigma2 * exp(-sqrt(dot_self(s1-s2))/ell);
}
}
return 0.5 * (R + R');
}
}
data {
int<lower=0> N;
vector[N] y;
matrix[N,2] s;
real<lower = 0> lambda_ell;
real<lower = 0> lambda_sigma;
}
parameters {
real<lower=0> sigma;
real<lower=0> ell;
}
model {
matrix[N,N] R = cov(N, s, sigma, ell);
y ~ multi_normal(rep_vector(0.0, N), R);
sigma ~ exponential(lambda_sigma);
ell ~ frechet(1, lambda_ell); // Only in 2D
}
// generated quantities {
// real check = 0.0; // should be the same as lp__
// { // I don't want to print R!
// matrix[N,N] R = cov(N, s, sigma, ell);
// check -= 0.5* dot_product(y,(R\ y)) + 0.5 * log_determinant(R);
// check += log(sigma) - lambda_sigma * sigma;
// check += log(ell) - 2.0 * log(ell) - lambda_ell / ell;
// }
// }
```

This is *a lot* easier than the code for the reference prior.

Let’s compare the results on some simulated data. Here I’m choosing , , and .

```
library(cmdstanr)
library(posterior)
n <- 25
dat <- tibble(s1 = runif(n), s2 = runif(n),
dist_mat = as.matrix(dist(cbind(s1,s2))),
y = MASS::mvrnorm(mu=rep(0,n),
Sigma = cov_fun(dist_mat, 1.0, 0.2)))
stan_dat <- list(y = dat$y,
s = cbind(dat$s1,dat$s2),
N = n,
lambda_ell = -log(0.05)*sqrt(0.05),
lambda_sigma = -log(0.05)/5)
mod_ref <- cmdstan_model("gp_ref_no_mean.stan")
mod_pc <- cmdstan_model("gp_pc_no_mean.stan")
```

First off, let’s look at the parameter estimates from the reference prior

```
fit_ref <- mod_ref$sample(data = stan_dat,
seed = 30127,
parallel_chains = 4,
iter_warmup = 2000,
iter_sampling = 2000,
refresh = 0)
```

```
Running MCMC with 4 parallel chains...
Chain 1 finished in 41.6 seconds.
Chain 2 finished in 43.4 seconds.
Chain 4 finished in 44.8 seconds.
Chain 3 finished in 47.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 44.2 seconds.
Total execution time: 47.2 seconds.
```

`fit_ref$print(digits = 2)`

```
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -30.95 -30.57 1.24 0.89 -33.46 -29.79 1.00 1397 896
sigma2 32.56 1.28 823.19 0.58 0.69 7.19 1.00 979 562
ell 9.04 0.26 240.39 0.16 0.11 1.88 1.00 927 542
sigma 1.67 1.13 5.46 0.27 0.83 2.68 1.00 979 562
```

It also took a bloody long time.

Now let’s check in with the PC prior.

```
fit_pc <- mod_pc$sample(data = stan_dat,
seed = 30127,
parallel_chains = 4,
iter_sampling = 2000,
refresh = 0)
```

```
Running MCMC with 4 parallel chains...
Chain 1 finished in 4.9 seconds.
Chain 4 finished in 5.1 seconds.
Chain 3 finished in 5.4 seconds.
Chain 2 finished in 5.5 seconds.
All 4 chains finished successfully.
Mean chain execution time: 5.2 seconds.
Total execution time: 5.6 seconds.
```

`fit_pc$print(digits = 2)`

```
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -10.36 -10.05 1.02 0.76 -12.42 -9.36 1.00 2160 3228
sigma 1.52 1.36 0.60 0.41 0.92 2.72 1.00 1424 1853
ell 0.67 0.45 0.72 0.27 0.19 1.89 1.00 1338 1694
```

You’ll notice two things there: it did a much better job at sampling and it was *much* faster.

Finally, let’s look at some plots. First off, let’s look at some 2D density plots.

```
library(cowplot)
samps_ref <- fit_ref$draws(format = "draws_df")
samps_pc <- fit_pc$draws(format = "draws_df")
p1 <- samps_ref |> ggplot(aes(ell, sigma)) +
geom_hex() +
scale_color_viridis_c()
p2 <- samps_pc |> ggplot(aes(ell, sigma)) +
geom_hex() +
scale_color_viridis_c()
plot_grid(p1,p2)
```

It would be interesting to look at how different the densities for are.

```
samps_pc |> ggplot(aes(ell)) +
geom_density() +
geom_density(aes(samps_ref$ell), colour = "red") +
xlim(0,2)
```

As expected, the PC prior (black) pulls the posterior towards the base model (), but what is interesting to me is that the posterior for the reference prior (red) has so much mass near zero. That’s the one thing we didn’t want to happen.

We can look closer at this by looking at the posterior for .

```
p3 <- samps_ref |>
mutate(kappa = 2/ell) |>
ggplot(aes(kappa, sigma)) +
geom_hex() +
scale_color_viridis_c()
p4 <- samps_pc |>
mutate(kappa = 2/ell) |>
ggplot(aes(kappa, sigma)) +
geom_hex() +
scale_color_viridis_c()
plot_grid(p3, p4)
```

To be brutally francis with you all, I’m not sure how much I trust that Stan posterior, so I’m going to look at the posterior along the ridge.

```
log_prior <- function(sigma, ell) {
V <- cov_fun(dat$dist_mat, sigma, ell)
dV <- (V * dat$dist_mat)/ell^2
U <- t(solve(V, dV))
lprior <- 0.5 * log(sum(diag(U %*% U)) - sum(diag(U))^2/n) - log(sigma)
}
log_posterior <- \(sigma, ell) log_prior(sigma, ell) + f_direct(sigma, ell)
m <- 500
pars <- \(c) tibble(ell = seq(0.001,2, length.out = m),
sigma = sqrt(c * ell), c = rep(c, m))
lpost <- map_df(seq(0.001, 8, length.out = 200),pars) |>
mutate(tau = c,
log_posterior = map2_dbl(sigma, ell,
possibly(log_posterior, otherwise = NA_real_)))
lpost |>
filter(log_posterior > -20) |>
ggplot(aes(ell, log_posterior, colour = tau, group = tau)) +
geom_line() +
#scale_color_brewer(palette = "Set1") +
theme_bw()
```

We can compare this with the likelihood surface.

```
llik <- map_df(seq(0.001, 8, length.out = 200),pars) |>
mutate(tau = c,
log_likelihood = map2_dbl(sigma, ell,
possibly(f_direct, otherwise = NA_real_)))
lprior <- map_df(seq(0.001, 8, length.out = 200),pars) |>
mutate(tau = c,
log_prior = map2_dbl(sigma, ell,
possibly(log_prior, otherwise = NA_real_)))
p1 <- llik |>
filter(log_likelihood > -50) |>
ggplot(aes(ell, log_likelihood, colour = tau, group = tau)) +
geom_line() +
#scale_color_brewer(palette = "Set1") +
theme_bw()
p2 <- lprior |>
filter(log_prior > -20) |>
ggplot(aes(ell, log_prior, colour = tau, group = tau)) +
geom_line() +
#scale_color_brewer(palette = "Set1") +
theme_bw()
plot_grid(p1, p2)
```

You can see here that the prior is putting *a lot* of weight at zero relative to the likelihood surface, which is relatively flat.

It’s also important to notice that the ridge isn’t as flat with as it is with . It would be very interesting to repeat this with larger values of , but frankly I do not have the time.

There is *a lot* more to say on this topic. But honestly this blog post is already enormous (you are a bit over halfway if you choose to read the technical guff). So I’m just going to summarise some of the things that I think are important here.

Firstly, the rigorous construction of the PC prior only makes sense when . This is a bit annoying, but it is what it is. I would argue that this construction is still fairly reasonable in moderate dimensions. (In high dimensions I think we need more research.)

There are two ways to see that. Firstly, if you look at the derivation of the distance, it involves an infinite sum that only converges when . But mathematically, if we can show^{53} that the partial sums can be bounded independently of , then we can just send another thing to infinity when we send the domain size and the base model length scale there.

A different way is to see this is to note that the PC prior distance is . This is proportional to the inverse of the volume of the -sphere^{54} of radius . This doesn’t seem like a massively useful observation, but just wait.

What if we ask ourselves “what is the average variance of over a ball of radius ?”. If we write as the Matérn covariance function, then^{55} where for all . If we remember that , then we can write this as Hence the PC prior on is penalising the change in average standard deviation over a ball relative to the unit length scale. With this interpretation, the base model is, once again, zero standard deviation. This reasoning carries over to the length scale parameter in *any*^{56} Gaussian process.

This post only covers the simplest version of Matérn GPs. One simple extension is to construct a non-stationary GP by replacing the Euclidean distance with the distance on a manifold with volume element . This might seem like a weird and abstract thing to do, but it’s an intrinsic specification of the popular deformation method due to Guttorp and Samson. Our paper covers the prior specification in this case.

The other common case that I’ve not considered here is the extension where there is a different length scale^{57} in each dimension. In this case, we could compute a PC prior independently for each dimension (so for each prior). To be completely honest with you, I worry a little bit about that choice in high dimensions^{58} (products of independent priors being notoriously weird), but I don’t have a better suggestion.

So you might have noticed that even though the previous section is a “conclusion” section, there is quite a bit more blog to go. I shan’t lie: this whole thing up to this point is a tl;dr that got wildly out of control.

The rest of the post is the details.

There are two parts. The first part covers enough^{59} of the theory of stationary GPs to allow us to understand the second part, which actually derives the PC prior.

It’s going to get a bit hairy and I’m going to assume you’ve at least skimmed through the first 2 definitions in my previous post defining GPs.

I fully expect that most people will want to stop reading here. But you shouldn’t. Because if I had to suffer you all have to suffer.

Gaussian processes with the Matérn covariance function are an excellent example of a stationary^{60} Gaussian process, which are characterised^{61} ^{62} by have covariance functions of the form where I am abusing notation and using for both the two parameter and one parameter functions. This assumption means that the correlation structure does not depend on where you are in space, only on the distance between points.

The assumption of stationarity massively simplifies GPs. Firstly, the stationarity assumption greatly reduces the number of parameters you need to describe a GP as we don’t need to worry about location-specific parameters. Secondly, it increases the statistical power of the data. If two subsets of the domain are more than apart, they are essentially independent replicates of the GP with the same parameters. This means that if the locations vary across a large enough area (relative to the natural length scale), we get multiple effective replicates^{63} from the same realisation of the process.

In practice, stationarity^{64} is often a *good enough* assumption when the mean has been modelled carefully, especially given the limitations of the data. That said, priors on non-stationary processes can be set using the PC prior methodology by using a stationary process as the base model. The supplementary material of our paper gives a simple, but useful, example of this.

The restriction to stationary processes is *extremely* powerful. It opens us up to using Fourier analysis as a potent tool for understanding GPs. We are going to need this to construct our KL divergence, and so with some trepidation, let’s dive into the moonee ponds of spectral representations.

The first thing that we need to do is remember what a *Fourier transform* is. A Fourier transform of a square integrable function is^{65}

If you have bad memories^{66} of desperately trying to compute Fourier integrals in undergrad, I promise you that we are not doing that today. We are simply affirming their right to exist (and my right to look them up in a table).

The reason I care about Fourier^{67} transforms is that if I have a non-negative measure^{68} , I can define a function If measures freak you out, you can—with some loss of generality—assume that there is a function such that We are going to call the spectral measure and the corresponding , if it exists, is called the spectral density.

I put it to you that, defined this way, is a (complex) positive definite function.

Recall^{69} that a function is positive definite if, for every for every , every , and every where is the complex conjugate of .

Using our assumption about we can write the left hand side as where .

We have shown that if , then it is a valid covariance function. This is also true, although much harder to prove, in the other direction and the result is known as Bochner’s theorem.

**Theorem 5 (Bochner’s theorem) **A function is positive definite, ie for every , every , and every if and only if there is a non-negative finite measure such that

Just as a covariance function^{70} is enough to completely specify a zero-mean Gaussian process, a spectral measure is enough to completely specify a zero mean *stationary* Gaussian process.

Our lives are mathematically much easier when represents a density that satisfies This function, when it exists, is precisely the Fourier transform of . Unfortunately, this will not exist^{71} for all possible positive definite functions. But as we drift further and further down this post, we will begin to assume that we’re only dealing with cases where exists.

The case of particular interest to us is the Matérn covariance function. The parameterisation used above is really lovely, but for mathematical convenience, we are going to set^{72} , which has^{73} Fourier transform where is defined implicitly above and is a constant (as we are keeping fixed).

To see this, we need a tiny bit of machinery. Specifically, we need the concept of a Gaussian -noise and its corresponding integral.

**Definition 1 (Complex -noise) **A (complex) -noise^{74} is a random measure^{75} such that, for every^{76} disjoint^{77} pair of sets satisfies the following properties

- has mean zero and variance ,
- If and are disjoint then
- If and are disjoint then and are uncorrelated
^{78}, ie .

This definition might not seem like much, but imagine a simple^{79} piecewise constant function where and the sets are pairwise disjoint and . Then we can define an integral with respect to the -noise as which has mean and variance where the first equality comes from noting that and are uncorrelated and the last equality comes from the definition of an integral of a piecewise constant function.

Moreover, we get the covariance

A nice thing is that while these piecewise constant functions are quite simple, we can approximate *any*^{80} function arbitrarily well by a simple function. This is the same fact we use to build ourselves ordinary^{81} integrals.

In particular, the brave and the bold among you might just say “we can take limits here and *define*” an integral with respect to the -noise this way. And, indeed, that works. You get that, for any ,

and, for any ,

If we define then it follows immediately that is mean zero and has covariance function That is is the spectral measure associated with the correlation function.

Combining this with Bochner’s theorem, we have just proved^{82} the spectral representation theorem for general^{83} (weakly) stationary^{84} random fields^{85}.

**Theorem 6 (Spectral representation theorem) **If is a finite, non-negative measure on and is a complex -noise, then the complex-valued process has mean zero an covariance and is therefore weakly stationary. If then is a Gaussian process.

Furthermore, every mean-square continuous mean zero stationary Gaussian process with covariance function and corresponding spectral measure has an associated -noise such that holds in the mean-square sense for all .

is called the *spectral process* ^{86} associated with . When it exists, the density of , denoted by , is called the *spectral density* or the *power spectrum*.

All throughout here I used complex numbers and complex Gaussian processes because, believe it or not, it makes things easier. But you will be pleased to know that will be real-valued as long as the spectral density is symmetric around the origin. And it always is.

One particular advantage of stationary processes is that we get a straightforward characterization of the Cameron-Martin space inner product. Recall that the Cameron-Martin space (or reproducing kernel Hilbert space) associated with a Gaussian process is the^{88} space of all functions of the form where is finite, are real, and are distinct points in . This is the space that the posterior mean for GP regression lives in.

The inner product associated with this space can be written in terms of the spectral density as^{89} In particular, for a Matérn Gaussian process, the corresponding norm is For those of you familiar with function spaces, this is equivalent to the norm on . One way to interpret this is that the *set* of functions in the Cameron-Martin space for a Matérn GP only depends on , while the norm and inner product (and hence the posterior mean and all that stuff) depend on , , and . This observation is going to be important.

It would’ve been a bit of an odd choice to spend all this time talking about spectral representations and never using them. So in this section, I’m going to cover the reason for the season: singularity or absolute continuity of Gaussian measures.

The Feldman-Hájek theorem quoted is true on quite general sets of functions. However, if we are willing to restrict ourselves to a separable^{90} Hilbert^{91} space there is a much more refined version of the theorem that we can use.

**Theorem 7 (Feldman-Hájek theorem (Taylor’s ^{92} version)) **Two Gaussian measures (mean , covariance operator

The Cameron-Martin spaces associated with and are the same (considered as sets of functions. They usually will not have the same inner products.),

is in the

^{94}Cameron-Martin space, andThe operator is a Hilbert-Schmidt operator, that is it has a countable set of eigenvalues and corresponding eigenfunctions that satisfy and

When these three conditions are fulfilled, the Radon-Nikodym derivative is where is an sequence of N(0,1) random variables^{95} ^{96} (under ).

Otherwise, the two measures are singular.

This version of Feldman-Hájek is considerably more useful than its previous incarnation. The first condition basically says that the posterior means from the two priors will have the same smoothness and is rarely a problem. Typically the second condition is fulfilled in practice (for example, we always set the mean to zero).

The third condition is where all of the action is. This is, roughly speaking, a condition that says that and aren’t toooooo different. To understand this, we need to look a little at what the values actually are. It turns out to actually be easier to ask about , which are the eigenvalues of . In that case, we are trying to find the orthonormal system of functions such that where .

Hence, we can roughly interpret the as the eigenvalues of The Hilbert-Schmidt condition is then requiring that is not infinitely far from the identity mapping.

A particularly nice version of this theorem occurs when and have the *same* eigenvectors. This is a fairly restrictive assumption, but we are going to end up using it later, so it’s worth specialising. In that case, assuming has eigenvalues and corresponding -orthogonal eigenfunctions , we can write^{97} Using the orthogonality of the eigenfunctions, we can show^{98} that

With a bit of effort, we can see that and so From that, we get^{99} the KL divergence

Possibly unsurprisingly, this is simply the sum of the one dimensional divergences It’s fun to convince yourself that that is sufficient to ensure the sum converges.

Ok. So I lied. I suggested that we’d use all of that spectral stuff in the last section. And we didn’t! Because I’m dastardly. But this time I promise we will!

It turns out that even with our fancy version of Feldman-Hájek, it can be difficult^{100} to work out whether two Gaussian processes are singular or equivalent. One of the big challenges is that the eigenvalues and eigenfunctions depend on the domain and so we would, in principle, have to check this quite complex condition for every single domain.

Thankfully, there is an easy to parse sufficient condition that we can use that show when two GPs are equivalent on *every* bounded domain. These conditions are stated in terms of the spectral densities.

**Theorem 8 (Sufficent condition for equivalence (Thm 4 of Skorokhod and Yadrenko)) **Let and be mean-zero Gaussian processes with spectral densities , . Assume that is bounded away from zero and infinity for some^{101} and Then the joint distributions of and are equivalent measures for every bounded region .

The proof of this is pretty nifty. Essentially it constructs the operator in a sneaky^{102} way and then bounds its trace on rectangle containing . That upper bound is finite precisely when the above integral is finite.

Now that we have a relatively simple condition for equivalence, let’s look at Matérn fields. In particular, we will assume , are two Matérn GPs with the same smoothness parameter and other parameters^{103} . We can save ourselves some trouble by considering two cases separately.

**Case 1:** .

In this case, we can make the change to spherical coordinates via the substitution and, again to save my poor fingers, let’s set . The condition becomes To check that this integral is finite, first note that, near , the integrand is^{104} , so there is no problem there. Near (aka the other place bad stuff can happen), the integrand is This is integrable for large whenever^{105} . Hence, the two fields are equivalent whenever and . It is harder, but possible to show that the fields are singular when . The case with is boring and nobody cares.

**Case 2: ** .

Let’s define . Then it’s clear that and therefore the Matérn field with parameters is equivalent to .

We will now show that and are singular, which implies that and are singular. To do this, we just need to note that, as and have the *same* value of , We know, from the previous blog post, that and will be singular unless , but this only happens when , which is not true by assumption.

Hence we have proved the first part of the following Theorem due, in this form, to Zhang^{106} (2004) and Anderes^{107} (2010).

**Theorem 9 (Thm 2 of Zhang (2004)) **Two Gaussian process on , , with Matérn covariance functions with parameters , induce equivalent Gaussian measures if and only if When , the measures are always singular (Anderes, 2010).

With all of that in hand, we are finally (finally!) in a position to show that, in 3 or fewer dimensions, the PC prior distance is . After this, we can put everything together! Hooray!

Now, you can find a proof of this in the appendix of our JASA paper, but to be honest it’s quite informal. But although you can sneak any old shite into JASA, this is a blog goddammit and a blog has integrity. So let’s do a significantly more rigorous proof of our argument.

To do this, we will need to find the KL divergence between , with parameters and a base model with parameters , where is some fixed, small number and is fixed. We will actually be interested in the behaviour of the KL divergence as goes to zero. Why? Because is our base model.

The specific choice of standard deviation in both models ensures that and so the KL divergence is finite.

In order to approximate the KL divergence, we are going to find a basis that simultaneously diagonalises both processes. In the paper, we simply declared that we could do this. And, morally, we can. But as I said a blog aims to a higher standard than mere morality. Here we strive for meaningless rigour.

To that end, we are going to spend a moment thinking about how this can be done in a way that isn’t intrinsically tied to a given domain . There may well be a lot of different ways to do this, but the most obvious one is to notice that if is *periodic* on the cube for some , then it can be considered as a GP on a -dimensional torus. If is large enough that , then we might be able to focus on our cube and forget all about the specific domain .

A nice thing about periodic GPs is that we actually know their Karhunen-Loève^{108} representation. In particular, if is a stationary covariance function on a torus, then we^{109} know that it’s eigenfunctions are and its eigenvalues are This gives^{110}

Now we have some work to do. Firstly, our process is not periodic^{111} on . That’s a bit of a barrier. Secondly, even if it were, we don’t actually know what is going to be. This is probably^{112} an issue.

So let’s make this sucker periodic. The trick is to note that, at long enough distances, and are almost uncorrelated. In particular, if , then . This means that if we are interested in on a fixed domain , then we can replace it with that is a GP where the covariance function is the periodic extension of from to (aka we just repeat it!).

This repetition won’t be noticed on as long as is big enough. But we can run into the small^{113} problem. This procedure can lead to a covariance function that is *not* positive definite. Big problem. Huge.

It turns out that one way to fix this is is to use a smooth cutoff function that is 1 on and 0 outside of , where is big enough so that and . We can then build the periodic extension of a stationary covariance function as It’s important^{114} to note that this is not the same thing as simply repeating the covariance function in a periodic manner. Near the boundaries (but outside of the domain) there will be some reach-around contamination. Bachmayr, Cohen, and Migliorati show that this *does not work* for general stationary covariance functions, but does work under the additional condition that is big enough and there exist some and such that This condition obviously holds for the Matérn covariance function and Bachmayr, Graham, Nguyen, and Scheichl^{115} showed that for some explicit function that only depends on and is sufficient to make this work.

The nice thing about this procedure is that as long as , which means that our inference is going to be *identical* on our sample as it would be with the non-periodic covariance function! Splendid!

Now that we have made a valid periodic extension (and hence we know what the eigenfunctions are), we need to work out what the corresponding eigenvalues are.

We know that But it is not clear what will happen when we take the Fourier transform of .

Thankfully, the convolution theorem is here to help us and we know that, if , then where is the convolution operator.

In the perfect world, would be very close to zero, so we can just replace the Fourier transform of with the Fourier transform of . And thank god we live in a perfect world.

The specifics here are a bit tedious^{116}, but you can show that as . For Matérn fields, Bachmayr etc performed some heroic calculations to show that the difference is exponentially small as and that, as long as , everything is positive definite and lovely.

So after a bunch of effort and a bit of a literature dive, we have finally got a simultaneous eigenbasis and we can write our KL divergence as We can write this as for some constant that you can actually work out but I really don’t need to. The important thing is that the error is exponentially small in , which is very large and spiraling rapidly out towards infinity.

Then, noticing that the sum is just a trapezium rule approximation to a -dimensional integral, we get, as (and hence ), The integral converges whenever .

This suggests that we can re-scale the distance by absorbing the into the constant in the PC prior, and get

This distance does not depend on the specific domain (or the observation locations), which is an improvement over the PC prior I derived in the introduction. Instead, it only assumes that is bounded, which isn’t really a big restriction in practice.

With all of this in hand, we can now construct the PC prior. Instead of working directly with , we will instead derive the prior for the estimable parameter , and the non-estimable parameter .

We know that multiplies the covariance function of , so it makes sense to treat like a standard deviation parameter. In this case, the PC prior is The canny among you would have noticed that I have made the scaling parameter depend on . I have done this because the quantity of interest that we want our prior to control is the marginal standard deviation , which is a function of . If we want to ensure , we need

We can now derive the PC prior for . The distance that we just spent all that effort calculating, and an exponential prior on leads^{117} to the prior Note that in this case, does not depend on any other parameters: this is because is our identifiable parameter. If we require , we get

Hence the joint PC prior on , which is emphatically *not* the product of two independent priors, is

Great gowns, beautiful gowns.

Of course, we don’t want the prior on some weird parameterisation (even though we needed that parameterisation to derive it). We want it on the original parameterisation. And here is where some magic happens! When we transform this prior to -space it magically^{118} becomes the product of two independent priors! In particular, the PC prior that encodes and is

It. Is. Finished.

The most common feedback was “I hung in for as long as I could”.↩︎

If you don’t think we’re gonna get our Maccabees on you’re dreamin’. Hell, I might have to post Enoch-ussy on main.↩︎

Penalised Complexity priors (or PC priors) are my favourite thing. If you’re unfamilliar with them, I strongly recommend you read the previous post on PC priors to get a good grip on what they are, but essentially they’re a way to construct principled, weakly informative prior distributions. The key tool for PC priors is the Kullback-Leibler divergence between a model with parameter and a fixed base model with parameter . Computing the KL divergence between two GPs is, as we will see, a challenge.↩︎

Fun fact: when we were starting to work on PC priors we were calling them PCP priors, but then I remembered that one episode of CSI where some cheerleaders took PCP and ate their friend and we all agreed that that wasn’t the vibe we were going for.↩︎

you might just need to trust me at some points↩︎

It could be easily more complex with multilevel component, multiple GPs, time series components etc etc. But the simplest example is a GP regression.↩︎

The GP has mean zero for the same reason we usually centre our covariates: it lets the intercept model the overall mean.↩︎

Not just the likelihood but also everything else in the model↩︎

A challenge with reference priors is that they are often improper (aka they don’t integrate to 1). This causes some conceptual difficulties, but there is a whole theory of Bayes that’s mostly fine with this as long as the resulting posterior integrates to one. But this is by no means guaranteed and is typically only checked in very specific cases. Jim Berger, one of the bigger proponents of reference prior, used to bring his wife to conference poster sessions. When she got bored, she would simply find a grad student and ask them if they’d checked if the posterior was proper. Sometimes you need to make your own fun.↩︎

Hope has no place in statistics.↩︎

Remember that any number on the logit scale outside of might as well be the same number↩︎

`log(.Machine$integer.max) = 21.48756`

↩︎, so 70% of the prior mass is less than that. 90% of the prior mass is less than and 99% is less than . This is still a weak prior.↩︎

Conceptually. The mathematics of what happens as aren’t really worth focusing on.↩︎

Or, you know, linear functionals↩︎

You can find Bayesians who say that they don’t care if cross validation works or not. You can find Bayesians who will say just about anything.↩︎

There are lots of parameterisations, but they’re all easy to move between. Compared to wikipedia, we use the scaling rather than the scaling.↩︎

Everything in this post can be easily generalised to having different length scales on each dimension.↩︎

If you’ve not run into these before, is finite at zero and decreases monotonically in an exponential-ish fashion as .↩︎

Possibly trying several values and either selecting the best or stacking all of the models↩︎

Field because by rights GPs with multidimensional parameter spaces should be called

*Gaussian Fields*but we can’t have nice things so whatever. Live your lives.↩︎At which point you need to ask yourself if one goes their faster. It’s chaos.↩︎

Asymptotics as copaganda.↩︎

I mean, if you can repeat experiments that’s obviously amazing, but there are lots of situations where that is either not possible or not the greatest use of resources. There’s an interesting sub-field of statistical earth sciences that focuses on working out the value of getting new types of observations in spatial data. This particular variant of the value of information problem throws up some fun corners.↩︎

or hoping↩︎

in 3 or fewer dimensions↩︎

I have not fact checked this↩︎

Basically everything you care about. Feel free to google the technical definition. But any space with a metric is locally convex. Lots of things that aren’t metric spaces are too.↩︎

measurable↩︎

This will seem a bit weird if it’s the first time you’ve seen the concept. In finite dimensions (aka most of statistics)

*every*Gaussian is equivalent to every other Gaussian. In fact, it’s equivalent to every other continuous distribution with non-zero density on the whole of . But shit gets weird when you’re dealing with functions and we just need to take a hit of the video head cleaner and breathe until we get used to it.↩︎These measures

*are not the same*. They just happen to be non-zero on the same sets.↩︎This was proven in the monster GP blog post.↩︎

eg, computationally where Metropolis-Hastings acceptance probabilities have an annoying tendency to go to zero unless you are extraordinarily careful.↩︎

if it exists↩︎

This can be interpreted as the event that infinity many times for every epsilon. If this event occurs with any probability, it would strongly suggest that the estimator is not bloody converging.↩︎

or even many↩︎

Technically, a recent paper in JRSSSB said that if you add an iid Gaussian process you will get identifiability, but that’s maybe not the most realistic asymptotic approximation.↩︎

The fourth dimension is where mathematicians go to die↩︎

It’s computationally pretty expensive to plot the whole likelihood surface, so I’m just doing it along lines↩︎

`partial`

freezes a few parameter values, and`possibly`

replaces any calls that return an error with an NA↩︎That I could find↩︎

To be fair to van der Vaart and van Zanten their particular problem doesn’t necessarily have a ridge!↩︎

Saddle up for some spectral theory.↩︎

I’m terribly sorry.↩︎

I’m moderately sure that the preprint is pretty similar to the published version but I am not going to check.↩︎

Can’t stress enough that this is smoothness in a qualitative sense rather than in the more technical “how differentiable is it?” sense.↩︎

Truly going wild with the scare quotes. Always a sign of excellent writing.↩︎

For the usual smoothing spline with the square of the Laplacian, you need . Other values of still give you splines, just with different differentiability assumptions.↩︎

If your data is uniformly spaced, you can use the minimum. Otherwise, I suggest a low quantile of the distribution of distances. Or just a bit of nous.↩︎

The second half of this post is devoted to proving this. And it is

*long*.↩︎With this parameterisation it’s sometimes known as a Type-II Gumbel distribution. Because why not.↩︎

And

*only*in this case! The reference prior changes a lot when there is a non-zero mean, when there are other covariates, when there is observation noise, etc etc. It really is quite a wobbly construction.↩︎Readers, I have not bothered to show.↩︎

Part of why I’m reluctant to claim this is a good idea in particularly high dimensions is that volume in high dimensions is frankly a bit gross.↩︎

I, for one, love a sneaky transformation to spherical coordinates.↩︎

So why do all the technical shit to derive the PC prior when this option is just sitting there? Fuck you, that’s why.↩︎

This is sometimes called “automatic relevance determination” because words don’t have meaning anymore. Regardless, it’s a pretty sensible idea when you have a lot of covariates that can be quite different.↩︎

It is possible that a horseshoe-type prior on would serve better, but there are going to be some issues as that will shrink the geometric mean of the length scales towards 1.↩︎

Part of the motivation for writing this was to actually have enough of the GP theory needed to think about these priors in a single place.↩︎

In fact, it’s isotropic, which is a stricter condition on most spaces. But there’s no real reason to specialise to isotropic processes so we simply won’t.↩︎

We are assuming that the mean is zero, but absent that assumption, we need to assume that the mean is constant.↩︎

For non-Gaussian processes, this property is known as

*second-order*stationarity. For GPs this corresponds to strong stationary, which is a property of the distribution rather than the covariance function ↩︎If you’ve been exposed to the concept of ergodicity of random fields you may be eligible for compensation.↩︎

Possibly with different length scales in different directions or some other form of anisotropy↩︎

This is normalisation is to make my life easier.↩︎

Let’s not lie, I just jumped straight to complex numbers. Some of you are having flashbacks.↩︎

Fourier-Stieljes↩︎

countably additive set-valued function. Like a probability but it doesn’t have to total to one↩︎

and complexify↩︎

or a Cameron-Martin space↩︎

That is, this measure bullshit isn’t just me pretending to be smart. It’s necessary.↩︎

Feeling annoyed by a reparameterisation this late in the blog post? Well tough. I’ve got to type this shit out and if I had to track all of those s I would simply curl up and die.↩︎

In my whole damn life I have never successfully got the constant correct, so maybe check that yourself. But truly it does not matter. All that matters for the purposes of this post is the density as a function of .↩︎

This is not restricted to being Gaussian, but for all intents and porpoises it is.↩︎

Countably additive set-valued function taking any value in ↩︎

-measurable↩︎

↩︎

If is also Gaussian then this is the same as them being independent↩︎

This is the technical term for this type of function because mathematicians weren’t hugged enough as children.↩︎

for a particular value of “any”↩︎

for a particular value of “ordinary”↩︎

Well enough for a statistician anyway. You can look it up the details but if you desperately need to formalise it, you build an isomorphism between and and use that to construct . It’s not

*wildly*difficult but it’s also not actually interesting except for mathturbatory reasons.↩︎Non-Gaussian!↩︎

On more spaces, the same construction still works. Just use whatever Fourier transform you have available.↩︎

or stochastic processes↩︎

Yes, it’s a stochastic process over some -algebra of sets in my definition.

*Sometimes*people use as the spectral process and interpret the integrals as Lebesgue-Stieltjes integrals. All power to them! So cute! It makes literally no difference and truly I do not think it makes anything easier. By the time you’re like “you know what, I reckon Stieltjes integrals are the way to go” you’ve left “easier” a few miles back. You’ve still got to come up with an appropriate concept of an integral.↩︎Also known as the Reproducing Kernel Hilbert Space even though it doesn’t actually have to be one. This is the space of all means. See the previous GP blog.↩︎

closure of the↩︎

In the previous post, I wrote this in terms of the inverse of the covariance operator. For a stationary operator, the covariance operator is (by the convolution theorem) and it should be pretty easy to convince yourself that ↩︎

ie one where we can represent functions using a Fourier series rather than a Fourier transform↩︎

ie one with an inner product↩︎

Bogachev’s Gaussian Measures book, Corollary 6.4.11 with some interpretation work to make it slightly more human-readable. I also added the minus sign he missed in the density.↩︎

Recall that this is the integral operator .↩︎

Because of condition 1 if it’s in one of them it’s in the other too!↩︎

Technically, they are an orthonormal basis in the closure of under the norm, but let’s just be friendly to ourselves and pretend have zero mean so these spaces are the same. The theorem is very explicit about what they are. If are the (-orthonormal) eigenfunctions corresponding to , then where is the spectral process associated with . Give or take, this the same thing I said in the main text.↩︎

After reading all of that, let me tell you that it simply does not matter even a little bit.↩︎

Yes - this is Mercer’s theorem again. The only difference is that we are assuming that the eigenfunctions are the same for each so they don’t need an index.↩︎

↩︎

You simply cannot make me care enough to prove that we can swap summation and expectation. Of course we bloody can. Also .↩︎

But not impossible. Kristin Kirchner and David Bolin have done some very nice work on this recently.↩︎

This is a stronger condition than the one in the paper, but it’s a) readily verifiable and b) domain independent.↩︎

This is legitimately quite hard to parse. You’ve got to back-transform their orthogonal basis to an orthogonal basis on , which is where those inverse square roots come from!↩︎

Remember because Daddy hates typing.↩︎

Through the magical power of WolframAlpha or, you know, my own ability to do simple Taylor expansions.↩︎

↩︎

↩︎

↩︎

The other KL. The spicy, secret KL. KL after dark. What Loève but a second-hand Karhunen?↩︎

This is particularly bold use of the inclusive voice here. You may or may not know. Nevertheless it is true.↩︎

Specifically, this kinda funky set of normalisation choices that statisticians love to make gives↩︎

If you think a bit about it, a periodic function on can be thought of as a process on a torus by joining the approrpriate edges together!↩︎

We will see that this is not an issue, but you better bloody believe that our JASA paper just breezed the fuck past these considerations. Proof by citations that didn’t actually say what we needed them to say but were close enough for government work. Again, this is one of those situations where the thing we are doing is obviously valid, but the specifics (which are unimportant for our situation because we are going to send and in a way that’s

*much*faster than ) are tedious and, I cannot stress this enough, completely unimportant in this context. But it’s a fucking blog and a blog has a type of fucking integrity that the Journal of the American Fucking Statistical Association does not even almost claim to have. I’ve had some red wine.↩︎big↩︎

I cannot stress enough that we’re not bloody implementing this scheme, so it’s not even slightly important. Scan on, McDuff.↩︎

Fun fact. I worked in the same department as authors 2 and 4 for a while and they are both very lovely.↩︎

Check out either of the Bachmayr

*et al.*papers if you’re interested.↩︎Thanks Mr Jacobian!↩︎

I feel like I’ve typed enough, if you want to see the Jacobian read the appendices of the paper.↩︎

BibTeX citation:

```
@online{simpson2022,
author = {Dan Simpson},
editor = {},
title = {Priors for the Parameters in a {Gaussian} Process},
date = {2022-09-27},
url = {https://dansblog.netlify.app/posts/2022-09-07-priors5/priors5.html},
langid = {en}
}
```

For attribution, please cite this work as:

Dan Simpson. 2022. “Priors for the Parameters in a Gaussian
Process.” September 27, 2022. https://dansblog.netlify.app/posts/2022-09-07-priors5/priors5.html.

In about 2016, Almeling *et al.* published a paper that suggested aged Barbary macaques maintained interest in members of their own species while losing interest in novel non-social stimuli (eg toys or puzzles with food inside).

This is where Eliza—who knows a little something about monkeys—comes into frame: this did not gel with her experiences at all.

So Eliza (and Mark^{1} ^{2}, who also knows a little something about monkeys) decided to look into it.

A big motivation for studying macaques and other non-human primates is that they’re good models of humans. This means that if there was solid evidence of macaques becoming less interested in novel stimuli as they age (while maintaining interest in people), this could suggest an evolutionary reason from this (commonly observed) behaviour in humans.

So if this result is true, it could help us understand the psychology of humans as they age (and in particular, the learned vs evolved trade off they are making).

There are a few things you can do when confronted with a result that contradicts your experience: you can complain about it on the Internet, you can mobilize a direct replication effort, or you can conduct your own experiments. Eliza and Mark opted for the third option, designing a *conceptual replication*.

Direct replications tell you more about the specific experiment that was conducted, but not necessarily more about the phenomenon under investigation. In a study involving aged monkeys^{4}, it’s difficult to imagine how a direct replication could take place.

On the other hand, a conceptual replication has a lot more flexibility. It allows you to probe the question in a more targeted manner, appropriate for incremental science. In this case, Eliza and Mark opted to study only the claim that the monkeys lose interest in novel stimuli as they age (paper here). They did not look into the social claim. They also used a slightly different species of macaque (*M. mulatta* rather than *M. butterfly*). This is reasonable insofar as understanding macaques as a model for human behaviour.

The experiment used 243^{5} monkeys aged between 4 and 30 and gave them a novel puzzle task (opening a fancy tube with food in it) for twenty minutes over two days. The puzzle was fitted with an activity tracker. Each monkey had two tries at the puzzle over two days. Monkeys had access to the puzzle for around^{6} 20 minutes.

In order to match the original study’s analysis, Eliza and Mark divided the first two minutes into 15 second intervals and counted the number of intervals where the monkey interacted with the puzzle. They also measured the same thing over 20 minutes in order to see if there was a difference between short-term curiosity and more sustained exploration.

For each monkey, we have the following information:

- Monkey ID
- Age (4-30)
- Day (one or two)
- Number of active intervals in the first two minutes (0-8)
- Number of active intervals in the first twenty minutes (0-80)

The data and their analysis are freely^{7} available here.

```
library(tidyverse)
acti_data <- read_csv("activity_data.csv")
activity_2mins <- acti_data |>
filter(obs<9) |> group_by(subj_id, Day) |>
summarize(total=sum(Activity),
active_bins = sum(Activity > 0),
age = min(age)) |>
rename(monkey = subj_id, day = Day) |>
ungroup()
activity_20minms80 <- acti_data |> filter(obs<81) |>
group_by(subj_id, Day) |>
summarize(total=sum(Activity),
active_bins = sum(Activity > 0),
age = min(age)) |>
rename(monkey = subj_id, day = Day) |>
ungroup()
glimpse(activity_20minms80)
```

```
Rows: 485
Columns: 5
$ monkey <dbl> 0, 0, 88, 88, 636, 636, 760, 760, 1257, 1257, 1607, 1607, …
$ day <dbl> 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2…
$ total <dbl> 9881, 6356, 15833, 4988, 572, 308, 1097, 2916, 4884, 2366,…
$ active_bins <int> 42, 34, 43, 19, 10, 4, 12, 23, 50, 33, 9, 11, 13, 7, 30, 3…
$ age <dbl> 29, 29, 29, 29, 28, 28, 30, 30, 27, 27, 27, 27, 27, 27, 26…
```

Eliza and Mark’s monkey data is an example of a fairly common type of experimental data, where the same subject is measured multiple times. It is useful to break the covariates down into three types: *grouping variables*, *group-level covariates*, and *individual-level covariates*.

*Grouping variables* indicate what *group* each observation is in. We will see a lot of different ways of defining groups as we go on, but a core idea is that observations within a group should conceptually more similar to each other than observations in different groups. For Eliza and Mark, their grouping variable is `monkey`

. This encodes the idea that different monkeys might have very different levels of curiosity, but the same monkey across two different days would probably have fairly similar levels of curiosity.

*Group-level covariates* are covariates that describe a feature of the *group* rather than the observation. In this example, `age`

is a group-level covariate, because the monkeys are the same age at each observation.

*Individual-level covariates* are covariates that describe a feature that is specific to an observation. (The nomenclature here can be a bit confusing: the “individual” refers to individual observations, not to individual monkeys. All good naming conventions go to shit eventually.) The individual-level covariate is experiment day. This can be a bit harder to see than the other designations, but it’s a little clearer if you think of it as an indicator of whether this is the first time the monkey has seen the task or the second time. Viewed this way, it is very clearly a measurement of an property of an observation rather than of a group.

Eliza and Mark’s monkey data is an example of a fairly general type of experimental data where subjects (our groups) are given the same task under different experimental conditions (described through individual-level covariates). As we will see, it’s not uncommon to have much more complex group definitions (that involve several grouping covariates) and larger sets of both group-level and individual-level covariates.

So how do we fit a model to this data.

The temptation with this sort of data is to fit a linear regression to it as a first model. In this case, we are using grouping, group-level, and individual-level covariates in the same way. Let’s suck it and see.

```
library(broom)
fit_lm <- lm(active_bins ~ age*factor(day) + factor(monkey), data = activity_2mins)
tidy(fit_lm)
```

So the first thing you will notice is that that is *a lot* of regression coefficients! There are 243 monkeys and 2 days, but only 485 observations. This isn’t enough data to reliably estimate all of these parameters. (Look at the standard errors for the monkey-related coefficients. They are huge!)

So what are we to do?

The problem is the monkeys. If we use `monkey`

as a factor variable, we only have (at most) two observations of each factor level. This is simply not enough observations per to estimate a different intercept for each monkey!

This type of model is often described as having *no pooling*, which indicates that there is no explicit dependence between the intercepts for each group (`monkey`

). (There is some dependence between groups due to the group-level covariate `age`

.)

Our first attempt at a regression model didn’t work particularly well, but that doesn’t mean we should give up^{8}. A second option is that we can assume that there is, fundamentally, no difference between monkeys. If all monkeys of the same age have similar amounts of interest in new puzzles, this would be a reasonable assumption. The best case scenario is that not accounting for differences between individual monkeys would still lead to approximately normal residuals, albeit with probably a larger residual variance.

This type of modelling assumption is called *complete pooling* as it pools the information between groups by treating them all as the same.

Let’s see what happens in this case!

```
fit_lm_pool <- lm(active_bins ~ age*factor(day), data = activity_2mins)
summary(fit_lm_pool)
```

```
Call:
lm(formula = active_bins ~ age * factor(day), data = activity_2mins)
Residuals:
Min 1Q Median 3Q Max
-4.5249 -1.5532 0.1415 1.6731 4.1884
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 3.789718 0.344466 11.002 <2e-16 ***
age 0.003126 0.021696 0.144 0.885
factor(day)2 0.056112 0.488818 0.115 0.909
age:factor(day)2 0.025170 0.030759 0.818 0.414
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 2.103 on 481 degrees of freedom
Multiple R-squared: 0.01365, Adjusted R-squared: 0.0075
F-statistic: 2.219 on 3 and 481 DF, p-value: 0.0851
```

On the up side, the regression runs and doesn’t have too many parameters!

The brave and the bold might even try to interpret the coefficients and say something like *there doesn’t seem to be a strong effect of age*. But there’s real danger in trying to interpret regression coefficients in the presence of a potential confounder (in this case, the monkey ID). And it’s particularly bad form to do this without ever looking at any sort of regression diagnostics. Linear regression is not a magic eight ball.

Let’s look at the diagnostic plots.

```
library(broom)
augment(fit_lm_pool) |>
ggplot(aes(x = .fitted, y = active_bins - .fitted)) +
geom_point() +
geom_smooth(method = "lm", se = FALSE) +
theme_classic()
```

```
augment(fit_lm_pool) |> ggplot(aes(sample = .std.resid)) +
stat_qq() +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
theme_classic()
```

There are certainly some patterns in those residuals (and some suggestion that the error need a heavier tail for this model to make sense).

We are in a Goldilocks situation: no pooling results in a model that has too many independent parameters for the amount of data that we’ve got, while complete pooling has too few parameters to correctly account for the differences between the monkeys. So what is our perfectly tempered porridge^{9}?

The answer is to assume that each monkey has its own intercept, but that it’s intercept can only be *so far* from the overall intercept (that we would’ve gotten from complete pooling). There are a bunch of ways to realize this concept, but the classical method is to use a normal distribution.

In particular, if the th monkey has observations , , then we can write our model as

The effects of age and day and the data standard deviation () are just like they’d be in an ordinary linear regression model. Our modification comes in how we treat the .

In a classical linear regression model, we would fit the s independently, perhaps with some weakly informative prior distribution. But we’ve already discussed that that won’t work.

Instead we will make the *exchangeable* rather than independent. Exchangeability is a relaxation of the independence assumption to say instead encode that we have no idea which of the intercepts will do what. That is, if we switch around the labels of our intercepts the prior should not change. There is a long and storied history of exchangeable models in statistics, but the short version that is more than sufficient for our purposes is that they usually^{10} take the form

In a regression context, we typically assume that for some and that will need their own priors.

We can explore this difference mathematically. The regression model, which assumes independence of the , uses as the joint prior on . On the other hand, the exchangeable model, which forms the basis of multilevel models, assumes the joint prior for some prior on on .

This might not seem like much of a change, but it can be quite profound. In both cases, the prior is saying that each is, with high probability, at most away from the overall mean . The difference is that while the classical least squares formulation uses a fixed value of that needs to be specified by the modeller, while the exchangeable model lets adapt to the data.

This data adaptation is really nifty! It means that if the groups have similar means, they can borrow information from the other groups (via the narrowing of ) in order to improve their precision over an unpooled estimate. On the other hand, if there is a meaningful difference between the groups^{11}, this model can still represent that, unlike the unpooled model.

In our context, however, we need a tiny bit more. We have a *group-level covariate* (specifically `age`

) that we think is going to effect the group mean. So the model we want is

In order to fully specify the model we need to set the four prior distributions.

This is an example of a *multilevel*^{12} *model*. The name comes from the data having multiple levels (in this case two: the observation level and the group level). Both levels have an appropriate model for their mean.

This mathematical representation does a good job in separating out the two different levels. However, there are a lot of other ways of writing multilevel models. An important example is the extended formula notation created^{13} by R’s `lme4`

package. In their notation, we would write this model as

`formula <- active_bins_scaled ~ age_centred*day + (1 | monkey)`

The first bit of this formula is the same as the formula used in linear regression. The interesting bit is is the `(1 | monkey)`

. This is the way to tell R that the intercept (aka `1`

in formula notation) is going to be grouped by `monkey`

and we are going to put an exchangeable normal prior on it. For more complex models there are more complex variations on this theme, but for the moment we won’t go any further.

We need to set priors. The canny amongst you may have noticed that I did not set priors in the previous two examples. There are two reasons for this: firstly I didn’t feel like it, and secondly none but the most terrible prior distributions would have meaningfully changed the conclusions. This is, it turns out, one of the great truths when it comes to prior distributions: *they do not matter until they do*^{14}.

In particular, if you have a parameter that *directly* sees the data (eg it’s in the likelihood) and there is nothing weird going on^{15}, then the prior distribution will usually not do much as any prior will be quickly overwhelmed by the data.

The problem is that we have one parameter in our model () that does not directly see the data. Instead of directly telling us about an observation, it tells us about how different the *groups* of observations are. There is usually less information in the data about this type of parameter and, consequently, the prior distribution will be more important. This is especially true when you have more than one grouping variable, or when a variable only has a small number of groups.

So let’s pay some proper attention to the priors.

To begin with, let’s set priors on , , and (aka the data-level parameters). This is a *considerably* easier task if the data is scaled. Otherwise, you need to encode information about the usual scale^{16} of the data into your priors. Sometimes this is a sensible and easy thing to do, but usually it’s easier to simply scale the data. (A lot of software will simply scale your data for you, but it is *always* better to do it yourself!)

So let’s scale our data. We have three variables that need scaling: `age`

(aka the covariate that isn’t categorical) and `active_bins`

(aka the response). For age, we are going to want to measure it as either *years from the youngest monkey* or *years from the average monkey*. I think, in this situation, the first version could make a lot of sense, but we are going with the second. This allows us to interpret as the over-all mean. Otherwise, would tell us about the overall average activity of 4 year old monkeys and we will use to estimate how much the activity changes, on average keeping all other aspects constant, as the monkey ages.

On the other hand, we have no sensible baseline for activity, so deviation from the average seems like a sensible scaling. I also don’t know, *a priori*, how variable activity is going to be, so I might want to scale^{17} it by its standard deviation. In this case, I’m not going to do that because we have a sensible fixed^{18} upper limit (8), which I can scale by.

One important thing here is that if we scale the data by data-dependent quantities (like the minimum, the mean, or the standard deviation) we *must* keep track of this information. This is because *any* future data we try to predict with this model will need to be transformed *the same way using the same*^{19} *numbers*! This particularly has implication when you are doing things like test/training set validation or cross validation: in the first case, the test set needs to be scaled in the same way the training set was; while in the second case each cross validation training set needs to be scaled independently and that scaling needs to be used on the corresponding left-out data^{20}.

```
age_centre <- mean(activity_2mins$age)
age_scale <- diff(range(activity_2mins$age))/2
active_bins_centre <- 4
activity_2mins_scaled <- activity_2mins |>
mutate(monkey = factor(monkey),
day = factor(day),
age_centred = (age - age_centre)/age_scale,
active_bins_scaled = (active_bins - active_bins_centre)/4)
glimpse(activity_2mins_scaled)
```

```
Rows: 485
Columns: 7
$ monkey <fct> 0, 0, 88, 88, 636, 636, 760, 760, 1257, 1257, 1607,…
$ day <fct> 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, …
$ total <dbl> 495, 1003, 2642, 524, 199, 282, 363, 445, 96, 495, …
$ active_bins <int> 6, 6, 8, 6, 2, 3, 3, 4, 3, 8, 6, 5, 3, 3, 6, 5, 8, …
$ age <dbl> 29, 29, 29, 29, 28, 28, 30, 30, 27, 27, 27, 27, 27,…
$ age_centred <dbl> 1.1054718, 1.1054718, 1.1054718, 1.1054718, 1.02854…
$ active_bins_scaled <dbl> 0.50, 0.50, 1.00, 0.50, -0.50, -0.25, -0.25, 0.00, …
```

With our scaling completed, we can now start thinking about prior distributions. The trick with priors is to make them wide enough to cover all plausible values of a parameter without making them so wide that they put a whole bunch of weight on essentially silly values.

We know, for instance, that our unscaled activity will go between 0 and 8. That means that it’s unlikely for the mean of the scaled process to be much bigger than 3 or 4. These considerations, along with the fact that we have centred the data so the mean should be closer to zero, suggest that a prior should be appropriate for .

As we normalised our age data relative to the smallest age, we should think more carefully about the scaling of . Macaques live for 20-30^{21} years, so we need to think about, for instance, an ordinary aged macaque that would be 15 years older than the baseline. Thanks to our scaling, the largest change that we can have is around 1, which strongly suggests that if was too much larger than we are going to be in unreasonable territory. So let’s put a prior^{22} on and . For we can use a prior.

Similarly, the scaling of `activity_bins`

suggests that a prior would be sufficient for the data-level standard deviation .

That just leaves us with our choice of prior for the standard deviation of the intercept^{23} , . Thankfully, we considered this case in detail in the previous blog post. There I argued that a sensible prior for would be an exponential prior. To be quite honest with you, a half-normal or a half-t also would be fine. But I’m going to stick to my guns. For the scaling, again, it would be a touch surprising (given the scaling of the data) if the group means were more than 3 apart, so choosing in the exponential distribution should give a relatively weak prior without being so wide that we are putting prior mass on a bunch of values that we would never actually want to put prior mass on.

We can then fit the model with `brms`

. In this case, I’m using the `cmdstanr`

back end, because it’s fast and I like it.

To specify the model, we use the `lme4`

-style formula notation discussed above.

To set the priors, we will use `brms`

. Now, if you are Paul you might be able to remember how to set priors in `brms`

without having to look it up, but I am sadly not Paul^{24}, so every time I need to set priors in `brms`

I write the formula and use the convenient `get_prior`

function

```
library(cmdstanr)
library(brms)
get_prior(formula, activity_2mins_scaled)
```

From this, we can see that the default prior on is an improper flat prior, the default prior on the intercept is a Student-t with 3 degrees of freedom centred at zero with standard deviation 2.5. The same prior (restricted to positive numbers) is put on all of the standard deviation parameters. These default prior distributions are, to be honest, probably fine in this context^{25}, but it is good practice to always set your prior.

We do this as follows. (Note that `brms`

uses Stan, which parameterises the normal distribution by its mean and *standard deviation*!)

```
priors <- prior(normal(0, 0.2), coef = "age_centred") +
prior(normal(0,0.2), coef = "age_centred:day2") +
prior(normal(0, 1), coef = "day2") +
prior(normal(0,1), class = "sigma") +
prior(exponential(1), class = sd) + # tau
prior(normal(0,1), class = "Intercept")
priors
```

So we have specified some priors using the power of *our thoughts*. But we should probably check to see if they are broadly sensible. A great thing about Bayesian modelling is that we are explicitly specifying our *a priori* (or pre-data) assumptions about the data generating process. That means that we can do a fast validation of our priors by simulating from them and checking that they’re not too wild.

There are lots of ways to do this, but the easiest^{26} way to do this is to use the `sample_prior = "only"`

option in the `brm()`

function.

```
prior_draws <- brm(formula,
data = activity_2mins_scaled,
prior = priors,
sample_prior = "only",
backend = "cmdstanr",
cores = 4,
refresh = 0)
```

`Start sampling`

```
Running MCMC with 4 parallel chains...
Chain 1 finished in 0.7 seconds.
Chain 2 finished in 0.7 seconds.
Chain 3 finished in 0.7 seconds.
Chain 4 finished in 0.7 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.7 seconds.
Total execution time: 0.9 seconds.
```

Now that we have samples from the prior distribution, we can assemble them to work out what our prior tells us we would, pre-data, predict for the number of active bins for a single monkey (in this a single monkey^{27} that is 10 years older than the baseline).

```
pred_data <- data.frame(age_centred = 10, day = 1, monkey = "88")
tibble(pred = brms::posterior_predict(prior_draws,
newdata = pred_data )) |>
ggplot(aes(pred)) +
geom_histogram(aes(y = after_stat(density)), fill = "lightgrey") +
geom_vline(xintercept = -1, linetype = "dashed") +
geom_vline(xintercept = 1, linetype = "dashed") +
xlim(c(-20,20)) +
theme_bw()
```

The vertical lines are (approximately) the minimum and maximum of the data. This^{28} suggests that the implied priors are definitely wider than our observed data, but they are not several orders of magnitude too wide. This is a good situation to be in: it gives enough room in the priors that we might be wrong with our specification while also not allowing for truly wild values of the parameters (and implied predictive distribution). One could even go so far as to say that the prior is weakly informative.

Let’s compare this to the default priors on the standard deviation parameters. (The default priors on the regression parameters are improper so we can’t simulate from them. So I replaced the improper prior with a much narrower prior. If you make the prior on the wider the prior predictive distribution also gets wider.)

```
priors_default <- prior(normal(0,10), class = "b")
prior_draws_default <- brm(formula,
data = activity_2mins_scaled,
prior = priors_default,
sample_prior = "only",
backend = "cmdstanr",
cores = 4,
refresh = 0)
```

```
Running MCMC with 4 parallel chains...
Chain 1 finished in 0.6 seconds.
Chain 2 finished in 0.6 seconds.
Chain 3 finished in 0.6 seconds.
Chain 4 finished in 0.6 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.6 seconds.
Total execution time: 0.8 seconds.
```

```
tibble(pred = brms::posterior_predict(prior_draws_default,
newdata = pred_data )) |>
ggplot(aes(pred)) +
geom_histogram(aes(y = after_stat(density)), fill = "lightgrey") +
geom_vline(xintercept = -1, linetype = "dashed") +
geom_vline(xintercept = 1, linetype = "dashed") +
theme_bw()
```

This is considerably wider.

With all of that in hand, we can now fit the data. Hooray. This is done with the same command (minus the `sample_prior`

bit).

```
posterior_draws <- brm(formula,
data = activity_2mins_scaled,
prior = priors,
backend = "cmdstanr",
cores = 4,
refresh = 0)
```

`Start sampling`

```
Running MCMC with 4 parallel chains...
Chain 1 finished in 1.7 seconds.
Chain 3 finished in 1.8 seconds.
Chain 2 finished in 1.8 seconds.
Chain 4 finished in 1.8 seconds.
All 4 chains finished successfully.
Mean chain execution time: 1.8 seconds.
Total execution time: 2.0 seconds.
```

`posterior_draws`

```
Family: gaussian
Links: mu = identity; sigma = identity
Formula: active_bins_scaled ~ age_centred * day + (1 | monkey)
Data: activity_2mins_scaled (Number of observations: 485)
Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup draws = 4000
Group-Level Effects:
~monkey (Number of levels: 243)
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept) 0.31 0.03 0.25 0.37 1.00 1070 1766
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept -0.04 0.03 -0.11 0.02 1.00 4222 3171
age_centred 0.02 0.07 -0.11 0.14 1.00 3671 3150
day2 0.10 0.04 0.03 0.18 1.00 8022 2911
age_centred:day2 0.07 0.07 -0.08 0.22 1.00 6170 2584
Family Specific Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 0.43 0.02 0.39 0.47 1.00 1613 2430
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
```

There doesn’t seem to be much of an effect of age in this data.

If you’re curious, this matches well^{29} with the output of `lme4`

, which is a nice sense check for simple models. Generally speaking, if they’re the same then they’re both fine. If they are different^{30}, then you’ve got to look deeper.

```
library(lme4)
fit_lme4 <- lmer(formula, activity_2mins_scaled)
fit_lme4
```

```
Linear mixed model fit by REML ['lmerMod']
Formula: active_bins_scaled ~ age_centred * day + (1 | monkey)
Data: activity_2mins_scaled
REML criterion at convergence: 734.9096
Random effects:
Groups Name Std.Dev.
monkey (Intercept) 0.3091
Residual 0.4253
Number of obs: 485, groups: monkey, 243
Fixed Effects:
(Intercept) age_centred day2 age_centred:day2
-0.04114 0.01016 0.10507 0.08507
```

We can also compare the fit using leave-one-out cross validation. This is similar to AIC, but more directly interpretable. It is the average of where is a vector of all of the parameters in the model. The notation is the data *without* the th observation. This average is sometimes called the *expected log predictive density* or elpd.

To compare it with the two linear regression models, I need to fit them in `brms`

. I will use a prior for the monkey intercepts and the same priors as the previous model for the other parameters.

```
priors_lm <- prior(normal(0,1), class = "b") +
prior(normal(0, 0.2), coef = "age_centred") +
prior(normal(0,0.2), coef = "age_centred:day2") +
prior(normal(0, 1), coef = "day2") +
prior(normal(0,1), class = "Intercept") +
prior(normal(0,1), class = "sigma")
posterior_nopool <- brm(
active_bins_scaled ~ age_centred * day + monkey,
data = activity_2mins_scaled,
prior = priors_lm,
backend = "cmdstanr",
cores = 4,
refresh = 0)
```

```
Running MCMC with 4 parallel chains...
Chain 1 finished in 4.5 seconds.
Chain 3 finished in 4.5 seconds.
Chain 2 finished in 4.5 seconds.
Chain 4 finished in 4.5 seconds.
All 4 chains finished successfully.
Mean chain execution time: 4.5 seconds.
Total execution time: 4.7 seconds.
```

```
posterior_pool <- brm(
active_bins_scaled ~ age_centred * day,
data = activity_2mins_scaled,
prior = priors_lm,
backend = "cmdstanr",
cores = 4,
refresh = 0)
```

```
Running MCMC with 4 parallel chains...
Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.
Chain 3 finished in 0.1 seconds.
Chain 4 finished in 0.1 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.3 seconds.
```

We an now use the `loo_compare`

function to compare the models. By default, the best model is listed first and the other models are listed below it with the difference in elpd values given. To do this, we need to tell `brms`

to compute the `loo`

criterion using the `add_criterion`

function.

`posterior_draws <- add_criterion(posterior_draws, "loo")`

```
Warning: Found 2 observations with a pareto_k > 0.7 in model 'posterior_draws'.
It is recommended to set 'moment_match = TRUE' in order to perform moment
matching for problematic observations.
```

`posterior_nopool <- add_criterion(posterior_nopool, "loo")`

```
Warning: Found 63 observations with a pareto_k > 0.7 in model
'posterior_nopool'. It is recommended to set 'moment_match = TRUE' in order to
perform moment matching for problematic observations.
```

```
posterior_pool <- add_criterion(posterior_pool, "loo")
loo_compare(posterior_draws, posterior_nopool, posterior_pool)
```

```
elpd_diff se_diff
posterior_draws 0.0 0.0
posterior_pool -29.0 7.4
posterior_nopool -53.3 9.0
```

There are some warnings there suggesting that we could recompute these using a slower method, but for the purposes of today I’m not going to do that and I shall declare that the multilevel model performs *far better* than the other two models.

Of course, we would be fools to just assume that because we fit a model and compared it to some other models, the model is a good representation of the data. To do that, we need to look at some posterior checks.

The easiest thing to look at is the predictions themselves.

```
fitted <- activity_2mins_scaled |>
cbind(t(posterior_predict(posterior_draws,ndraws = 200))) |>
pivot_longer(8:207, names_to = "draw", values_to = "fitted")
day_labs <- c("Day 1", "Day 2")
names(day_labs) <- c("1", "2")
violin_plot <- fitted |>
ggplot(aes( x=age, y = 4*fitted + active_bins_centre, group = age)) +
geom_violin(colour = "lightgrey") +
geom_point(aes(y = active_bins), colour = "red") +
facet_wrap(~day, labeller = labeller(day = day_labs)) +
theme_bw()
violin_plot
```

That appears to be a reasonably good fit, although it’s possible that the prediction intervals are a bit wide. We can also look at the plot of the posterior residuals vs the fitted values. Here the fitted values are the mean of the posterior predictive distribution.

Next, let’s check for evidence of non-linearity in `age`

.

```
plot_data <- activity_2mins_scaled |>
mutate(fitted_mean = colMeans(posterior_epred(posterior_draws,ndraws = 200)))
age_plot <- plot_data |>
ggplot(aes(x = age, y = active_bins_scaled - fitted_mean)) +
geom_point() +
theme_bw()
age_plot
```

There doesn’t seem to be any obvious evidence of non-linearity in the residuals, which suggests the linear model for age was sufficient.

We can also check the distributional assumption^{31} that the residuals have a Gaussian distribution. We can check this with a qq-plot. Here we are using the posterior mean to define our residuals.

We can look at the qq-plot to see how we’re doing with normality.

```
distribution_plot <- plot_data |> ggplot(aes(sample = (active_bins_scaled - fitted_mean)/sd(active_bins_scaled - fitted_mean))) +
stat_qq() +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
theme_classic()
distribution_plot
```

That’s not too bad. A bit of a deviation from normality in the tails but nothing that would make me weep. It could well be an artifact of how I defined and normalised the residuals.

We can also look at the so-called k-hat plot, which can be useful for finding high-leverage observations in general models.

```
loo_posterior <- LOO(posterior_draws) #warnings suppressed
loo_posterior
```

```
Computed from 4000 by 485 log-likelihood matrix
Estimate SE
elpd_loo -349.8 12.4
p_loo 117.8 5.2
looic 699.7 24.7
------
Monte Carlo SE of elpd_loo is NA.
Pareto k diagnostic values:
Count Pct. Min. n_eff
(-Inf, 0.5] (good) 418 86.2% 902
(0.5, 0.7] (ok) 65 13.4% 443
(0.7, 1] (bad) 1 0.2% 272
(1, Inf) (very bad) 1 0.2% 59
See help('pareto-k-diagnostic') for details.
```

`plot(loo_posterior)`

This suggests that observations 393, 394 are potentially high leverage and we should check them more carefully. I won’t be doing that today.

Finally, let’s look at the residuals vs the fitted values. This is a commonly used diagnostic plot in linear regression and it can be very useful for visually detecting non-linear patterns and heteroskedasticity in the residuals. So let’s make the plot^{32}.

```
problem_plot <- plot_data |>
ggplot(aes(x = fitted_mean, y = active_bins_scaled - fitted_mean)) +
geom_point() +
geom_smooth(method = "lm", se = FALSE, linetype = "dashed", colour = "blue")+
facet_wrap(~day) +
theme_bw() + theme(legend.position="none") +
xlim(c(-1,1)) +
ylim(c(-1,1))
problem_plot
```

Hmmmm. That’s not *excellent*. The stripes are related to the 8 distinct values the response can take, but there is definitely a trend in the residuals. In particular, we are under-predicting small values and over-predicting large values. *There is something here and we will look into it*!

The thing is, multilevel models are notorious for having patterns that are essentially a product of the data design and not of any type of statistical misspecification. In a really great paper that you should all read, Adam Loy, Heike Hofmann, and Di Cook talk extensively about the challenges with interpreting diagnostic plots for linear mixed effects models^{33}.

I’m not going to fully follow their recommendations, mostly because I’m too lazy^{34} to write a for loop, but I am going to appropriate the guts of their idea.

They note that strange patterns can occur in diagnostic plots *even for correctly specified models*. Moreover, we simply do not know what these patters will be. It’s too complex a function of the design, the structure, the data, and the potential misspecification. That sounds bad, but they note that *we don’t need to know what pattern to expect*. Why not? Because we can simulate it!

So this is the idea: Let’s simulate some fake^{35} data from a correctly specified model that otherwise matches with our data. We can then compare the diagnostic plots from the fake data with diagnostic plots from the real data and see if the patterns are meaningfully different.

In order to do this, we should have a method to construct *multiple* fake data sets. Why? Well a plot is nothing but another test statistic and we *must* take this variability into account.

(That said, do what I say, not what I do. This is a blog. I’m not going to code well enough to make this clean and straightforward, so I’m just going to do one.)

There is an entire theory of *visual inference* that uses these lineups of diagnostic plots, where one uses the real data and the rest use realisations of the null data, that is really quite interesting and *well* beyond the scope of this post. But if you want to know more, read the Low, Hoffman, and Cook paper!

The first thing that we need to do is to work out how to simulate fake data from a correctly specified model with the same structure. Following the Low etc paper, I’m going to do a simple parameteric bootstrap, where I take the posterior medians of the fitted distribution and simulate data from them.

That said, there are a bunch of other options. Specifically, we have a whole bag of samples from our posterior distribution and it would be possible to use that to select values of^{36} for our simulation.

So let’s make some fake data and fit the model to it!

```
monkey_effect <- tibble(monkey = unique(activity_2mins_scaled$monkey),
monkey_effect = rnorm(243,0,0.31))
data_fake <- activity_2mins_scaled |>
left_join(monkey_effect, by = "monkey") |>
mutate(active_bins_scaled = rnorm(length(age_centred),
mean = -0.04 +0.01 * age_centred +
monkey_effect + if_else(day == "2", 0.1 + 0.085 *age_centred, 0.0),
sd = 0.43))
posterior_draws_fake <- brm(formula,
data = data_fake,
prior = priors,
backend = "cmdstanr",
cores = 4,
refresh = 0)
```

```
Running MCMC with 4 parallel chains...
Chain 1 finished in 1.6 seconds.
Chain 2 finished in 1.6 seconds.
Chain 3 finished in 1.6 seconds.
Chain 4 finished in 1.6 seconds.
All 4 chains finished successfully.
Mean chain execution time: 1.6 seconds.
Total execution time: 1.8 seconds.
```

First up, let’s look at the violin plot.

```
library(cowplot)
fitted_fake <- data_fake |>
cbind(t(posterior_predict(posterior_draws_fake,ndraws = 200))) |>
pivot_longer(8:207, names_to = "draw", values_to = "fitted")
day_labs <- c("Day 1", "Day 2")
names(day_labs) <- c("1", "2")
violin_fake <- fitted_fake |>
ggplot(aes( x=age, y = 4*fitted + active_bins_centre, group = age)) +
geom_violin(colour = "lightgrey") +
geom_point(aes(y = active_bins), colour = "red") +
facet_wrap(~day, labeller = labeller(day = day_labs)) +
theme_bw()
plot_grid(violin_plot, violin_fake, labels = c("Real", "Fake"))
```

That’s very similar to our data plot.

Next up, we will look at the residuals ordered by age

```
plot_data_fake <- data_fake |>
mutate(fitted_mean = colMeans(posterior_epred(posterior_draws_fake,ndraws = 200)))
age_fake <- plot_data_fake |>
ggplot(aes(x = age, y = active_bins_scaled - fitted_mean)) +
geom_point() +
theme_bw()
plot_grid(age_plot, age_fake, labels = c("Real", "Fake"))
```

Fabulous!

Now let’s check the distributional assumption on the residuals!

```
distribution_fake <- plot_data_fake |>
ggplot(aes(sample = (active_bins_scaled - fitted_mean)/sd(active_bins_scaled - fitted_mean))) +
stat_qq() +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
theme_classic()
plot_grid(distribution_plot, distribution_fake, labels = c("Real", "Fake"))
```

Excellent!

Finally, we can look at the k-hat plot. Because I’m lazy, I’m not going to put them side by side. You can scroll.

`loo_fake <- LOO(posterior_draws_fake)`

```
Warning: Found 4 observations with a pareto_k > 0.7 in model
'posterior_draws_fake'. It is recommended to set 'moment_match = TRUE' in order
to perform moment matching for problematic observations.
```

`loo_fake`

```
Computed from 4000 by 485 log-likelihood matrix
Estimate SE
elpd_loo -372.1 14.9
p_loo 115.4 6.1
looic 744.2 29.7
------
Monte Carlo SE of elpd_loo is NA.
Pareto k diagnostic values:
Count Pct. Min. n_eff
(-Inf, 0.5] (good) 422 87.0% 579
(0.5, 0.7] (ok) 59 12.2% 220
(0.7, 1] (bad) 4 0.8% 118
(1, Inf) (very bad) 0 0.0% <NA>
See help('pareto-k-diagnostic') for details.
```

`plot(loo_fake)`

And look: we get some extreme values. (Depending on the run we get more or less). This suggests that while it would be useful to look at the data points flagged by the k-hat statistic, it may just be sampling variation.;

All of this suggests our model assumptions are not being grossly violated. All except for that residual vs fitted values plot…

Now let’s look at our residual vs fitted plot.

```
problem_fake <- plot_data_fake |>
ggplot(aes(x = fitted_mean, y = active_bins_scaled - fitted_mean)) +
geom_point() +
geom_smooth(method = "lm", se = FALSE, linetype = "dashed", colour = "blue")+
facet_wrap(~day) +
theme_bw() + theme(legend.position="none") +
xlim(c(-1,1)) +
ylim(c(-1,1))
plot_grid(problem_plot, problem_fake, labels = c("Real", "Fake"))
```

And what do you know! They look the same. (Well, minus the discretisation artefacts.)

Great question! It turns out that this is one of those cases where our intuition from linear models *does not* transfer over to multilevel models.

We can actually reason this out by thinking about a model where we have no covariates.

If we have no pooling then the observations for every monkey are, essentially, averaged to get our estimate of . If we repeat this, we will find that our are basically^{37} unbiased and the corresponding residual will have mean zero.

But that’s not what happens when we have partial pooling. When we have partial pooling we are *combining* our naive average^{38} with the global average in a way that accounts for the size of group relative to other groups as well as the within-group variability relative to the between-group variability.

The short version is that there is some magical number , which depends on , , and such that Because of this, the residuals are suddenly *not* going to have mean zero.

In fact, if we think about it a bit more, we will realise that the model will drag extreme groups to the centre, which accounts for the positive slope in the residuals vs the fitted values.

The slope in this example is quite extreme because the groups are very small (only one or two individuals). But it is a general phenomenon and it’s discussed extensively in Chapter 7 of Jim Hodges’ excellent book. His suggestion is that there isn’t really a good, general way to remove the trend. But that doesn’t mean the plot is useless. It is still able to pinpoint outliers and heteroskedasticity. You’ve just got to tilt your head.

But for the purposes of today we can notice that there don’t seem to be any extreme outliers so everything is probably ok.

So what have we done? Well we’ve gone through the process of fitting and scruitinising a simple Bayesian multilevel model. We’ve talked about some of the challenges associated with graphical diagnostics for structured data. And we’ve all^{39} learnt something about the residual-vs-fitted plot for a multilevel model.

Most importantly, we’ve all learnt the value of using fake data simulated from the posterior model to help us understand our diagnostics.

There is more to the scientific story here. It turns out that while there is no effect over 2 minutes, there is a slight effect over 20 minutes. So the conceptual replication failed, but still found some interesting things.

Of course, I’ve ignored one big elephant in the room: That data was discrete. In the end, our distributional diagnostics didn’t throw up any massive red flags, but nevertheless it could be an interesting exercise to see what happens if we use a more problem-adapted likelihood.

Last, and certainly not least, I barely scratched the surface^{40} of the Loy, Hoffman, and Cook paper. Anyone who is interested in fitting Gaussian multilevel models should definitely give it a read.

Mark insisted that I like to his google scholar rather than his website. He’s cute that way.↩︎

Mark wants me to tell you that he’s not vain he’s just moving. Sure Jan.↩︎

I know that marmosets suffer from lesbian bed death, but I’m told that a marmoset is not a macaque, which in turn is not a macaw. Ecology is fascinating.↩︎

A real problem in the world is that there aren’t enough monkeys for animal research at the best of times. Once you need aged monkeys, it’s an even smaller population. Non-human primate research is

*hard*.↩︎Actually 244, but one of them turned out to be blind. Animal research is a journey.↩︎

It turns out that some of the monkeys didn’t want to give up the puzzle after 20 minutes. One held out for 72 minutes before the data collection ended. Cheeky monkeys.↩︎

Did Mark make me do unspeakable, degrading, borderline immoral things to get the data? No. It’s open source. Truly the first time I’ve been disappointed that something was open source.↩︎

If statisticians abandoned linear regression we would have nothing left. We would be desiccated husks propping up the bar at 3am talking about how we used to do loads of lines in the 80s.↩︎

Our perfect amount of pool? I don’t know how metaphors work↩︎

They

*always*take this form if there is a countable collection of exchangeable random variables. For a finite set there are a few more options. But no one talks about those.↩︎monkeys↩︎

Also known as a mixed effects or a linear mixed effects model.↩︎

There are

*many*other ways to represnt Gaussian multilevel models. My former colleague Emi Tanaka and Francis Hui wrote a great paper on this topic.↩︎Some particularly bold and foolish people take this to mean that priors aren’t important. They usually get their arse handed to them the moment they try to fit an even mildly complex model.↩︎

A non-exhaustive set of weird things: categorical regressors with a rare category, tail parameters, mixture models↩︎

There are situations where this is not true. For instance if you have a log or logit link function you can put reasonable bounds on your coefficients regardless of the scaling of your data. That said, the computational procedures

*always*appreciate a bit of scaling. If there’s one thing that computers hate more that big numbers it’s small numbers.↩︎Of course, we know that the there are only 8 fifteen second intervals in two minutes, so we could use this information to make a data-independent scaling. To be brutally francis with you, that’s what you should probably do in this situation, but I’m trying to be pedagogical so let’s at least think about scaling it by the standard deviation.↩︎

Fixed scaling is always easier than data-dependent scaling↩︎

A real trick for young players is scaling new data by the mean and standard deviation of the new data rather than the old data. That’s a very subtle bug that can be

*very*hard to squash.↩︎The

`tidymodels`

package in R is a great example of an ecosystem that does this properly. Max and Julia’s book on using`tidymodels`

is very excellent and well worth a read.↩︎Of all of the things in this post, this has been the most aggressively fact checked one↩︎

In prior width and on grindr, you should always expect that he’s rounding up.↩︎

In some places, we would call this a random effect.↩︎

He is very lovely. Many people would prefer that I was him.↩︎

It’s possible the the prior on might be too wide. If we were doing a logistic regression, these priors would definitely be too wide. And if we had a lot of different random terms (eg if we had lots of different species or lots of different labs) then they would also probably be too wide. But they are better than not having priors.↩︎

Not the most computationally efficient, but the easiest. Also because it’s the same code we will later use to fit the model, we are evaluating the priors that are actually used and not the ones that we think we’re using.↩︎

It’s number 88, but because our prior is exchangeable it does not matter which monkey we do this for!↩︎

I also checked different values of

`age`

as well as looking at the posterior mean (via`posterior_epred`

) and the conclusions stay the same.↩︎The numbers will never be exactly equal, but they are of similar orders of magnitude.↩︎

Or if you get some sort of error or warning from

`lme4`

↩︎So there’s a wrinkle here. Technically, all of the residuals have different variances, which is annoying. You typically studentise them using the leverage scores, but this is a touch trickier for multilevel models. Chapter 7 of Jim Hodges’s excellent book contains a really good discussion.↩︎

Once again, we are not studentizing the residuals. I’m sorry.↩︎

Another name for a multilevel model with a Gaussian response↩︎

Also because all of my data plots are gonna be stripey as hell, and that kinda destroys the point of visual inference.↩︎

They call it

*null data*.↩︎Note that I am

*not*using values of ! I will simulate those from the normal distribution to ensure correct model specification. For the same reason, I am not using a residual bootstrap. The aim here is not to assess uncertainty so much as it is to ↩︎This is a bit more complex when you’re Bayesian, but the intuition still holds. The difference is that now it is asymptotic↩︎

This is the average of all observations in group j. ↩︎

I mean, some of us knew this. Personally, I only remembered after I saw it and swore a bit.↩︎

In particular, they have an interesting discussion on assessing the distributional assumption for .↩︎

BibTeX citation:

```
@online{simpson2022,
author = {Dan Simpson},
editor = {},
title = {A First Look at Multilevel Regression; or {Everybody’s} Got
Something to Hide Except Me and My Macaques},
date = {2022-09-06},
url = {https://dansblog.netlify.app/2022-09-04-everybodys-got-something-to-hide-except-me-and-my-monkey.html},
langid = {en}
}
```

For attribution, please cite this work as:

Dan Simpson. 2022. “A First Look at Multilevel Regression; or
Everybody’s Got Something to Hide Except Me and My Macaques.”
September 6, 2022. https://dansblog.netlify.app/2022-09-04-everybodys-got-something-to-hide-except-me-and-my-monkey.html.

I am suddenly^{1} of a mood to write some more on this^{2} topic.

The thing is, so far I’ve only really talked about methods for setting prior distributions that I don’t particularly care for. Fuck that. Let’s talk about things I like. There is enough negative energy^{3} in the world.

So let’s talk about priors. But the good stuff. The aim is to give my answer to the question “how should you set a prior distribution?”.

You don’t. No one does. They’re not real.

Parameters are polite fictions that we use to get through the day. They’re our weapons of mass destruction. They’re the magazines we only bought for the articles. They are our girlfriends who live in Canada^{4}.

One way we can see this is to ask ourselves a simple^{5}:

The answer^{6} ^{7} would be two.

But let me ask a different question. How many parameters are there in this model^{8}

One answer to this question would be . In this interpretation of the question everything in the model that isn’t directly observed is a parameter.

But there is another view.

Mathematically, these two models are equivalent. That is, if you marginalise^{9} out the you get This is *exactly* the negative binomial distribution with mean and variance .

So maybe there are two parameters.

Does it make a difference? Sometimes. For instance, if you were following ordinary practice in Bayesian machine learning, you would (approximately) marginalise out in the first model, but in the second model you’d probably treat them as tuning hyper-parameters^{10} in the second and optimise^{11} over them.

Moreover, in the second model we can ask *what other priors could we put on the* *?*. There is no equivalent question for the first model. This could be useful, for instance, if we believe that the overdispersion may differ among population groups. It is considerably easier to extend the random effects formulation into a multilevel model.

Ok. So it doesn’t really matter too much. It really depends on what you’re going to do with the model when you’re breaking your model into *things that we need to set priors for* and *things where the priors are a structural part of the model*.

There are a lot of ways to set prior distributions. I’ve covered some in previous posts and there are certainly more. But today I’m going to focus on one constructive method that I’m particular fond of: penalised complexity priors.

These priors fall out from a certain way of seeing parameters. The idea is that some parameters in a model function as *flexibility parameters*. These naturally have a base value, which corresponds to the simplest model that they index. I’ll refer to the distribution you get when the parameter takes its base value as the *base model*.

**Example 1 (Overdispersion of a negative binomial) **The negative binomial distribution has two parameters: a mean and an overdispersion parameter so the variance is . The mean parameter is *not* a flexibility parameter. Conceptually, changing the mean^{12} does not make a distribution more or less complex, it simply shuttles it around.

On the other hand, the overdispersion parameter *is* a flexibility parameter. It’s special value is , which corresponds to a Poisson distribution, which is the base model for the negative binomial distribution.

**Example 2 (Student-t degrees of freedom) **The three parameter student-t distribution has density (parameterised by its standard deviation assuming !) This has mean and variance . The slightly strange parameterisation and the restriction to is useful because it lets us specify a prior on the *variance* itself and not some parameter that is the variance divided by some function^{13} of .

The natural base model here is , which corresponds to .

**Example 3 (Variance of a Gaussian random effect) **A Gaussian distribution has two parameters: a mean and a standard deviation . Once again, is not a flexibility parameter, but in some circumstances can be.

To see this, imagine that we have a simple random intercept model In this case, we don’t really view as a flexibility parameter, but is. Why the distinction? Well let’s think about what happens at special value .

When we are saying that there is no variability in the data if we know the corresponding . This is, frankly, quite weird and it’s not necessarily a base model we would believe^{14} in.

On the other hand, if , then we are say that all of the groups have the same mean. This is a useful and interesting base model that could absolutely happen in most data. So we say that while isn’t necessarily a flexibility parameter in the model, definitely is.

In this case the base model is the degenerate distribution^{15} where the mean of each group is equal to .

The second example shows that the idea of a flexibility parameter is deeply contextual. Once again, we run into the idea that Statistical Arianism^{16} is bad. *Parameters and their prior distributions can only be fully understood if you know their context within the entire model.*

Now that we have the concept of a flexibility parameter, let’s think about how we should use it. In particular, we should ask exactly what we want our prior to do. In the paper we listed 8 things that we want the prior to do:

- The prior should contain information
^{17}^{18}^{19} - The prior should be aware of model structure
- If we move our model to a new application, it should be clear how we can change the information contained in our prior. We can do this by
*explicitly*including specific information in the prior. - The prior should limit
^{20}the flexibility of an overparameterised model - Restrictions of the prior to identifiable sub-manifolds
^{21}of the parameter space should be sensible. - The prior should be specified to control what a parameter
*does*in the context^{22}of the model (rather than its numerical value) - The prior should be computationally
^{23}feasible - The prior should perform well
^{24}.

These desiderata are *aspirational* and I in no way claim that we successfully satisfied them. But we tried. And we came up with a pretty useful proposal.

The idea is simple: if our model has a flexibility parameter we should put a prior on it that *penalises the complexity* of the model. That is, we want most of the prior mass to be near^{25} the base value.

In practice, we try to do this by penalising the complexity of each *component* of a model. For instance, consider the following model for a flexible regression: The exact definition^{26} of a smoothing spline that we are using is not wildly important, but it is specified^{27} by a smoothing parameter , and when we get our base model (a function that is equal to zero everywhere). This model has two components ( and ) and they each have one smoothing parameter (, with base model at , and , with base model at ).

The nice thing about splitting a model up into components and building priors for each component is that we can build generic priors for each component that can be potentially be tuned to make them appropriate for the global model. Is this a perfect way to realise our second aim? No. But it’s an ok place to start^{28}.

Ok. So you’re Brad Pitt. Wait. No.

Ok. So we need to build a prior that penalises complexity by putting most of its prior mass near the base model. In order to do this we need to first specify what we mean by *near*.

There are *a lot* of things that we could mean. The easiest choice would be to just use the natural distance from the base model in the parameter space. But this isn’t necessarily a good idea. Firstly, it falls flat when the base model is at infinity. But more importantly, it violates our 6th aim by ignoring the context of the parameter and just setting a prior on its numerical value.

So instead we are going to parameterise distance by asking ourselves a simple question: for a component with flexibility parameter , how much more complex would our model component be if we used the value instead of the base value ?

We can measure this complexity using the Kullback-Leibler divergence (or KL divergence if you’re nasty) This is a quantity from information theory that directly measures how much information would be lost^{29} if we replaced the more complex model with the simpler model . The more information that would be lost, the more complex is relative to .

While the Kullback-Leibler divergence looks a bit intimidating the first time you see it, it’s got a lot of nice properties:

It’s always non-negative.

It doesn’t depend on how you parameterise the distribution. If you do a smooth, invertible change of variables to both distribution the KL divergence remains unchanged.

It’s related to the information matrix and the Fisher distance. In particular, let be a family of distributions parameterised by . Then, near , where is the Fisher information. The quantity on the right hand side is the square of a distance from the base model.

It can be related to the total variation distance

^{30}

But it also has some less charming properties:

- The KL divergence is
*not*a distance! - The KL divergence is
*not*symmetric, that is

The first of these properties is irrelevant to us. The second interesting. I’d argue that it is an advantage. We can think in an analogy: if your base model is a point at the bottom of a valley, there is a big practical difference between how much effort it takes to get from the base model to another model that is on top of a hill compared to the amount of effort it takes to go in the other direction. This type of asymmetry is relevant to us: it’s easier for data to tell a simple model that it should be more complex than it is to tell a complex model to be simpler. We want our prior information to somewhat even this out, so we put less prior mass on models that are more complex and more on models that are more complex.

There is one more little annoyance: if you look at the two distance measures that the KL divergence is related to, you’ll notice that in both cases, the KL divergence is related to the *square* of the distance and not the distance itself.

If we use the KL divergence itself as a distance proxy, it will increase too sharply^{31} and we may end up over-penalising. To that end, we use the following “distance” measure If you’re wondering about that 2, it doesn’t really matter but it makes a couple of things ever so slightly cleaner down the road.

Ok. Let’s compute some of these distances!

**Example 4 (Overdispersion of a negative binomial (continued)) **The negative binomial distribution is discrete so This has two problems: I can’t work out what it is and it might^{32} end up depending on .

Thankfully we can use our alternative representation of the negative binomial to note that and so we could just as well consider the model component that we want to penalise the complexity of. In this case we need the KL divergence^{33} between Gamma distributions where is the digamma function.

As , the KL divergence becomes^{34}

Now, you will notice that as the KL divergence heads off to infinity. This happens a lot when the base model is much simpler than the flexible model. Thankfully, we will see later that we can ignore the factor of and get a PC prior that’s valid against the base model for *all* sufficiently small . This is not legally the same thing as having one for , but it is morally the same.

With this, we get

If the digamma function is a bit too hardcore for you, the approximation gives the approximate distance That is, the distance we are using is approximately the *standard deviation* of .

Let’s see if this approximation^{35} is any good.

```
library(tidyverse)
tibble(alpha = seq(0.01, 20, length.out = 1000),
exact = sqrt(2*log(1/alpha) - 2*digamma(1/alpha)),
approx = sqrt(alpha)
) |>
ggplot(aes(x = alpha, y = exact)) +
geom_line(colour = "red") +
geom_line(aes(y = approx), colour = "blue", linetype = "dashed") +
theme_bw()
```

It’s ok but it’s not perfect.

**Example 5 (Student-t degrees of freedom (Continued)) **In our original paper, we computed the distance for the degrees of freedom numerically. However, Yongqiang Tang derived an analytic expression for it.

If we note that we can use this (and the above asymptotic expansion of the digamma function) to get We can use the same asymptotic approximations as above to get

Let’s check this approximation numerically.

```
tibble(nu = seq(2.1, 300, length.out = 1000),
exact = sqrt(1 + log(2/(nu-2)) +
2*lgamma((nu+1)/2) - 2*lgamma(nu/2) -
(nu + 1)* (digamma((nu+1)/2)-
digamma(nu/2))),
approx = sqrt(log(nu^2/((nu-2)*(nu+1))) - (nu+2)/(3*nu*(nu+1)))
) |>
ggplot(aes(x = nu, y = exact)) +
geom_line(colour = "red") +
geom_line(aes(y = approx), colour = "blue", linetype = "dashed") +
theme_bw()
```

Once again, this is not a terrible approximation, but it’s also not an excellent one.

**Example 6 (Variance of a Gaussian random effect (Continued)) **The distance calculation for the standard deviation of a Gaussian random effect has a very similar structure to the negative binomial case. We note, via wikipedia, that

This implies that We shall see later that the scaling on the doesn’t matter so for all intents and purposed

So now that we have a distance measure, we need to turn it into a prior. There are lots of ways we can do this. Essentially any prior we put on the distance can be transformed into a prior on the flexibility parameter . We do this through the change of variables formula where is the prior density for the distance parameterisation

But which prior should we use on the distance? A good default choice is a prior that penalises at a constant rate. That is, we want for some . This condition says that the rate at which the density decreases does not change as we move through the parameter space. This is extremely useful because any other (monotone) distribution is going to have a point at which the bulk changes to the tail. As we are putting our prior on , we won’t necessarily be able to reason about this point.

Constant-rate penalisation implies that the prior on the distance scale is an exponential distribution and, hence, we get our generic PC prior for a flexibility parameter

**Example 7 (Overdispersion of a negative binomial (continued)) **The exact PC prior for the overdispersion parameter in the negative binomial distribution is where is the derivative of the digamma function.

On the other hand, if we use the approximate distance we get

```
lambda <- 1
dat <- tibble(alpha = seq(0.01, 20, length.out = 1000),
exact = lambda / alpha^2 * abs(trigamma(1/alpha) - alpha)/
sqrt(2*log(1/alpha) -
2*digamma(1/alpha))*
exp(-lambda*sqrt(2*log(1/alpha) -
2*digamma(1/alpha))),
approx = lambda/(2*sqrt(alpha))*exp(-lambda*sqrt(alpha))
)
dat |>
ggplot(aes(x = alpha, y = exact)) +
geom_line(colour = "red") +
geom_line(aes(y = approx), colour = "blue", linetype = "dashed") +
theme_bw()
```

```
dat |>
ggplot(aes(x = alpha, y = exact - approx)) +
geom_line(colour = "black") +
theme_bw()
```

That’s a pretty good agreement!

**Example 8 (Student-t degrees of freedom (Continued)) **An interesting feature of the PC prior (and any prior where the density on the distance scale takes its maximum at the base model) is that the implied prior on has no finite moments. In fact, if your prior on has finite moments, the density on the distance scale is zero at zero!

The exact PC prior for the degrees of freedom in a Student-t distribution is where is given above.

The approximate PC prior is Let’s look at the difference.

```
dist_ex <- \(nu) sqrt(1 + log(2/(nu-2)) +
2*lgamma((nu+1)/2) - 2*lgamma(nu/2) -
(nu + 1)* (digamma((nu+1)/2)-
digamma(nu/2)))
dist_ap <- \(nu) sqrt(log(nu^2/((nu-2)*(nu+1))) - (nu+2)/(3*nu*(nu+1)))
lambda <- 1
dat <- tibble(nu = seq(2.1, 30, length.out = 1000),
exact = lambda * (1/(nu-2) + (nu+1)/2 * (trigamma((nu+1)/2) - trigamma(nu/2)))/(4*dist_ex(nu)) * exp(-lambda*dist_ex(nu)),
approx = lambda * (nu*(nu+2)*(2*nu + 9) + 4)/(3*nu^2*(nu+1)^2*(nu-2)) * exp(-lambda*dist_ap(nu))
)
dat |>
ggplot(aes(x = nu, y = exact)) +
geom_line(colour = "red") +
geom_line(aes(y = approx), colour = "blue", linetype = "dashed") +
theme_bw()
```

```
dat |>
ggplot(aes(x = nu, y = exact - approx)) +
geom_line(colour = "black") +
theme_bw()
```

The approximate prior isn’t so good for near 2. In the original paper, the distance was tabulated for and a different high-precision asymptotic expansion was given for .

In the original paper, we also plotted some common priors for the degrees of freedom on the distance scale to show just how informative flat-ish priors on can be! Note that the wider the uniform prior on is the more informative it is on the distance scale.

**Example 9 (Variance of a Gaussian random effect (Continued)) **This is the easy one because the distance is equal to the standard deviation! The PC prior for the standard deviation of a Gaussian distribution is an exponential prior More generally, if is a multivariate normal distribution, than the PC prior for is still The corresponding prior on is Sometimes, for instance if you’re converting a model from BUGS or you’re looking at the smoothing parameter of a smoothing spline, you might specify your normal distribution in terms of the precision, which is the inverse of the variance. If , then the corresponding PC prior (using the change of variables ) is

This case was explored extensively in the context of structured additive regression models (think GAMs but moreso) by Klein and Kneib, who found that the choice of exponential prior on the distance scale gave more consistent performance than either a half-normal or a half-Cauchy distribution.

The big unanswered question is how do we choose . The scaling of a prior distribution is *vital* to its success, so this is an important question.

And I will just say this: work it out your damn self.

The thing about prior distributions that shamelessly include information is that, at some point, you need to include^{36} some information. And there is no way for anyone other than the data analyst to know what the information to include is.

But I can outline a general procedure.

Imagine that for your flexibility parameter you have some interpretable transformation of it . For instance if , then a good choice for would be . This is because standard deviations are on the same scale as the observations^{37}, and we have intuition about that happens one standard deviation from the mean.

We then use problem-specific information can help us set a natural scale for . We do this by choosing so that for some , which we would consider large^{38} for our problem, and .

From the properties of the exponential distribution, we can see that we can satisfy this if we choose This can be found numerically if it needs to be.

The simplest case is the standard deviation of the normal distribution, because in this case and . In general, if and is not a correlation matrix, you should take into account the diagonal of when choosing . For instance, choosing to be the geometric mean^{39} of the marginal variances of the is a good idea!

When a model has more than one component, or a component has more than one flexibility parameter, it can be the case that depends on multiple parameters. For instance, if I hadn’t reparameterised the Student-t distribution to have variance independent of , a PC prior on would have a quantity of interest that depends on . We will also see this if I ever get around to writing about priors for Gaussian processes.

Thus we can put together a PC prior as the unique prior that follows the following four principles:

Occam’s razor: We have a base model that represents simplicity and we prefer our base model.

Measuring complexity: We define the prior using the square root of the KL divergence between the base model and the more flexible model. The square root ensures that the divergence is on a similar scale to a distance, but we maintain the asymmetry of the divergence as as a feature (not a bug).

Constant penalisation: We use an exponential prior on the distance scale to ensure that our prior mass decreases evenly as we move father away from the base model.

User-defined scaling: We need the user to specify a quantity of interest and a scale . We choose the scaling of the prior so that . This ensures that when we move to a new context, we are able to modify the prior by using the relevant information about .

These four principles define a PC prior. I think the value of laying them out explicitly is that users and critics can clearly and cleanly identify if these principles are relevant to their problem and, if they are, they can implement them. Furthermore, if you need to modify the principles (say by choosing a different distance measure), there is a clear way to do that.

I’ve come to the end of my energy for this blog post, so I’m going to try to wrap it up. I will write more on the topic later, but for now there are a couple of things I want to say.

These priors can seem quite complex, but I assure you that are a) useful, b) used, and c) not too terrible in practice. Why? Well fundamentally because you usually don’t have to derive them yourselves. Moreover, a lot of that complexity is the price we pay for dealing with densities. We think that this is worth it and the lesson that the parameterisation that you are given may not be the correct parameterisation to use when specifying your prior is an important one!

The original paper contains a bunch of other examples. The paper was discussed and we wrote a rejoinder^{40}, which contains an out-of-date list of other PC priors people have derived. If you are interested in some other people’s views of this idea, a good place to start is the discussion of the original paper.

There are also PC priors for Gaussian Processes, disease mapping models, AR(p) processes, variance parameters in multilevel models, and many more applications.

PC priors are all over the INLA software package and its documentation contains a bunch more examples.

Try them out. They’ll make you happy.

I’ve not turned on my computer for six weeks and tbh I finished 3 games and I’m caught up on TV and the weather is shite.↩︎

“But what about sparse matrices?!” exactly 3 people ask. I’ll get back to them. But this is what I’m feeling today.↩︎

I am told my Mercury is in Libra and truly I am not living that with those posts. Maybe Mercury was in Gatorade when I wrote them. So if we can’t be balanced at least let’s like things.↩︎

Our weapons of ass destruction that lives in Canada?↩︎

Negative binomial parameterised by mean and overdispersion so that its mean is and the variance is because we are not flipping fucking coins here↩︎

Hello and welcome to Statistics for Stupid Children. My name is Daniel and I will be your host today.↩︎

If we didn’t have stupid children we’d never get dumb adults and then who would fuck me? You? You don’t have that sort of time. You’ve got a mortgage to service and interest rates are going up. You’ve got your Warhammer collection and it is simply not going to paint itself. You’ve been meaning to learn how to cook Thai food. You simply do not have the time. (I’m on SSRIs so it’s never clear what will come first: the inevitable decay and death of you and your children and your children’s children; the interest, eventual disinterest, and inevitable death of the family archivist from the far future who digs up your name from the digital graveyard; the death of the final person who will ever think of you, thereby removing you from the mortal realm entirely; the death of the universe; or me. Fucking me is a real time commitment.)↩︎

Gamma is parameterised by shape and rate, so has mean 1 and variance .↩︎

integrate↩︎

Sometimes, people still refer to these as

*hyperparameters*and put priors on them, which would clarify things, but like everything in statistics there’s no real agreed upon usage. Because why would anyone want that?↩︎somehow↩︎

location parameter↩︎

This is critical: we

*do not know*so the only way we can put a sensible prior on the scaling parameter is if we disentangle the role of these two parameters!↩︎In fact, if my model estimated the data-level variance to be nearly zero I would assume I’ve fucked something up elsewhere and my model is either over-fitting or I have a redundancy in my model (like if ).↩︎

There are some mathematical peculiarities that we will run into later when the base model is singular. But they’re not too bad.↩︎

The Arianist heresy is that God, Jesus, and the Holy Spirit are three separate beings rather than consubstantial. It’s the reason for that bit of the Nicene. The statistical version most commonly occurs when you consider you model for your data conditional on the parameters (you likelihood) and your model for the parameters (your prior) as separate objects. This can lead to really dumb priors and bad inferences.↩︎

Complaining that a prior is adding information is like someone complaining to you that his boyfriend has stopped fucking him and you subsequently discovering that this is because his boyfriend died a few weeks ago. Like I’m sorry Jonathan, I know even the sight of a traffic cone sets your bussy a-quiverin’, but there really are bigger concerns and I’m gonna need you to focus.↩︎

In this story, the bigger concerns are things like misspecification, incorrect assumptions, data problems etc etc, the traffic cone is an unbiased estimator, Jonathan is our stand in for a generic data analyst, and Jonathan’s bussy is said data scientist’s bussy.↩︎

Yes, I know that there are problems with giving my generic data analyst a male name. Did I carefully think through the gender and power dynamics in my bussy simile? I think the answer to that is obvious.↩︎

We use priors for the same reason that other people use penalties: we don’t want to go into a weird corner of our model space

*unless*our data explicitly drags us there↩︎This is a bit technical. When a model is over-parameterised, it’s not always possible to recover all of the parameters. So we ideally want to make sure that if there are bunch of asymptotically equivalent parameters, our prior operates sensibly on that set. An example of this will come in a future post where I’ll talk about priors for the parameters of a Gaussian process.↩︎

That Arianism thing creeping in again!↩︎

There are examples of theoretically motivated priors where it’s wildly expensive to compute their densities. We will see one in a later post about GPs.↩︎

Sure, Jan. Of course we want that. But we believed that it was important to include this in a list of desiderata because we

*never*want to say “our prior has motivation X and therefore it is good”. It is not enough to be pure, you actually have to work.↩︎What do I mean by near? Read on McDuff.↩︎

Think of it as a P-spline if you must. The the important thing is that the weights of the basis functions are jointly normal with mean zero and precision matrix .↩︎

Given the knots, which are fixed↩︎

I might talk about more advanced solutions at some point.↩︎

Strictly how many bits would we need ↩︎

The largest absolute difference between the probability that an event happens under and .↩︎

When performing the battered sav, it’s important to not speed up too quickly lest you over-batter.↩︎

It also might not. I don’t care to work it out.↩︎

The “easy” way to get this is to use the fact that the Gamma is in the exponential family and use the general formula for KL divergences in exponential families. The easier way is to look it up on Wikipedia↩︎

Using asymptotic expansions for the log of a Gamma function at infinity↩︎

I’ll be dead before I declare that something is an approximation without bloody checking how good it is.↩︎

We have already included information that is a flexibility parameter with base model , but that is model-specific information. Now we move on to

*problem*specific information.↩︎the have the same units↩︎

Same thing happens if we want a particular quantity not to be too small, just swap the signs↩︎

Always average on the natural scale. For non-negative parameters geometric means make a lot more sense than arithmetic means!↩︎

Homosexually titled

*You just keep on pushing my love over the borderline: a rejoinder*. I’m still not sure how I got away with that.↩︎

BibTeX citation:

```
@online{simpson2022,
author = {Dan Simpson},
editor = {},
title = {Priors Part 4: {Specifying} Priors That Appropriately
Penalise Complexity},
date = {2022-09-03},
url = {https://dansblog.netlify.app/2022-08-29-priors4/2022-08-29-priors4.html},
langid = {en}
}
```

For attribution, please cite this work as:

Dan Simpson. 2022. “Priors Part 4: Specifying Priors That
Appropriately Penalise Complexity.” September 3, 2022. https://dansblog.netlify.app/2022-08-29-priors4/2022-08-29-priors4.html.

A very simple way to get an decent estimate of is to use *importance sampling*, that is taking draws , from some proposal distribution . Then, noting that we can use Monte Carlo to estimate the second integral. This leads to the importance sampling estimator

This all seems marvellous, but there is a problem. Even though is probably a very pleasant function and is a nice friendly distribution, can be an absolute beast. Why? Well it’s^{1} the ratio of two densities and there is no guarantee that the ratio of two nice functions is itself a nice function. In particular, if the bulk of the distributions and are in different places, you’ll end up with the situation where for most draws is very small^{2} and a few will be HUGE^{3}.

This will lead to an extremely unstable estimator.

It is pretty well known that the raw importance sampler will behave nicely (that is will be unbiased with finite variance) precisely when the distribution of has finite variance.

Elementary treatments stop there, but they miss two very big problems. The most obvious one is that it’s basically impossible to check if the variance of is finite. A second, much larger but much more subtle problem, is that the variance can be finite but *massive*. This is probably the most common case in high dimensions. McKay has an excellent example where the importance ratios are *bounded*, but that bound is so large that it is infinite for all intents and purposes.

All of which is to say that importance sampling doesn’t work unless you work on it.

If the problem is the fucking ratios then by gum we will fix the fucking ratios. Or so the saying goes.

The trick turns out to be modifying the largest ratios enough that we stabilise the variance, but not so much as to overly bias the estimate.

The first version of this was truncated importance sampling (TIS), which selects a threshold and estimates the expectation as It’s pretty obvious that has finite variance for any fixed , but we should be pretty worried about the bias. Unsurprisingly, there is going to be a trade-off between the variance and the bias. So let’s explore that.

To get an expression for the bias, first let us write and for . Occasionally we will talk about the joint distribution or . Sometimes we will also need to use the indicator variables .

Then, we can write^{4}

How does this related to TIS? Well. Let be the random variable denoting the number of times . Then,

Hence the bias in TIS is

To be honest, this doesn’t look phenomenally interesting for fixed , however if we let depend on the sample size then as long as we get vanishing bias.

We can get more specific if we make the assumption about the tail of the importance ratios. In particular, we will assume that^{5} for some^{6} .

While it seems like this will only be useful for estimating , it turns out that under some mild^{7} technical conditions, the conditional excess distribution function^{8} is well approximated by a Generalised Pareto Distribution as . Or, in maths, as , for some and . The shape^{9} parameter is very important for us, as it tells us how many moments the distribution has. In particular, if a distribution has shape parameter , then We will focus exclusively on the case where . When , the distribution has finite variance.

If , then the conditional exceedence function is which suggests that as , converges to a generalised Pareto distribution with shape parameter and scale parameter .

All of this work lets us approximate the distribution of and use the formula for the mean of a generalised Pareto distribution. This gives us the estimate which estimates the bias when is constant^{10} as

For what it’s worth, Ionides got the same result more directly in the TIS paper, but he wasn’t trying to do what I’m trying to do.

The variance is a little bit more annoying. We want it to go to zero.

As before, we condition on (or, equivalently, ) and then use the law of total variance. We know from the bias calculation that

A similarly quick calculation tells us that To close it out, we recall that is the sum of Bernoulli random variables so

With this, we can get an expression for the unconditional variance. To simplify the expression, let’s write . Then,

There are four terms in the variance. The first and third terms are clearly harmless: they go to zero no matter how we choose . Our problem terms are the second and fourth. We can tame the fourth term if we choose . But that doesn’t seem to help with the second term. But it turns out it is enough. To see this, we note that where the second inequality uses the fact that and the third comes from the law of total probability.

So the TIS estimator has vanishing bias and variance as long as the truncation and . Once again, this is in the TIS paper, where it is proved in a much more compact way.

It can also be useful to have an understanding of how wild the fluctuations are. For traditional importance sampling, we know that if is finite, then then the fluctuations are, asymptotically, normally distributed with mean zero. Non-asymptotic results were given by Chatterjee and Diaconis that also hold even when the estimator has infinite variance.

For TIS, it’s pretty obvious that for fixed and , will be asymptotically normal (it is, after all, the sum of bounded random variables). For growing sequences it’s a tiny bit more involved: it is now a triangular array^{11} rather than a sequence of random variables. But in the end very classical results tell us that for bounded^{12} , the fluctuations of the TIS estimator are asymptotically normal.

It’s worth saying that when is unbounded, it *might* be necessary to truncate the product rather than just . This is especially relevant if grows rapidly with . Personally, I can’t think of a case where this happens: usually grows (super-)exponentially in while usually grows polynomially, which implies grows (poly-)logarithmically.

The other important edge case is that when can be both positive and negative, it might be necessary to truncate both above *and* below.

TIS has lovely theoretical properties, but it’s a bit challenging to use in practice. The problem is, there’s really no practical guidance on how to choose the truncation sequence.

So let’s do this differently. What if instead of specifying a threshold directly, we instead decided that the largest values are potentially problematic and should be modified? Recall that for TIS, the number of samples that exceeded the threshold, , was random while the threshold was fixed. This is the opposite situation: the number of exceedences is fixed but the threshold is random.

The threshold is now the th largest value of . We denote this using order statistics notation: we re-order the sample so that With this notation, the threshold is and the Winsorized importance sampler (WIS) is where are the pairs *ordered* so that . Note that are not necessarily in increasing order: they are known as *concomitants* of , which is just a fancy way to say that they’re along for the ride. It’s *very* important that we reorder the when we reorder the , otherwise we won’t preserve the joint distribution and we’ll end up with absolute rubbish.

We can already see that this is both much nicer and much wilder than the TIS distribution. It is *convenient* that is no longer random! But what the hell are we going to do about those order statistics? Well, the answer is very much the same thing as before: condition on them and hope for the best.

Conditioned on the event^{13} , we get From this, we get that the bias, conditional on is

You should immediately notice that we are in quite a different situation from TIS, where only the tail contributed to the bias. By fixing and randomising the threshold, we have bias contributions from both the bulk (due, essentially, to a weighting error) and from the tail (due to both the weighting error and the truncation). This is going to require us to be a bit creative.

We could probably do something more subtle and clever here, but that is not my way. Instead, let’s use the triangle inequality to say and so the first term in the bias can be bounded if we can bound the relative error

Now the more sensible among you will say *Daniel, No! That’s a ratio! That’s going to be hard to bound*. And, of course, you are right. But here’s the thing: if is small relative to , it is *tremendously* unlikely that is anywhere near zero. This is intuitively true, but also mathematically true.

To attack this expectation, we are going to look at a slightly different quantity that has the good grace of being non-negative.

**Lemma 1 **Let , be an iid sample from , let be an integer. Then and where is an F-distributed random variable with parameters .

*Proof*. For any , where is the incomplete Beta function.

You could, quite reasonably, ask where the hell that incomplete Beta function came from. And if I had thought to look this up, I would say that it came from Equation 2.1.5 in David and Nagaraja’s book on order statistics. Unfortunately, I did not look this up. I derived it, which is honestly not very difficult. The trick is to basically note that the event is the same as the event that at least of the samples are less than or equal to . Because the are independent, this is the probability of observing at least heads from a coin with the probability of a head . If you look this up on Wikipedia^{14} you see^{15} that it is . The rest just come from noting that and using the symmetry .

To finish this off, we note that From which, we see that

The second result follows the same way and by noting that is also F-distributed with parameters .

*The proof has ended*

Now, obviously, in this house we do not trust mathematics. Which is to say that I made a stupid mistake the first time I did this and forgot that when is binomial, and had a persistent off-by-one error in my derivation. But we test out our results so we don’t end up doing the dumb thing.

So let’s do that. For this example, we will use generalised Pareto-distributed .

```
library(tidyverse)
xi <- 0.7
s <- 2
u <- 4
samp <- function(S, k, p,
Q = \(x) u + s*((1-x)^(-xi)-1)/xi,
F = \(x) 1 - (1 + xi*(x - u)/s)^(-1/xi)) {
# Use theory to draw x_{k:S}
xk <- Q(rbeta(1, k, S - k + 1))
c(1 - p / F(xk), 1-(1-p)/(1-F(xk)))
}
S <- 1000
M <- 50
k <- S - M + 1
p <- 1-M/S
N <- 100000
fs <- rf(N, 2 * (S - k + 1), 2 * k )
tibble(theoretical = 1-p - p * fs * (S - k + 1)/k,
xks = map_dbl(1:N, \(x) samp(S, k, p)[1])) %>%
ggplot() + stat_ecdf(aes(x = xks), colour = "black") +
stat_ecdf(aes(x = theoretical), colour = "red", linetype = "dashed") +
ggtitle(expression(1 - frac(1-M/S , R(r[S-M+1:S]))))
```

```
tibble(theoretical = p - (1-p) * k/(fs * (S - k + 1)),
xks = map_dbl(1:N, \(x) samp(S, k, p)[2])) %>%
ggplot() + stat_ecdf(aes(x = xks), colour = "black") +
stat_ecdf(aes(x = theoretical), colour = "red", linetype = "dashed") +
ggtitle(expression(1 - frac(M/S , 1-R(r[S-M+1:S]))))
```

Fabulous. It follow then that where has an F-distribution with degrees of freedom. As , it follows that this term goes to zero as long as . This shows that the first term in the bias goes to zero.

It’s worth noting here that we’ve also calculated that the bias is *at most* , however, this rate is extremely sloppy. That upper bound we just computed is *unlikely* to be tight. A better person than me would probably check, but honestly I just don’t give a shit^{16}

The second term in the bias is As before, we can write this as By our lemma, we know that the distribution of the term in the absolute value when is the same as where , which has mean and variance From Jensen’s inequality, we get If follows that and so we get vanishing bias as long as and .

Once again, I make no claims of tightness^{17}. Just because it’s a bit sloppy at this point doesn’t mean the job isn’t done.

**Theorem 1 **Let , be an iid sample from and let . Assume that

is absolutely continuous

and

Then Winsorized importance sampling converges in and is asymptotically unbiased.

Ok so that’s nice. But you’ll notice that I did not mention our piss-poor rate. That’s because there is absolutely no way in hell that the bias is ! That rate is an artefact of a *very* sloppy bound on .

Unfortunately, Mathematica couldn’t help me out. Its asymptotic abilities shit the bed at the sight of , which is everywhere in the exact expression (which I’ve put below in the fold.

```
-(((M/(1 + S))^(-(1/2) - S/2)*Gamma[(1 + S)/2]*
(6*(M/(1 + S))^(1/2 + M/2 + S/2)*((1 + S)/(1 - M + S))^(M/2 + S/2) -
5*M*(M/(1 + S))^(1/2 + M/2 + S/2)*((1 + S)/(1 - M + S))^(M/2 + S/2) +
M^2*(M/(1 + S))^(1/2 + M/2 + S/2)*((1 + S)/(1 - M + S))^(M/2 + S/2) +
8*S*(M/(1 + S))^(1/2 + M/2 + S/2)*((1 + S)/(1 - M + S))^(M/2 + S/2) -
6*M*S*(M/(1 + S))^(1/2 + M/2 + S/2)*((1 + S)/(1 - M + S))^(M/2 + S/2) +
M^2*S*(M/(1 + S))^(1/2 + M/2 + S/2)*((1 + S)/(1 - M + S))^(M/2 + S/2) +
2*S^2*(M/(1 + S))^(1/2 + M/2 + S/2)*((1 + S)/(1 - M + S))^(M/2 + S/2) -
M*S^2*(M/(1 + S))^(1/2 + M/2 + S/2)*((1 + S)/(1 - M + S))^(M/2 + S/2) -
6*Sqrt[-(M/(-1 + M - S))]*Sqrt[(-1 - S)/(-1 + M - S)]*
(M/(1 - M + S))^(M/2 + S/2)*Hypergeometric2F1[1, (1/2)*(-1 + M - S),
M/2, M/(-1 + M - S)] + 8*M*Sqrt[-(M/(-1 + M - S))]*
Sqrt[(-1 - S)/(-1 + M - S)]*(M/(1 - M + S))^(M/2 + S/2)*
Hypergeometric2F1[1, (1/2)*(-1 + M - S), M/2, M/(-1 + M - S)] -
2*M^2*Sqrt[-(M/(-1 + M - S))]*Sqrt[(-1 - S)/(-1 + M - S)]*
(M/(1 - M + S))^(M/2 + S/2)*Hypergeometric2F1[1, (1/2)*(-1 + M - S),
M/2, M/(-1 + M - S)] - 8*Sqrt[-(M/(-1 + M - S))]*
Sqrt[(-1 - S)/(-1 + M - S)]*S*(M/(1 - M + S))^(M/2 + S/2)*
Hypergeometric2F1[1, (1/2)*(-1 + M - S), M/2, M/(-1 + M - S)] +
4*M*Sqrt[-(M/(-1 + M - S))]*Sqrt[(-1 - S)/(-1 + M - S)]*S*
(M/(1 - M + S))^(M/2 + S/2)*Hypergeometric2F1[1, (1/2)*(-1 + M - S),
M/2, M/(-1 + M - S)] - 2*Sqrt[-(M/(-1 + M - S))]*
Sqrt[(-1 - S)/(-1 + M - S)]*S^2*(M/(1 - M + S))^(M/2 + S/2)*
Hypergeometric2F1[1, (1/2)*(-1 + M - S), M/2, M/(-1 + M - S)] +
6*M*(M/(1 + S))^(M/2)*((1 + S)/(1 - M + S))^(M/2 + S/2)*
Hypergeometric2F1[(1 + S)/2, (1/2)*(1 - M + S), (1/2)*(3 - M + S),
(-1 + M - S)/M] - 5*M^2*(M/(1 + S))^(M/2)*((1 + S)/(1 - M + S))^
(M/2 + S/2)*Hypergeometric2F1[(1 + S)/2, (1/2)*(1 - M + S),
(1/2)*(3 - M + S), (-1 + M - S)/M] + M^3*(M/(1 + S))^(M/2)*
((1 + S)/(1 - M + S))^(M/2 + S/2)*Hypergeometric2F1[(1 + S)/2,
(1/2)*(1 - M + S), (1/2)*(3 - M + S), (-1 + M - S)/M] +
2*M*S*(M/(1 + S))^(M/2)*((1 + S)/(1 - M + S))^(M/2 + S/2)*
Hypergeometric2F1[(1 + S)/2, (1/2)*(1 - M + S), (1/2)*(3 - M + S),
(-1 + M - S)/M] - M^2*S*(M/(1 + S))^(M/2)*((1 + S)/(1 - M + S))^
(M/2 + S/2)*Hypergeometric2F1[(1 + S)/2, (1/2)*(1 - M + S),
(1/2)*(3 - M + S), (-1 + M - S)/M] - 2*M*(M/(1 + S))^(M/2)*
((1 + S)/(1 - M + S))^(M/2 + S/2)*Hypergeometric2F1[(1 + S)/2,
(1/2)*(3 - M + S), (1/2)*(5 - M + S), (-1 + M - S)/M] +
3*M^2*(M/(1 + S))^(M/2)*((1 + S)/(1 - M + S))^(M/2 + S/2)*
Hypergeometric2F1[(1 + S)/2, (1/2)*(3 - M + S), (1/2)*(5 - M + S),
(-1 + M - S)/M] - M^3*(M/(1 + S))^(M/2)*((1 + S)/(1 - M + S))^
(M/2 + S/2)*Hypergeometric2F1[(1 + S)/2, (1/2)*(3 - M + S),
(1/2)*(5 - M + S), (-1 + M - S)/M] - 2*M*S*(M/(1 + S))^(M/2)*
((1 + S)/(1 - M + S))^(M/2 + S/2)*Hypergeometric2F1[(1 + S)/2,
(1/2)*(3 - M + S), (1/2)*(5 - M + S), (-1 + M - S)/M] +
M^2*S*(M/(1 + S))^(M/2)*((1 + S)/(1 - M + S))^(M/2 + S/2)*
Hypergeometric2F1[(1 + S)/2, (1/2)*(3 - M + S), (1/2)*(5 - M + S),
(-1 + M - S)/M]))/(((1 + S)/(1 - M + S))^S*
(2*(-2 + M)*M*Sqrt[(-1 - S)/(-1 + M - S)]*Gamma[M/2]*
Gamma[(1/2)*(5 - M + S)])))
```

But do not fear: we can recover. At the cost of an assumption about the tails of . (We’re also going to assume that is bounded because it makes things ever so slightly easier, although unbounded is ok^{18} as long as it doesn’t grow too quickly relative to .)

We are going to make the assumption that is in the domain of attraction of a generalized Pareto distribution with shape parameter . A sufficient condition, due to von Mises, is that

This seems like a weird condition, but it’s basically just a regularity condition at infinity. For example if is regularly varying at infinity^{19} and is, eventually, monotone^{20} decreasing, then this condition holds.

The von Mises condition is very natural for us as Falk and Marohn (1993) show that the relative error we get when approximating the tail of by a generalised Pareto density is the same as the relative error in the von Mises condition. That is if then where are constants and is the density of a generalised Pareto distribution.

Anyway, under those two assumptions, we can swap out the density of with its asymptotic approximation and get that, conditional on ,

Hence, the second term in the bias goes to zero if goes to zero.

Now this is not particularly pleasant, but it helps to recognise that even if a distribution doesn’t have finite moments, away from the extremes, its order statistics always do. This means that we can use Cauchy-Schwartz to get

Arguably, the most alarming term is the first one, but that can^{21} be tamed. To do this, we lean into a result from Bickel (1967) who, if you examine the proof and translate some obscurely-stated conditions and fix a typo^{22}, you get that You might worry that this is going to grow too quickly. But it doesn’t. Noting that , we can rewrite the upper bound in terms of the Beta function to get

To show that this doesn’t grow too quickly, we use the identity From this, it follows that In this case, we are interested in , so

Hence the we get that . This is increasing^{23} in , but we will see that it is not going up too fast.

For the second half of this shindig, we are going to attack A standard result^{24} from extreme value theory is that has the same distribution as the th order statistics from a sample of iid random variables. Hence^{25}, If follows^{26} that and Adding these together and doing some asymptotic expansions, we get which goes to zero^{27} like if .

We can multiply this rate together and get that the second term in the bias is bounded above by

Putting all of this together we have proved the following Corollary.

**Corollary 1 **Let , be an iid sample from and let . Assume that

is absolutely continuous and satisfies the von Mises condition

^{28}is bounded

^{29}

Winsorized importance sampling converges in with rate of, at most, , which is balanced when . Hence, WIS is^{30} -consistent.

Right, that was a bit of a journey, but let’s keep going to the variance.

It turns out that following the route I thought I was going to follow does not end well. That lovely set of tricks breaking up the variance into two conditional terms turns out to be very very unnecessary. Which is good, because I thoroughly failed to make the argument work.

If you’re curious, the problem is that the random variable is an absolute *bastard* to bound. The problem is that and so the usual trick of bounding that truncated expectation by or some such thing will prove that the variance is *finite* but not that it goes to zero. There is a solid chance that the Cauchy-Schwartz inequality would work. But truly that is just bloody messy^{31}.

So let’s do it the easy way, shall we. Fundamentally, we will use Noting that we can write compactly as Hence,

This goes to zero as long as .

Bickel (1967) shows that, noting that , and so the variance is bounded.

The previous argument shows that the variance is . We can refine that if we assume the von Mises condition hold. In that case we know that as and therefore Bickel (1967) shows that so combining this with the previous result gives a variance of . If we take , this gives , which is smaller than the previous bound for . It’s worth noting that Hence the variance goes to zero.

The argument that we used here is a modification of the argument in the TIS paper. This lead to a great deal of panic: did I just make my life extremely difficult? Could I have modified the TIS proof to show the bias goes to zero? To be honest, someone might be able to, but I can’t.

So anyway, we’ve proved the following theorem.

**Theorem 2 **Let , be an iid sample from and let . Assume that

is absolutely continuous

and

.

The variance in Winsorized importance sampling is at most .

Pareto-smoothed importance sampling (or PSIS) takes the observation that the tails are approximately Pareto distributed to add some bias correction to the mix. Essentially, it works by noting that approximating where is the median^{32} th order statistic in an iid sample of Generalised Pareto random variables with tail parameters fitted to the distribution.

This is a … funky … quadrature rule. To see that, we can write If we approximate the distribution of by and approximate the conditional probability by

Empirically, this is a very good choice (with the mild caveat that you need to truncate the largest expected order statistic by the observed maximum in order to avoid some variability issues). I would love to have a good analysis of why that is so, but honest I do not.

But, to the issue of this blog post the convergence and vanishing variance still holds. To see this, we note that So we are just re-weighting our tail samples by

Recalling that when , we had , this term is at most . This will not trouble either of our convergence proofs.

This leads to the following modification of our previous results.

**Theorem 3 **Let , be an iid sample from and let . Assume that

is absolutely continuous.

and are known with .

Pareto smoothed importance sampling converges in and its variance goes to zero and it is consistent and asymptotically unbiased.

**Corollary 2 **Assume further that

R satisfies the von Mises condition

^{33}is bounded

^{34}.

Then the L^1 convergence occurs at a rate of of, at most, . Furthermore, the variance of the PSIS estimator goes to zero at least as fast as .

Hence, under these additional conditions PSIS is^{35} -consistent.

So that’s what truncation and winsorization does to importance sampling estimates. I haven’t touched on the fairly important topic of asymptotic normality. Essentially, Griffin (1988), in a fairly complex^{36} paper that suggests that if you winsorize the product *and* winsorize it at both ends, the von Mises condition^{37} imply that the WIS estimator is asymptotically normal.

Why is this important, well the same proof shows that doubly winsorized importance sampling (dWIS) applied to the vector valued function will also be asymptotically normal, which implies, via the delta method, that the *self normalized* dWIS estimator is consistent, where is the th order statistic of .

It is very very likely that this can be shown (perhaps under some assumptions) for something closer to the version of PSIS we use in practice. But that is an open question.

proportional to↩︎

because is very small↩︎

because is a reasonable size, but is tiny.↩︎

I have surreptitiously dropped the subscript because I am gay and sneaky.↩︎

That it’s parameterised by is an artefact of history.↩︎

We need to be finite, so we need .↩︎

very fucking complex↩︎

I have used that old trick of using the same letter for the CDF as the random variable when I have a lot of random variables. ↩︎

aka the tail index↩︎

This is a relevant case. But if you think a little bit about it, our problem happens when grows

*much*faster than . For example if and for , then , and if , then , which is a slowly growing function.↩︎Because the truncation depends on , moving from the th partial sum to the th partial sum changes the distribution of . This is exactly why the dead Russians gifted us with triangular arrays.↩︎

Also practical unbounded , but it’s just easier for bounded ↩︎

Shut up. I know. Don’t care.↩︎

or, hell, even in a book↩︎

Straight up, though, I spent 2 days dicking around with tail bounds on sums of Bernoulli random variables for some bloody reason before I just looked at the damn formula.↩︎

Ok. I checked. And yeah. Same technique as below using Jensen in its . If you put that together you get something that goes to zero like , which is for our usual choice of . Which confirms the suspicion that the first term in the bias goes to zero

*much*faster than the second (remembering, of course, that Jensen’s inequality is notoriously loose!).↩︎It’s Pride month↩︎

The result holds exactly if and with a turning up somewhere if it’s .↩︎

for a slowly varying function (eg a power of a logarithm) .↩︎

A property that implies this is that is differentiable and

*convex at infinity*, which is to say that there is some finite such that exists for all and is a monotone function on .↩︎There’s a condition here that has to be large enough, but it’s enough if .↩︎

The first in the equation below is missing in the paper. If you miss this, you suddenly get the expected value converging to zero, which would be

*very*surprising. Always sense-check the proofs, people. Even if a famous person did it in the 60s.↩︎We need to take to be able to estimate the tail index from a sample, which gives an upper bound by a constant.↩︎

Note that if , then . Because this is monotone, it doesn’t change ordering of the sample↩︎

This is, incidentally, how Bickel got the upper bound on the moments. He combined this with an upper bound on the quantile function.↩︎

Save the cheerleader, save the world. Except it’s one minus a beta is still beta but with the parameters reversed.↩︎

As long as ↩︎

The rate here is probably not optimal, but it will guarantee that the error in the Pareto approximation doesn’t swamp the other terms.↩︎

Or doesn’t grow to quickly, with some modification of the rates in the unlikely case that it grows polynomially.↩︎

almost, there’s an epsilon gap but I don’t give a shit↩︎

And girl do not get me started on messy. I ended up going down a route where I used the [inequality]((https://www.sciencedirect.com/science/article/pii/0167715288900077) which holds for any supported on with differentiable density. And let me tell you. If you dick around with enough beta distributions you can get something. Is it what you want? Fucking no. It is

*a lot*of work, including having to differentiate the conditional expectation, and it gives you sweet bugger all.↩︎Or, the expected within ↩︎

The rate here is probably not optimal, but it will guarantee that the error in the Pareto approximation doesn’t swamp the other terms.↩︎

Or doesn’t grow to quickly, with some modification of the rates in the unlikely case that it grows polynomially.↩︎

almost, there’s an epsilon gap but I don’t give a shit↩︎

I mean, the tools are elementary. It’s just a lot of detailed estimates and Berry-Esseen as far as the eye can see.↩︎

and more general things↩︎

BibTeX citation:

```
@online{simpson2022,
author = {Dan Simpson},
editor = {},
title = {Tail Stabilization of Importance Sampling Etimators: {A} Bit
of Theory},
date = {2022-06-15},
url = {https://dansblog.netlify.app/2022-06-03-that-psis-proof},
langid = {en}
}
```

For attribution, please cite this work as:

Dan Simpson. 2022. “Tail Stabilization of Importance Sampling
Etimators: A Bit of Theory.” June 15, 2022. https://dansblog.netlify.app/2022-06-03-that-psis-proof.

Last time, we looked at making JAX primitives. We built four of them. Today we are going to implement the corresponding differentiation rules! For three^{1} of them.

So strap yourselves in. This is gonna be detailed.

If you’re interested in the code^{2}, the git repo for this post is linked at the bottom and in there you will find a folder with the python code in a python file.

Derivatives are computed in JAX through the glory and power of automatic differentiation. If you came to this blog hoping for a great description of how autodiff works, I am terribly sorry but I absolutely do not have time for that. Might I suggest google? Or maybe flick through this survey by Charles Margossian..

The most important thing to remember about algorithmic differentiation is that it is *not* symbolic differentiation. That is, it does not create the functional form of the derivative of the function and compute that. Instead, it is a system for cleverly composing derivatives in each bit of the program to compute the *value* of the derivative of the function.

But for that to work, we need to implement those clever little mini-derivatives. In particular, every function needs to have a function to compute the corresponding Jacobian-vector product where the matrix has entries

Ok. So let’s get onto this. We are going to derive and implement some Jacobian-vector products. And all of the assorted accoutrement. And by crikey. We are going to do it all in a JAX-traceable way.

The first of the derivatives that we need to work out is the derivative of a linear solve . Now, intrepid readers, the obvious thing to do is look the damn derivative up. You get exactly no hero points for computing it yourself.

But I’m not you, I’m a dickhead.

So I’m going to derive it. I could pretend there are reasons^{3}, but that would just be lying. I’m doing it because I can.

Beyond the obvious fun of working out a matrix derivative from first principles, this is fun because we have *two* arguments instead of just one. Double the fun.

And we really should make sure the function is differentiated with respect to every reasonable argument. Why? Because if you write code other people might use, you don’t get to control how they use it (or what they will email you about). So it’s always good practice to limit surprises (like a function not being differentiable wrt some argument) to cases^{4} where it absolutely necessary. This reduces the emails.

To that end, let’s take an arbitrary SPD matrix with a *fixed* sparsity pattern. Let’s take another symmetric matrix with *the same sparsity pattern* and assume that is small enough^{5} that is still symmetric positive definite. We also need a vector with a small .

Now let’s get algebraing.

Easy^{6} as.

We’ve actually calculated the derivative now, but it’s a little more work to recognise it.

To do that, we need to remember the practical definition of the Jacobian of a function that takes an -dimensional input and produces an -dimensional output. It is the matrix such that

The formulas further simplify if we write . Then, if we want the Jacobian-vector product for the first argument, it is while the Jacobian-vector product for the second argument is

The only wrinkle in doing this is we need to remember that we are only storing the lower triangle of . Because we need to represent the same way, it is represented as a vector `Delta_x`

that contains only the lower triangle of . So we need to make sure we remember to form the *whole* matrix before we do the matrix-vector product !

But otherwise, the implementation is going to be pretty straightforward. The Jacobian-vector product costs one additional linear solve (beyond the one needed to compute the value ).

In the language of JAX (and autodiff in general), we refer to and as *tangent vectors*. In search of a moderately coherent naming convention, we are going to refer to the tangent associated with the variable `x`

as `xt`

.

So let’s implement this. Remember: it needs^{7} to be JAX traceable.

For some sense of continuity, we are going to keep the naming of the primitives from the last blog post, but we are *not* going to attack them in the same order. Why not? Because we work in order of complexity.

So first off we are going to do the triangular solve. As I have yet to package up the code (I promise, that will happen next^{8}), I’m just putting it here under the fold.

```
from scipy import sparse
import numpy as np
from jax import numpy as jnp
from jax import core
from jax._src import abstract_arrays
from jax import core
sparse_triangular_solve_p = core.Primitive("sparse_triangular_solve")
def sparse_triangular_solve(L_indices, L_indptr, L_x, b, *, transpose: bool = False):
"""A JAX traceable sparse triangular solve"""
return sparse_triangular_solve_p.bind(L_indices, L_indptr, L_x, b, transpose = transpose)
@sparse_triangular_solve_p.def_impl
def sparse_triangular_solve_impl(L_indices, L_indptr, L_x, b, *, transpose = False):
"""The implementation of the sparse triangular solve. This is not JAX traceable."""
L = sparse.csc_array((L_x, L_indices, L_indptr))
assert L.shape[0] == L.shape[1]
assert L.shape[0] == b.shape[0]
if transpose:
return sparse.linalg.spsolve_triangular(L.T, b, lower = False)
else:
return sparse.linalg.spsolve_triangular(L.tocsr(), b, lower = True)
@sparse_triangular_solve_p.def_abstract_eval
def sparse_triangular_solve_abstract_eval(L_indices, L_indptr, L_x, b, *, transpose = False):
assert L_indices.shape[0] == L_x.shape[0]
assert b.shape[0] == L_indptr.shape[0] - 1
return abstract_arrays.ShapedArray(b.shape, b.dtype)
```

```
from jax._src import ad_util
from jax.interpreters import ad
from jax import lax
from jax.experimental import sparse as jsparse
def sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent, *, transpose):
"""
A jax-traceable jacobian-vector product. In order to make it traceable,
we use the experimental sparse CSC matrix in JAX.
Input:
arg_values: A tuple of (L_indices, L_indptr, L_x, b) that describe
the triangular matrix L and the rhs vector b
arg_tangent: A tuple of tangent values (same lenght as arg_values).
The first two values are nonsense - we don't differentiate
wrt integers!
transpose: (boolean) If true, solve L^Tx = b. Otherwise solve Lx = b.
Output: A tuple containing the maybe_transpose(L)^{-1}b and the corresponding
Jacobian-vector product.
"""
L_indices, L_indptr, L_x, b = arg_values
_, _, L_xt, bt = arg_tangent
value = sparse_triangular_solve(L_indices, L_indptr, L_x, b, transpose=transpose)
if type(bt) is ad.Zero and type(L_xt) is ad.Zero:
# I legit do not think this ever happens. But I'm honestly not sure.
print("I have arrived!")
return value, lax.zeros_like_array(value)
if type(L_xt) is not ad.Zero:
# L is variable
if transpose:
Delta = jsparse.CSC((L_xt, L_indices, L_indptr), shape = (b.shape[0], b.shape[0])).transpose()
else:
Delta = jsparse.CSC((L_xt, L_indices, L_indptr), shape = (b.shape[0], b.shape[0]))
jvp_Lx = sparse_triangular_solve(L_indices, L_indptr, L_x, Delta @ value, transpose = transpose)
else:
jvp_Lx = lax.zeros_like_array(value)
if type(bt) is not ad.Zero:
# b is variable
jvp_b = sparse_triangular_solve(L_indices, L_indptr, L_x, bt, transpose = transpose)
else:
jvp_b = lax.zeros_like_array(value)
return value, jvp_b - jvp_Lx
ad.primitive_jvps[sparse_triangular_solve_p] = sparse_triangular_solve_value_and_jvp
```

Before we see if this works, let’s first have talk about the structure of the function I just wrote. Generally speaking, we want a function that takes in the primals and tangents at tuples and then returns the value and the^{9} Jacobian-vector product.

The main thing you will notice in the code is that there is *a lot* of checking for `ad.Zero`

. This is a special type defined in JAX that is, essentially, telling the autodiff system that we are not differentiating wrt that variable. This is different to a tangent that just happens to be numerically equal to zero. Any code for a Jacobian-vector product needs to handle this special value.

As we have two arguments, we have 3 interesting options:

Both

`L_xt`

and`bt`

are`ad.Zero`

: This means the function is a constant and the derivative is zero. I am fairly certain that we do not need to manually handle this case, but because I don’t know and I do not like surprises, it’s in there.`L_xt`

is*not*`ad.Zero`

: This means that we need to differentiate wrt the matrix. In this case we need to compute or , depending on the`transpose`

argument. In order to do this, I used the`jax.experimental.sparse.CSC`

class, which has some very limited sparse matrix support (basically matrix-vector products). This is*extremely*convenient because it means I don’t need to write the matrix-vector product myself!`bt`

is*not*`ad.Zero`

: This means that we need to differentiate wrt the rhs vector. This part of the formula is pretty straightforward: just an application of the primal.

In the case that either `L_xt`

or `bt`

are `ad.Zero`

, we simply set the corresponding contribution to the jvp to zero.

It’s worth saying that you can bypass all of this `ad.Zero`

logic by writing separate functions for the JVP contribution from each input and then chaining them together using^{10} `ad.defjvp2()`

to chain them together. This is what the `lax.linalg.triangular_solve()`

implementation does.

So why didn’t I do this? I avoided this because in the other primitives I have to implement, there are expensive computations (like Cholesky factorisations) that I want to share between the primal and the various tangent calculations. The `ad.defjvp`

frameworks don’t allow for that. So I decided not to demonstrate/learn two separate patterns.

Now I’ve never actively wanted a Jacobian-vector product in my whole life. I’m sorry. I want a gradient. Gimme a gradient. I am the Veruca Salt of gradients.

In may autodiff systems, if you want^{11} a gradient, you need to implement vector-Jacobian products^{12} explicitly.

One of the odder little innovations in JAX is that instead of forcing you to implement this as well^{13}, you only need to implement half of it.

You see, some clever analysis that, as far as I far as I can tell^{14}, is detailed in this paper shows that you only need to form explicit vector-Jacobian products for the structurally linear arguments of the function.

In JAX (and maybe elsewhere), this is known as a *transposition rule*. The combination of a transopition rule and a JAX-traceable Jacobian-vector product is enough for JAX to compute all of the directional derivatives and gradients we could ever hope for.

As far as I understand, it is all about functions that are *structurally linear* in some arguments. For instance, if is a matrix-valued function and and are vectors, then the function is structurally linear in in the sense that for every fixed value of , the function is linear in . The resulting transpositon rule is then

```
def f_transpose(x, y):
Ax = A(x)
gx = g(x)
return (None, Ax.T @ y + gx)
```

The first element of the return is `None`

because is not^{15} structurally linear in so there is nothing to transpose. The second element simply takes the matrix in the linear function and transposes it.

If you know anything about autodiff, you’ll think “this doesn’t *feel* like enough” and it’s not. JAX deals with the non-linear part of by tracing the evaluation tree for its Jacobian-vector product and … manipulating^{16} it.

We already built the abstract evaluation function last time around, so the tracing part can be done. All we need is the transposition rule.

The linear solve is non-linear in the first argument but linear in the second argument. So we only need to implement where the subscript indicates we’re only computing the Jacobian wrt .

Initially, I struggled to work out what needed to be implemented here. The thing that clarified the process for me was looking at JAX’s internal implementation of the Jacobian-vector product for a dense matrix. From there, I understood what this had to look like for a vector-valued function and this is the result.

```
def sparse_triangular_solve_transpose_rule(cotangent, L_indices, L_indptr, L_x, b, *, transpose):
"""
Transposition rule for the triangular solve.
Translated from here https://github.com/google/jax/blob/41417d70c03b6089c93a42325111a0d8348c2fa3/jax/_src/lax/linalg.py#L747.
Inputs:
cotangent: Output cotangent (aka adjoint). (produced by JAX)
L_indices, L_indptr, L_x: Represenation of sparse matrix. L_x should be concrete
b: The right hand side. Must be an jax.interpreters.ad.UndefinedPrimal
transpose: (boolean) True: solve $L^Tx = b$. False: Solve $Lx = b$.
Output:
A 4-tuple with the adjoints (None, None, None, b_adjoint)
"""
assert not ad.is_undefined_primal(L_x) and ad.is_undefined_primal(b)
if type(cotangent) is ad_util.Zero:
cot_b = ad_util.Zero(b.aval)
else:
cot_b = sparse_triangular_solve(L_indices, L_indptr, L_x, cotangent, transpose = not transpose)
return None, None, None, cot_b
ad.primitive_transposes[sparse_triangular_solve_p] = sparse_triangular_solve_transpose_rule
```

If this doesn’t make a lot of sense to you, that’s because it’s confusing.

One way to think of it is in terms of the more ordinary notation. Mike Giles has a classic paper that covers these results for basic linear algebra. The idea is to imagine that, as part of your larger program, you need to compute .

Forward-mode autodiff computes the *sensitivity* of , usually denoted from the sensitivies and . These have already been computed. The formula in Giles is The canny reader will recognise this as exactly^{17} the formula for the Jacobian-vector product.

So what does reverse-mode autodiff do? Well it moves through the program in the other direction. So instead of starting with the sensitivities and already computed, we instead start with the^{18} *adjoint sensitivity* . Our aim is to compute and from .

The details of how to do this are^{19} *beyond the scope*, but without tooooooo much effort you can show that which you should recognise as the equation that was just implemented.

The thing that we *do not* have to implement in JAX is the other adjoint that, for dense matrices^{20}, is Through the healing power of … something?—Truly I do not know.— JAX can work that bit out itself. woo.

So let’s see if this works. I’m not going to lie, I’m flying by the seat of my pants here. I’m not super familiar with the JAX internals, so I have written a lot of test cases. You may wish to skip this part. But rest assured that almost every single one of these cases was useful to me working out how this thing actually worked!

```
def make_matrix(n):
one_d = sparse.diags([[-1.]*(n-1), [2.]*n, [-1.]*(n-1)], [-1,0,1])
A = (sparse.kronsum(one_d, one_d) + sparse.eye(n*n)).tocsc()
A_lower = sparse.tril(A, format = "csc")
A_index = A_lower.indices
A_indptr = A_lower.indptr
A_x = A_lower.data
return (A_index, A_indptr, A_x, A)
A_indices, A_indptr, A_x, A = make_matrix(10)
```

This is the same test case as the last blog. We will just use the lower triangle of as the test matrix.

First things first, let’s check out the numerical implementation of the function. We will do that by comparing the implemented Jacobian-vector product with the *definition* of the Jacobian-vector product (aka the forward^{21} difference approximation).

There are lots of things that we could do here to turn these into *actual* tests. For instance, the test suite inside JAX has a lot of nice convenience functions for checking implementations of derivatives. But I went with homespun because that was how I was feeling.

You’ll also notice that I’m using random numbers here, which is fine for a blog. Not so fine for a test that you don’t want to be potentially^{22} flaky.

The choice of `eps = 1e-4`

is roughly^{23} because it’s the square root of the single precision machine epsilon^{24}. A very rough back of the envelope calculation for the forward difference approximation to the derivative shows that the square root of the machine epislon is about the size you want your perturbation to be.

```
b = np.random.standard_normal(100)
bt = np.random.standard_normal(100)
bt /= np.linalg.norm(bt)
A_xt = np.random.standard_normal(len(A_x))
A_xt /= np.linalg.norm(A_xt)
arg_values = (A_indices, A_indptr, A_x, b )
arg_tangent_A = (None, None, A_xt, ad.Zero(type(b)))
arg_tangent_b = (None, None, ad.Zero(type(A_xt)), bt)
arg_tangent_Ab = (None, None, A_xt, bt)
p, t_A = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_A, transpose = False)
_, t_b = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_b, transpose = False)
_, t_Ab = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_Ab, transpose = False)
pT, t_AT = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_A, transpose = True)
_, t_bT = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_b, transpose = True)
eps = 1e-4
tt_A = (sparse_triangular_solve(A_indices, A_indptr, A_x + eps * A_xt, b) - p) /eps
tt_b = (sparse_triangular_solve(A_indices, A_indptr, A_x, b + eps * bt) - p) / eps
tt_Ab = (sparse_triangular_solve(A_indices, A_indptr, A_x + eps * A_xt, b + eps * bt) - p) / eps
tt_AT = (sparse_triangular_solve(A_indices, A_indptr, A_x + eps * A_xt, b, transpose = True) - pT) / eps
tt_bT = (sparse_triangular_solve(A_indices, A_indptr, A_x, b + eps * bt, transpose = True) - pT) / eps
print(f"""
Transpose = False:
Error A varying: {np.linalg.norm(t_A - tt_A): .2e}
Error b varying: {np.linalg.norm(t_b - tt_b): .2e}
Error A and b varying: {np.linalg.norm(t_Ab - tt_Ab): .2e}
Transpose = True:
Error A varying: {np.linalg.norm(t_AT - tt_AT): .2e}
Error b varying: {np.linalg.norm(t_bT - tt_bT): .2e}
""")
```

```
Transpose = False:
Error A varying: 1.08e-07
Error b varying: 0.00e+00
Error A and b varying: 4.19e-07
Transpose = True:
Error A varying: 1.15e-07
Error b varying: 0.00e+00
```

Brilliant! Everythign correct withing single precision!

Making the numerical implementation work is only half the battle. We also have to make it work *in the context of JAX*.

Now I would be lying if I pretended this process went smoothly. But the first time is for experience. It’s mostly a matter of just reading the documentation carefully and going through similar examples that have already been implemented.

And testing. I learnt how this was supposed to work by testing it.

(For full disclosure, I also wrote a big block f-string in the `sparse_triangular_solve()`

function at one point that told me the types, shapes, and what `transpose`

was, which was how I worked out that my code was breaking because I forgot the first to `None`

outputs in the transposition rule. When it doubt, print shit.)

As you will see from my testing code, I was not going for elegance. I was running the damn permutations. If you’re looking for elegance, look elsewhere.

```
from jax import jvp, grad
from jax import scipy as jsp
def f(theta):
Ax_theta = jnp.array(A_x)
Ax_theta = Ax_theta.at[A_indptr[20]].add(theta[0])
Ax_theta = Ax_theta.at[A_indptr[50]].add(theta[1])
b = jnp.ones(100)
return sparse_triangular_solve(A_indices, A_indptr, Ax_theta, b, transpose = True)
def f_jax(theta):
Ax_theta = jnp.array(sparse.tril(A).todense())
Ax_theta = Ax_theta.at[20,20].add(theta[0])
Ax_theta = Ax_theta.at[50,50].add(theta[1])
b = jnp.ones(100)
return jsp.linalg.solve_triangular(Ax_theta, b, lower = True, trans = "T")
def g(theta):
Ax_theta = jnp.array(A_x)
b = jnp.ones(100)
b = b.at[0].set(theta[0])
b = b.at[51].set(theta[1])
return sparse_triangular_solve(A_indices, A_indptr, Ax_theta, b, transpose = True)
def g_jax(theta):
Ax_theta = jnp.array(sparse.tril(A).todense())
b = jnp.ones(100)
b = b.at[0].set(theta[0])
b = b.at[51].set(theta[1])
return jsp.linalg.solve_triangular(Ax_theta, b, lower = True, trans = "T")
def h(theta):
Ax_theta = jnp.array(A_x)
Ax_theta = Ax_theta.at[A_indptr[20]].add(theta[0])
b = jnp.ones(100)
b = b.at[51].set(theta[1])
return sparse_triangular_solve(A_indices, A_indptr, Ax_theta, b, transpose = False)
def h_jax(theta):
Ax_theta = jnp.array(sparse.tril(A).todense())
Ax_theta = Ax_theta.at[20,20].add(theta[0])
b = jnp.ones(100)
b = b.at[51].set(theta[1])
return jsp.linalg.solve_triangular(Ax_theta, b, lower = True, trans = "N")
def no_diff(theta):
return sparse_triangular_solve(A_indices, A_indptr, A_x, jnp.ones(100), transpose = False)
def no_diff_jax(theta):
return jsp.linalg.solve_triangular(jnp.array(sparse.tril(A).todense()), jnp.ones(100), lower = True, trans = "N")
A_indices, A_indptr, A_x, A = make_matrix(10)
primal1, jvp1 = jvp(f, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal2, jvp2 = jvp(f_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad1 = grad(lambda x: jnp.mean(f(x)))(jnp.array([-142., 342.]))
grad2 = grad(lambda x: jnp.mean(f_jax(x)))(jnp.array([-142., 342.]))
primal3, jvp3 = jvp(g, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal4, jvp4 = jvp(g_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad3 = grad(lambda x: jnp.mean(g(x)))(jnp.array([-142., 342.]))
grad4 = grad(lambda x: jnp.mean(g_jax(x)))(jnp.array([-142., 342.]))
primal5, jvp5 = jvp(h, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal6, jvp6 = jvp(h_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad5 = grad(lambda x: jnp.mean(h(x)))(jnp.array([-142., 342.]))
grad6 = grad(lambda x: jnp.mean(h_jax(x)))(jnp.array([-142., 342.]))
primal7, jvp7 = jvp(no_diff, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal8, jvp8 = jvp(no_diff_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad7 = grad(lambda x: jnp.mean(no_diff(x)))(jnp.array([-142., 342.]))
grad8 = grad(lambda x: jnp.mean(no_diff_jax(x)))(jnp.array([-142., 342.]))
print(f"""
Variable L:
Primal difference: {np.linalg.norm(primal1 - primal2): .2e}
JVP difference: {np.linalg.norm(jvp1 - jvp2): .2e}
Gradient difference: {np.linalg.norm(grad1 - grad2): .2e}
Variable b:
Primal difference: {np.linalg.norm(primal3 - primal4): .2e}
JVP difference: {np.linalg.norm(jvp3 - jvp4): .2e}
Gradient difference: {np.linalg.norm(grad3 - grad4): .2e}
Variable L and b:
Primal difference: {np.linalg.norm(primal5 - primal6): .2e}
JVP difference: {np.linalg.norm(jvp5 - jvp6): .2e}
Gradient difference: {np.linalg.norm(grad5 - grad6): .2e}
No diff:
Primal difference: {np.linalg.norm(primal7 - primal8)}
JVP difference: {np.linalg.norm(jvp7 - jvp8)}
Gradient difference: {np.linalg.norm(grad7 - grad8)}
""")
```

```
Variable L:
Primal difference: 1.98e-07
JVP difference: 2.58e-12
Gradient difference: 0.00e+00
Variable b:
Primal difference: 7.94e-06
JVP difference: 1.83e-08
Gradient difference: 3.29e-10
Variable L and b:
Primal difference: 2.08e-06
JVP difference: 1.08e-08
Gradient difference: 2.33e-10
No diff:
Primal difference: 2.2101993124579167e-07
JVP difference: 0.0
Gradient difference: 0.0
```

Stunning!

Ok. So this is a very similar problem to the one that we just solved. But, as fate would have it, the solution is going to look quite different. Why? Because we need to compute a Cholesky factorisation.

First things first, though, we are going to need a JAX-traceable way to compute a Cholesky factor. This means that we need^{25} to tell our `sparse_solve`

function the how many non-zeros the sparse Cholesky will have. Why? Well. It has to do with how the function is used.

When `sparse_cholesky()`

is called with concrete inputs^{26}, then it can quite happily work out the sparsity structure of . But when JAX is preparing to transform the code, eg when it’s building a gradient, it calls `sparse_cholesky()`

using abstract arguments that only share the shape information from the inputs. This is *not* enough to compute the sparsity structure. We *need* the `indices`

and `indptr`

arrays.

This means that we need `sparse_cholesky()`

to throw an error if `L_nse`

isn’t passed. This wasn’t implemented well last time, so here it is done properly.

(If you’re wondering about that `None`

argument, it is the identity transform. So if `A_indices`

is a concrete value, `ind = A_indices`

. Otherwise an error is called.)

```
sparse_cholesky_p = core.Primitive("sparse_cholesky")
def sparse_cholesky(A_indices, A_indptr, A_x, *, L_nse: int = None):
"""A JAX traceable sparse cholesky decomposition"""
if L_nse is None:
err_string = "You need to pass a value to L_nse when doing fancy sparse_cholesky."
ind = core.concrete_or_error(None, A_indices, err_string)
ptr = core.concrete_or_error(None, A_indptr, err_string)
L_ind, _ = _symbolic_factor(ind, ptr)
L_nse = len(L_ind)
return sparse_cholesky_p.bind(A_indices, A_indptr, A_x, L_nse = L_nse)
```

```
@sparse_cholesky_p.def_impl
def sparse_cholesky_impl(A_indices, A_indptr, A_x, *, L_nse):
"""The implementation of the sparse cholesky This is not JAX traceable."""
L_indices, L_indptr= _symbolic_factor(A_indices, A_indptr)
if L_nse is not None:
assert len(L_indices) == L_nse
L_x = _structured_copy(A_indices, A_indptr, A_x, L_indices, L_indptr)
L_x = _sparse_cholesky_impl(L_indices, L_indptr, L_x)
return L_indices, L_indptr, L_x
def _symbolic_factor(A_indices, A_indptr):
# Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.
n = len(A_indptr) - 1
L_sym = [np.array([], dtype=int) for j in range(n)]
children = [np.array([], dtype=int) for j in range(n)]
for j in range(n):
L_sym[j] = A_indices[A_indptr[j]:A_indptr[j + 1]]
for child in children[j]:
tmp = L_sym[child][L_sym[child] > j]
L_sym[j] = np.unique(np.append(L_sym[j], tmp))
if len(L_sym[j]) > 1:
p = L_sym[j][1]
children[p] = np.append(children[p], j)
L_indptr = np.zeros(n+1, dtype=int)
L_indptr[1:] = np.cumsum([len(x) for x in L_sym])
L_indices = np.concatenate(L_sym)
return L_indices, L_indptr
def _structured_copy(A_indices, A_indptr, A_x, L_indices, L_indptr):
n = len(A_indptr) - 1
L_x = np.zeros(len(L_indices))
for j in range(0, n):
copy_idx = np.nonzero(np.in1d(L_indices[L_indptr[j]:L_indptr[j + 1]],
A_indices[A_indptr[j]:A_indptr[j+1]]))[0]
L_x[L_indptr[j] + copy_idx] = A_x[A_indptr[j]:A_indptr[j+1]]
return L_x
def _sparse_cholesky_impl(L_indices, L_indptr, L_x):
n = len(L_indptr) - 1
descendant = [[] for j in range(0, n)]
for j in range(0, n):
tmp = L_x[L_indptr[j]:L_indptr[j + 1]]
for bebe in descendant[j]:
k = bebe[0]
Ljk= L_x[bebe[1]]
pad = np.nonzero( \
L_indices[L_indptr[k]:L_indptr[k+1]] == L_indices[L_indptr[j]])[0][0]
update_idx = np.nonzero(np.in1d( \
L_indices[L_indptr[j]:L_indptr[j+1]], \
L_indices[(L_indptr[k] + pad):L_indptr[k+1]]))[0]
tmp[update_idx] = tmp[update_idx] - \
Ljk * L_x[(L_indptr[k] + pad):L_indptr[k + 1]]
diag = np.sqrt(tmp[0])
L_x[L_indptr[j]] = diag
L_x[(L_indptr[j] + 1):L_indptr[j + 1]] = tmp[1:] / diag
for idx in range(L_indptr[j] + 1, L_indptr[j + 1]):
descendant[L_indices[idx]].append((j, idx))
return L_x
@sparse_cholesky_p.def_abstract_eval
def sparse_cholesky_abstract_eval(A_indices, A_indptr, A_x, *, L_nse):
return core.ShapedArray((L_nse,), A_indices.dtype), \
core.ShapedArray(A_indptr.shape, A_indptr.dtype), \
core.ShapedArray((L_nse,), A_x.dtype)
```

Ok. So now on to the details. If we try to repeat our previous pattern it would look like this.

```
def sparse_solve_value_and_jvp(arg_values, arg_tangents, *, L_nse):
"""
Jax-traceable jacobian-vector product implmentation for sparse_solve.
"""
A_indices, A_indptr, A_x, b = arg_values
_, _, A_xt, bt = arg_tangents
# Needed for shared computation
L_indices, L_indptr, L_x = sparse_cholesky(A_indices, A_indptr, A_x)
# Make the primal
primal_out = sparse_triangular_solve(L_indices, L_indptr, L_x, b, transpose = False)
primal_out = sparse_triangular_solve(L_indices, L_indptr, L_x, primal_out, transpose = True)
if type(A_xt) is not ad.Zero:
Delta_lower = jsparse.CSC((A_xt, A_indices, A_indptr), shape = (b.shape[0], b.shape[0]))
# We need to do Delta @ primal_out, but we only have the lower triangle
rhs = Delta_lower @ primal_out + Delta_lower.transpose() @ primal_out - A_xt[A_indptr[:-1]] * primal_out
jvp_Ax = sparse_triangular_solve(L_indices, L_indptr, L_x, rhs)
jvp_Ax = sparse_triangular_solve(L_indices, L_indptr, L_x, jvp_Ax, transpose = True)
else:
jvp_Ax = lax.zeros_like_array(primal_out)
if type(bt) is not ad.Zero:
jvp_b = sparse_triangular_solve(L_indices, L_indptr, L_x, bt)
jvp_b = sparse_triangular_solve(L_indices, L_indptr, L_x, jvp_b, transpose = True)
else:
jvp_b = lax.zeros_like_array(primal_out)
return primal_out, jvp_b - jvp_Ax
```

That’s all well and good. Nothing weird there.

The problem comes when you need to implement the transposition rule. Remembering that , you might see the issue: we are going to need the Cholesky factorisation. *But we have no way to pass* *to the transpose function*.

This means that we would need to compute *two* Cholesky factorisations per gradient instead of one. As the Cholesky factorisation is our slowest operation, we do not want to do extra ones! We want to compute the Cholesky triangle once and pass it around like a party bottom^{27}. We do not want each of our functions to have to make a deep and meaningful connection with the damn matrix^{28}.

So how do we pass around our Cholesky triangle? Well, I do love a good class so my first thought was “fuck it. I’ll make a class and I’ll pass it that way”. But the developers of JAX had a *much* better idea.

Their idea was to abstract the idea of a linear solve and its gradients. They do this through `lax.custom_linear_solve`

. This is a function that takes all of the bits that you would need to compute and all of its derivatives. In particular it takes^{29}:

`matvec`

: A function that`matvec(x)`

that computes . This might seem a bit weird, but it’s the most common atrocity committed by mathematicians is abstracting^{30}a matrix to a linear mapping. So we might as well just suck it up.`b`

: The right hand side vector^{31}`solve`

: A function that takes takes the`matvec`

and a vector so that^{32}`solve(matvec, matvec(x)) == x`

`symmetric`

: A boolean indicating if is symmetric.

The idea (happily copped from the implementation of `jax.scipy.linalg.solve`

) is to wrap our Cholesky decomposition in the solve function. Through the never ending miracle of partial evaluation.

```
from functools import partial
def sparse_solve(A_indices, A_indptr, A_x, b, *, L_nse = None):
"""
A JAX-traceable sparse solve. For this moment, only for vector b
"""
assert b.shape[0] == A_indptr.shape[0] - 1
assert b.ndim == 1
L_indices, L_indptr, L_x = sparse_cholesky(
lax.stop_gradient(A_indices),
lax.stop_gradient(A_indptr),
lax.stop_gradient(A_x), L_nse = L_nse)
def chol_solve(L_indices, L_indptr, L_x, b):
out = sparse_triangular_solve(L_indices, L_indptr, L_x, b, transpose = False)
return sparse_triangular_solve(L_indices, L_indptr, L_x, out, transpose = True)
def matmult(A_indices, A_indptr, A_x, b):
A_lower = jsparse.CSC((A_x, A_indices, A_indptr), shape = (b.shape[0], b.shape[0]))
return A_lower @ b + A_lower.transpose() @ b - A_x[A_indptr[:-1]] * b
solver = partial(
lax.custom_linear_solve,
lambda x: matmult(A_indices, A_indptr, A_x, x),
solve = lambda _, x: chol_solve(L_indices, L_indptr, L_x, x),
symmetric = True)
return solver(b)
```

There are three things of note in that implementation.

The calls to

`lax.stop_gradient()`

: These tell JAX to not bother computing the gradient of these terms. The relevant parts of the derivatives are computed explicitly by`lax.custom_linear_solve`

in terms of`matmult`

and`solve`

, neither of which need the explicit derivative of the cholesky factorisation.!That definition of

`matmult()`

^{33}: Look. I don’t know what to tell you. Neither addition nor indexing is implemented for`jsparse.CSC`

objects. So we did it the semi-manual way. (I am thankful that matrix-vector multiplication is available)The definition of

`solver()`

: Partial evaluation is a wonderful wonderful thing.`functools.partial()`

transforms`lax.custom_linear_solve()`

from a function that takes 3 arguments (and some keywords), into a function`solver()`

that takes one^{34}argument^{35}(`b`

, the only positional argument of`lax.custom_linear_solve()`

that isn’t specified).

```
def f(theta):
Ax_theta = jnp.array(theta[0] * A_x)
Ax_theta = Ax_theta.at[A_indptr[:-1]].add(theta[1])
b = jnp.ones(100)
return sparse_solve(A_indices, A_indptr, Ax_theta, b)
def f_jax(theta):
Ax_theta = jnp.array(theta[0] * A.todense())
Ax_theta = Ax_theta.at[np.arange(100),np.arange(100)].add(theta[1])
b = jnp.ones(100)
return jsp.linalg.solve(Ax_theta, b)
def g(theta):
Ax_theta = jnp.array(A_x)
b = jnp.ones(100)
b = b.at[0].set(theta[0])
b = b.at[51].set(theta[1])
return sparse_solve(A_indices, A_indptr, Ax_theta, b)
def g_jax(theta):
Ax_theta = jnp.array(A.todense())
b = jnp.ones(100)
b = b.at[0].set(theta[0])
b = b.at[51].set(theta[1])
return jsp.linalg.solve(Ax_theta, b)
def h(theta):
Ax_theta = jnp.array(A_x)
Ax_theta = Ax_theta.at[A_indptr[:-1]].add(theta[0])
b = jnp.ones(100)
b = b.at[51].set(theta[1])
return sparse_solve(A_indices, A_indptr, Ax_theta, b)
def h_jax(theta):
Ax_theta = jnp.array(A.todense())
Ax_theta = Ax_theta.at[np.arange(100),np.arange(100)].add(theta[0])
b = jnp.ones(100)
b = b.at[51].set(theta[1])
return jsp.linalg.solve(Ax_theta, b)
primal1, jvp1 = jvp(f, (jnp.array([2., 3.]),), (jnp.array([1., 2.]),))
primal2, jvp2 = jvp(f_jax, (jnp.array([2., 3.]),), (jnp.array([1., 2.]),))
grad1 = grad(lambda x: jnp.mean(f(x)))(jnp.array([2., 3.]))
grad2 = grad(lambda x: jnp.mean(f_jax(x)))(jnp.array([2., 3.]))
primal3, jvp3 = jvp(g, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal4, jvp4 = jvp(g_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad3 = grad(lambda x: jnp.mean(g(x)))(jnp.array([-142., 342.]))
grad4 = grad(lambda x: jnp.mean(g_jax(x)))(jnp.array([-142., 342.]))
primal5, jvp5 = jvp(h, (jnp.array([2., 342.]),), (jnp.array([1., 2.]),))
primal6, jvp6 = jvp(h_jax, (jnp.array([2., 342.]),), (jnp.array([1., 2.]),))
grad5 = grad(lambda x: jnp.mean(f(x)))(jnp.array([2., 342.]))
grad6 = grad(lambda x: jnp.mean(f_jax(x)))(jnp.array([2., 342.]))
print(f"""
Check the plumbing!
Variable A:
Primal difference: {np.linalg.norm(primal1 - primal2): .2e}
JVP difference: {np.linalg.norm(jvp1 - jvp2): .2e}
Gradient difference: {np.linalg.norm(grad1 - grad2): .2e}
Variable b:
Primal difference: {np.linalg.norm(primal3 - primal4): .2e}
JVP difference: {np.linalg.norm(jvp3 - jvp4): .2e}
Gradient difference: {np.linalg.norm(grad3 - grad4): .2e}
Variable A and b:
Primal difference: {np.linalg.norm(primal5 - primal6): .2e}
JVP difference: {np.linalg.norm(jvp5 - jvp6): .2e}
Gradient difference: {np.linalg.norm(grad5 - grad6): .2e}
""")
```

```
Check the plumbing!
Variable A:
Primal difference: 1.98e-07
JVP difference: 1.43e-07
Gradient difference: 0.00e+00
Variable b:
Primal difference: 4.56e-06
JVP difference: 6.52e-08
Gradient difference: 9.31e-10
Variable A and b:
Primal difference: 8.10e-06
JVP difference: 1.83e-06
Gradient difference: 1.82e-12
```

Yes.

The other option for making this work would’ve been to implement the Cholesky factorisation as a primitive (~which we are about to do!~ which we will do another day) and then write the sparse solver directly as a pure JAX function.

```
def sparse_solve_direct(A_indices, A_indptr, A_x, b, *, L_nse = None):
L_indices, L_indptr, L_x = sparse_cholesky(A_indices, A_indptr, A_x)
out = sparse_triangular_solve(L_indices, L_indptr, L_x, b)
return sparse_triangular_solve(L_indices, L_indptr, L_x, out, transpose = True)
```

This function is JAX-traceable^{36} and, therefore, we could compute the gradient of it directly. It turns out that this is going to be a bad idea.

Why? Because the derivative of `sparse_cholesky`

, which we would have to chain together with the derivatives from the solver, is pretty complicated. Basically, this means that we’d have to do a lot more work^{37} than we do if we just implement the symbolic formula for the derivatives.

Ok, so now we get to the good one. The log-determinant of . The first thing that we need to do is wrench out a derivative. This is not as easy as it was for the linear solve. So what follows is a modification for sparse matrices from Appendix A of Boyd’s convex optimisation book.

It’s pretty easy to convince yourself that

It is harder to convince yourself how this could possibly be a useful fact.

If we write , as the eigenvalues of , then we have Remembering that is very small, it follows that will *also* be small. That translates to the eigenvalues of all being small. Therefore, we can use the approximation .

This means that^{38} which follows from the cyclic property of the trace.

If we recall the formula from the last section defining the Jacobian-vector product, in our context , is the vector of non-zero entries of the lower triangle of stacked by column, and is the vector of non-zero entries of the lower triangle of . That means the Jacobian-vector product is

Remembering that is sparse with the same sparsity pattern as , we see that the Jacobian-vector product requires us to know the values of that correspond to non-zero elements of . That’s good news because we will see that these entries are relatively cheap and easy to compute. Whereas the full inverse is dense and very expensive to compute.

But before we get to that, I need to point out a trap for young players^{39}. Lest your implementations go down faster than me when someone asks politely.

The problem comes from how we store our matrix. A mathematician would suggest that it’s our representation. A physicist^{40} would shit on about being coordinate free with such passion that he^{41} will keep going even after you quietly leave the room.

The problem is that we only store the non-zero entries of the lower-triangular part of . This means that *we need to be careful* that when we compute the Jacobian-vector product that we properly compute the Matrix-vector product.

Let `A_indices`

and `A_indptr`

define the sparsity structure of (and ). Then if is our input and is our vector, then we need to do the follow steps to compute the Jacobian-vector product:

- Compute
`Ainv_x`

(aka the non-zero elements of that correspond to the sparsity pattern of ) - Compute the matrix vector product as

`jvp = 2 * sum(Ainv_x * v) - sum(Ainv_x[A_indptr[:-1]] * v[A_indptr[:-1]])`

Why does it look like that? Well we need to add the contribution from the upper triangle as well as the lower triangle. And one way to do that is to just double the sum and then subtract off the diagonal terms that we’ve counted twice.

(I’m making a pretty big assumption here, which is fine in our context, that has a non-zero diagonal. If that doesn’t hold, it’s just a change of the indexing in the second term to just pull out the diagonal terms.)

Using similar reasoning, we can compute the Jacobian as where is the vector that stacks the columns of the elements of that correspond to the non-zero elements of . (Yikes!)

So now we need to actually work out how to compute this *partial inverse* of a symmetric positive definite matrix . To do this, we are going to steal a technique that goes back to Takahashi, Fagan, and Chen^{42} in 1973. (For this presentation, I’m basically pillaging Håvard Rue and Sara Martino’s 2007 paper.)

Their idea was that if we write , where is a lower-triangular matrix with ones on the diagonal and is diagonal. This links up with our usual Cholesky factorisation through the identity . It follows that if , then . Then, we make some magic manipulations^{43}.

Once again, this does not look super-useful. The trick is to notice 2 things.

Because is lower triangular, is also lower triangular and the elements of are the inverse of the diagonal elements of (aka they are all 1). Therefore, is a lower triangular matrix with a diagonal given by the diagonal of .

is an upper triangular matrix and .

These two things together lead to the somewhat unexpected situation where the upper triangle of defines a set of recursions for the upper triangle of . (And, therefore, all of because is symmetric!) These are sometimes referred to as the Takahashi recursions.

But we don’t want the whole upper triangle of , we just want the ones that correspond to the non-zero elements of . Unfortunately, the set of recursions are not, in general, solveable using only that subset of . But we are in luck: they are solveable using the elements of that correspond to the non-zeros of , which, as we know from a few posts ago, is a superset of the non-zero elements of !

From this, we get the recursions running from , (the order is important!) such that

If you recall our discussion way back when about the way the non-zero structure of the the column of relates to the non-zero structure of the th column for , it’s clear that we have computed enough^{44} of at every step to complete the recursions.

Now we just need to Python it. (And thanks to Finn Lindgren who helped me understand how to implement this, which he may or may not remember because it happened about five years ago.)

Actually, we need this to be JAX-traceable, so we are going to implement a very basic primitive. In particular, we don’t need to implement a derivative or anything like that, just an abstract evaluation and an implementation.

```
sparse_partial_inverse_p = core.Primitive("sparse_partial_inverse")
def sparse_partial_inverse(L_indices, L_indptr, L_x, out_indices, out_indptr):
"""
Computes the elements (out_indices, out_indptr) of the inverse of a sparse matrix (A_indices, A_indptr, A_x)
with Choleksy factor (L_indices, L_indptr, L_x). (out_indices, out_indptr) is assumed to be either
the sparsity pattern of A or a subset of it in lower triangular form.
"""
return sparse_partial_inverse_p.bind(L_indices, L_indptr, L_x, out_indices, out_indptr)
@sparse_partial_inverse_p.def_abstract_eval
def sparse_partial_inverse_abstract_eval(L_indices, L_indptr, L_x, out_indices, out_indptr):
return abstract_arrays.ShapedArray(out_indices.shape, L_x.dtype)
@sparse_partial_inverse_p.def_impl
def sparse_partial_inverse_impl(L_indices, L_indptr, L_x, out_indices, out_indptr):
n = len(L_indptr) - 1
Linv = sparse.dok_array((n,n), dtype = L_x.dtype)
counter = len(L_x) - 1
for col in range(n-1, -1, -1):
for row in L_indices[L_indptr[col]:L_indptr[col+1]][::-1]:
```