Sparse Matrices 4: Design is my passion

Just some harmeless notes. Like the ones Judy Dench took in that movie.

Sparse matrices
Sparse Cholesky factorisation

May 16, 2022

This is the fourth post in a series where I try to squeeze autodiffable sparse matrices into JAX with the aim to speed up some model classes in PyMC. So far, I have:

I am in the process of writing a blog on building new primitives1 into JAX, but as I was doing it I accidentally wrote a long section about options for exposing sparse matrices. It really didn’t fit very well into that blog, so here it is.

What are we trying to do here?

If you recall from the first blog, we need to be able to compute the value and gradients of the (un-normalised) log-posterior \[ \log(p(\theta \mid y)) = \frac{1}{2} \mu_{u\mid y, \theta}(\theta)^TA^TW^{-1}y + \frac{1}{2} \log(|Q(\theta)|) - \frac{1}{2}\log(|Q_{u\mid y, \theta}(\theta)|) + \text{const}, \] where \(Q(\theta)\) is a sparse matrix, and \[ \mu_{u\mid y, \theta}(\theta) = \frac{1}{\sigma^2} Q_{u\mid y,\theta}(\theta)^{-1} A^TW^{-1}y. \]

Overall, our task is to design a system where this un-normalised log-posterior can be evaluated and differentiated efficiently. As with all design problems, there are a lot of different ways that we can implement it. They share a bunch of similarities, so we will actually end up implementing the guts of all of the systems.

To that end, let’s think of all of the ways we can implement our target2.

Option 1: The direct design

  • \(A \rightarrow \log(|A|)\), for a sparse, symmetric positive definite matrix \(A\)
  • \((A,b) \rightarrow A^{-1}b\), for a sparse, symmetric positive definite matrix \(A\) and a vector \(b\)

This option is, in some sense, the most straightforward. We implement primitives for both of the major components of our target and combine them using existing JAX primitives (like addition, scalar multiplication, and dot products).

This is a bad idea.

The problem is that both primitives require the Cholesky decomposition of \(A\), so if we take this route we might end up computing an extra Cholesky decomposition. And you may ask yourself: what’s an extra Cholesky decomposition between friends?

Well, Jonathan, it’s the most expensive operation we are doing for these models, so perhaps we should avoid the 1/3 increase in running time!

There are some ways around this. We might implement sparse, symmetric positive definite matrices as a class that, upon instantiation, computes the Cholesky factorisation.

class SPDSparse: 
  def __init__(self, A_indices, A_indptr, A_x):
    self._perm, self._iperm = _find_perm(A_indices, A_indptr)
    self._A_indices, self._A_indptr, self._A_x = _twist(self._perm, A_indices, A_indptr, A_x)
      self._L_indices, self._L_indptr, self._L_x = _compute_cholesky()
    except SPDError:
      print("Matrix is not symmetric positive definite to machine precision.")
  def _find_perm(self, indices, indptr):
    """Finds the best fill-reducing permutation"""
    raise NotImplemented("_find_perm")
  def _twist(self, perm, indices, indptr, x):
    """Returns A[perm, perm]"""
    raise NotImplemented("_twist")
  def _compute_cholesky():
    """Compute the Cholesky decomposition of the permuted matrix"""
    raise NotImplemented("_compute_cholesky")
  # Not pictured: a whole forest of gets

In contexts where we need a Cholesky decomposition of every SPD matrix we instantiate, this design might be useful. It might also be useful to write a constructor that takes a jax.experimental.CSCMatrix, so that we could build a differentiable matrix and then just absolutely slam it into our filthy little Cholesky context3.

In order to use this type of pattern with JAX, we would need to register it as a Pytree class, which involves writing flatten and unflatten routines. The CSCSparse class is a good example of how to implement this type of thing. Some care would be needed to make sure the differentiation rules don’t try to do something stupid like differentiate with respect to self.iperm or self.L_x. This is beyond the extra autodiff sugar in the experimental sparse library.

Implementing this would be quite an undertaking, but it’s certainly an option. The most obvious downside of this pattern (plus a fully functional sparse matrix class) is that it may end up being quite delicate to have this volume of auxillary information4 in a pytree while making everything differentiate properly. This doesn’t seem to be how most parts of JAX has been built. There are also a couple of sharp corners we could run into with instantiation.

To close this out, it’s worth noting a variation on this pattern that comes up: the optional Cholesky. The idea is that rather than compute the permutations and the Cholesky factorisation on initialisation, we store a boolean flag in the class is_cholesky and, whenever we need a Cholesky factor we check is_cholesky and if it’s True we use the computed Cholesky factor and otherwise we compute it and set is_cholesky = True.

This pattern introduces state to the object: it is no longer set and forget. This will not work within JAX5, where objects need to be immutable. It’s also not an exceptional pattern in general: it is considerably easier to debug code with stateless objects.

Option 2: Implement all of the combinations of functions that we need

Rather than dicking around with classes, we could just implement primitives that compute

  • \(A \rightarrow \log(|A|)\), for a sparse, symmetric positive definite matrix \(A\)
  • \((A,b, c) \rightarrow \log(|A|) + c^TA^{-1}b\), for a sparse, symmetric positive definite matrix \(A\) and vectors \(b\) and \(c\).

This is exactly what we need to do our task and nothing more. It won’t result in any unnecessary Cholesky factors. It doesn’t need us to store computed Cholesky factors. We can simply eat, prey, love.

The obvious downside to this option is it’s going to just massively expand the codebase if there are more things that we want to do. It’s also not obvious why we would do this instead of just making \(\log p(\theta \mid y)\) a primitive6.

Option 3: Just compute the Cholesky

Our third option is to simply compute (and differentiate) the Cholesky factor directly. We can then compute \(\log(|A|)\) and \(A^{-1}b\) through a combination of differentiable operations on the elements of the Cholesky factor (for \(\log(|A|)\)) and triangular linear solves \(L^{-1}b\) and \(L^{-T}c\) (for \(A^{-1}b\)).

Hence we require the following two7 JAX primitives:

  • \(A \rightarrow \operatorname{chol}(A)\), where \(\operatorname{chol}(A)\) is the Cholesky factor of \(A\),
  • \((L, b) \rightarrow L^{-1} b\) and \((L, b) \rightarrow L^{-T}b\) for lower-triangular sparse matrix \(L\).

This is pretty close to how the dense version of this function would be implemented.

There are two little challenges with this pattern:

  1. We are adding another large-ish node \(L\) to our autodiff tree. As we saw in other patterns, this is unnecessary storage for our problem at hand.

  2. The number of non-zeros in \(L\) is a function of the non-zero pattern of \(A\). This means the Cholesky will need to be implemented very carefully to ensure that its traceable enough.

The second point here might actually be an issue. To be honest, I have no idea. I think maybe it’s fine? But I need to do a close read on the adding primitives doc. Essentially, as long as the abstract traces just need shapes but not dimensions, we should be ok.

For adding this to something like Stan, however, we will likely need to do some extra work to make sure we know the number of parameters.

The advantage of this type of design pattern is that it gives users the flexibility to do whatever perverted thing they want to do with the Cholesky triangle. For example, they might want to do a centring/non-centring transformation. In Option 1, we would need to write explicit functions to let them do that (not difficult, but there’s a lot of code to write, which has the annoying tendency to increases the maintainence burden).

Option 4: Functors!

A slightly wilder design pattern would be to abandon sparse matrices and just make functions A(theta, ...) that return a sparse matrix. If that function is differentiable wrt its first argument, then we can build this whole thing up that way.

In reality, the only way I can think of to implement this pattern would be to implement a whole differentiable sparse matrix arithmetic (make operations like alpha * A + beta * B, C * D work for sparse matrices). At which point, we’ve basically just recreated option 1.

I’m really only bringing up functors because unlike sparse matrices, it is actually a pretty good model for implementing Gaussian Processes with general covariance functions. There’s a little bit of the idea in this Stan issue that, to my knowledge, hasn’t gone anywhere. More recently, a variant has been used successfully in the (as yet un-merged) Laplace approximation feature in Stan.

Which one should we use?

We don’t really need to make that choice yet. So we won’t.

But personally, I like option 1. I expect everyone else on earth would prefer option 3. For densities that see a lot of action, it would make quite a bit of sense to consider making that density a primitive when it has a complex derivative (à la option 2).

But for now, let’s park this and start getting in on the implementations.


  1. functions that have explicit transformations written for them (eg explicit instruction on how to JIT or how to differentiate)↩︎

  2. I get sick of typing “unnormalised log-posterior”↩︎

  3. I am sorry. I have had some wine.↩︎

  4. Permuations, cholesky, etc↩︎

  5. This also won’t work in Stan, because all Stan objects are stateless.↩︎

  6. This is actually what Stan has done for a bunch of its GLM-type models. It’s very efficient and fast. But with a maintainance burden.↩︎

  7. or three, but you can implement both triangular solves in one function↩︎



BibTeX citation:
  author = {Simpson, Dan},
  title = {Sparse {Matrices} 4: {Design} Is My Passion},
  date = {2022-05-16},
  url = {},
  langid = {en}
For attribution, please cite this work as:
Simpson, Dan. 2022. “Sparse Matrices 4: Design Is My Passion.” May 16, 2022.