Sparse Matrices 3: Failing at JAX

Takes a long drag on cigarette. JAX? Where was he when I had my cancer?

Sparse matrices
Sparse Cholesky factorisation

May 14, 2022

This is part three of an ongoing exercise in hubris. Part one is here. Part two is here. The overall aim of this series of posts is to look at how sparse Cholesky factorisations work, how JAX works, and how to marry the two with the ultimate aim of putting a bit of sparse matrix support into PyMC, which should allow for faster inference in linear mixed models, Gaussian spatial models. And hopefully, if anyone ever gets around to putting the Laplace approximation in, all sorts of GLMMs and non-Gaussian models with splines and spatial effects.

It’s been a couple of weeks since the last blog, but I’m going to just assume that you are fully on top of all of those details. To that end, let’s jump in.

What is JAX?

JAX is a minor miracle. It will take python+numpy code and make it cool. It will let you JIT1 compile it! It will let you differentiate it! It will let you batch2. JAX refers to these three operations as transformations.

But, as The Mountain Goats tell us God is present in the sweeping gesture, but the devil is in the details. And oh boy are those details going to be really fucking important to us.

There are going to be two key things that will make our lives more difficult:

  1. Not every operation can be transformed by every operation. For example, you can’t always JIT or take gradients of a for loop. This means that some things have to be re-written carefully to make sure it’s possible to get the advantages we need.

  2. JAX arrays are immutable. That means that once a variable is defined it cannot be changed. This means that things like a = a + 1 is not allowed! If you’ve come from an R/Python/C/Fortran world, this is the weirdest thing to deal with.

There are really excellent reasons for both of these restrictions. And looking into the reasons is fascinating. But not a topic for this blog3

JAX has some pretty decent4 documentation, a core piece of which outlines some of the sharp edges you will run into. As you read through the documentation, the design choices become clearer.

So let’s go and find some sharp edges together!

To JAX or not to JAX

But first, we need to ask ourselves which functions do we need to JAX?

In the context of our problem we, so far, have three functions:

  1. _symbolic_factor_csc(A_indices, A_indptr), which finds the non-zero indices of the sparse Cholesky factor and return them in CSC format,
  2. _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr), which takes the entries of the matrix \(A\) and re-creates them so they can be indexed within the larger pattern of non-zero elements of \(L\),
  3. _sparse_cholesky_csc_impl(L_indices, L_indptr, L_x), which actually does the sparse Cholesky factorisation.

Let’s take them piece by piece, which is also a good opportunity to remind everyone what the code looked like.

Symbolic factorisation

def _symbolic_factor_csc(A_indices, A_indptr):
  # Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.
  n = len(A_indptr) - 1
  L_sym = [np.array([], dtype=int) for j in range(n)]
  children = [np.array([], dtype=int) for j in range(n)]
  for j in range(n):
    L_sym[j] = A_indices[A_indptr[j]:A_indptr[j + 1]]
    for child in children[j]:
      tmp = L_sym[child][L_sym[child] > j]
      L_sym[j] = np.unique(np.append(L_sym[j], tmp))
    if len(L_sym[j]) > 1:
      p = L_sym[j][1]
      children[p] = np.append(children[p], j)
  L_indptr = np.zeros(n+1, dtype=int)
  L_indptr[1:] = np.cumsum([len(x) for x in L_sym])
  L_indices = np.concatenate(L_sym)
  return L_indices, L_indptr

This function only needs to be computed once per non-zero pattern. In the applications I outlined in the first post, this non-zero pattern is fixed. This means that you only need to run this function once per analysis (unlike the others, that you will have to run once per iteration!).

As a general rule, if you only do something once, it isn’t all that necessary to devote too much time into optimising it. There are, however, some obvious things we could do.

It is, for instance, pretty easy to see how you would implement this with an explicit tree5 structure instead of constantly np.appending the children array. This is far better from a memory standpoint.

It’s also easy to imagine this as a two-pass algorithm, where you build the tree and count the number of non-zero elements in the first pass and then build and populate L_indices in the second pass.

The thing is, neither of these things fixes the core problem for using JAX to JIT this: the dimensions of the internal arrays depend on the values of the inputs. This is not possible.

It seems like this would be a huge limitation, but in reality it isn’t. Most functions aren’t like this one! And, if we remember that JAX is a domain language focussing mainly on ML applications, this is very rarely the case. It is always good to remember context!

So what are our options? We have two.

  1. Leave it in Python and just eat the speed.
  2. Build a new JAX primitive and write the XLA compilation rule6.

Today are opting for the first option!

The structure-changing copy

def _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr):
  n = len(A_indptr) - 1
  L_x = np.zeros(len(L_indices))
  for j in range(0, n):
    copy_idx = np.nonzero(np.in1d(L_indices[L_indptr[j]:L_indptr[j + 1]],
    L_x[L_indptr[j] + copy_idx] = A_x[A_indptr[j]:A_indptr[j+1]]
  return L_x

This is, fundamentally, a piece of bookkeeping. An annoyance of sparse matrices. Or, if you will, explicit cast between different sparse matrix types7. This is a thing that we do actually need to be able to differentiate, so it needs to live in JAX.

So where are the potential problems? Let’s go line by line.

  1. n = len(A_indptr) - 1: This is lovely. n is used in a for loop later, but because it is a function of the shape of A_indptr, it is considered static and we will be able to JIT over it!

  2. L_x = np.zeros(len(L_indices)): Again, this is fine. Sizes are derived from shapes, life is peachy.

  3. for j in range(0, n):: This could be a problem if n was an argument or derived from values of the arguments, but it’s derived from a shape so it is static. Praise be! Well, actually it’s a bit more involved than that.

The problem with the for loop is what will happen when it is JIT’d. Essentially, the loop will be statically unrolled8. That is fine for small loops, but it’s a bit of a pain in the arse when n is large.

In this case, we might want to use the structured control flow in jax.lax9 In this case we would need jax.lax.fori_loop(start, end, body_fun, init_value). This makes the code look less pythonic, but probably should make it faster. It is also, and I cannot stress this enough, an absolute dick to use.

(In actuality, we will see that we do not need this particular corner of the language here!)

  1. copy_idx = np.nonzero(...): This looks like it’s going to be complicated, but actually it is a perfectly reasonable composition of numpy functions. Hence, we can use the same jax.numpy functions with minimal changes. The one change that we are going to need to make in order to end up with a JIT-able and differentiable function is that we need to tell JAX how many non-zero elements there are. Thankfully, we know this! Because the non-zero pattern of \(A\) is a subset of the non-zero pattern of \(L\), we know that
np.in1d(L_indices[L_indptr[j]:L_indptr[j + 1]], A_indices[A_indptr[j]:A_indptr[j+1]])

will have exactly len(A_indices[A_indptr[j]:A_indptr[j+1]]) True values, and so np.nonzero(...) will have that many. We can pass this information to jnp.nonzero() using the optional size argument.

Oh no! We have a problem! This return size is a function of the values of A_indptr rather than a function of the shape. This means we’re a bit fucked.

There are two routes out:

  1. Declare A_indptr to be a static parameter, or
  2. Change the representation from CSC to something more convenient.

In this case we could do either of these things, but I’m going to opt for the second option, as it’s going to be more useful going forward.

But before we do that, let’s look at the final line in the code.

  1. L_x[L_indptr[j] + copy_idx] = A_x[A_indptr[j]:A_indptr[j+1]]: The final non-trivial line of the code is also a problem. The issue is that these arrays are immutable and we are asking to change the values! That is not allowed!

The solution here is to use a clunkier syntax. In JAX, we need to replace

x[ind] = a

with the less pleasant

x =[ind].set(a)

What is going on under the hood to make the second option ok while the first is an error is well beyond the scope of this little post. But the important thing is that they compile down to an in-place10 update, which is all we really care about.

Re-doing the data structure.

Ok. So we need a new data structure. That’s annoying. The rule, I guess, is always that if you need to innovate, you should innovate very little if you can get away with it, or a lot if you have to.

We are going to innovate only the tiniest of bits.

The idea is to keep the core structure of the CSC data structure, but to replace the indptr array with explicitly storing the row indices and row values as a list of np.arrays. So A_index will now be a list of n arrays that contain the row indices of the non-zero elements of \(A\), while A_xwill now be a list of n arrays that contain the values of the non-zero elements of \(A\).

This means that the matrix \[ B = \begin{pmatrix} 1 &&5 \\ 2&3& \\ &4&6 \end{pmatrix} \] would be stored as

B_index = [np.array([0,1]), np.array([1,2]), np.array([0,2])]
B_x = [np.array([1,2]), np.array([3,4]), np.array([5,6])]

This is a considerably more pythonic11 version of CSC. So I guess that’s an advantage.

We can easily go from CSC storage to this modified storage.

def to_pythonic_csc(indices, indptr, x):
  index = np.split(indices, indptr[1:-1])
  x = np.split(x, indptr[1,-1])
  return index, x

A JAX-tracable structure-changing copy

So now it’s time to come back to that damn for loop. As flagged earlier, for loops can be a bit picky in JAX. If we use them as is, then the code that is generated and then compiled is unrolled. You can think of this as if the JIT compiler automatically writes a C++ program and then compiles it. If you were to examine that code, the for loop would be replaced by n almost identical blocks of code with only the index j changing between them. This leads to a potentially very large program to compile12 and it limits the compiler’s ability to do clever things to make the compiled code run faster13.

The lax.fori_loop() function, on the other hand, compiles down to the equivalent of a single operation14. This lets the compiler be super clever.

But we don’t actually need this here. Because if you take a look at the original for loop we are just applying the same two lines of code to each triple of lists in A_index, A_x, and L_index (in our new15 data structure).

This just screams out for a map applying a single function independently to each column.

The challenge is to find the right map function. An obvious hope would be jax.vmap. Sadly, jax.vmap does not do that. (At least not without more padding16 than a drag queen.) The problem here is a misunderstanding of what different parts of JAX are for. Functions like jax.vmap are made for applying the same function to arrays of the same size. This makes sense in their context. (JAX is, after all, made for machine learning and these shape assumptions fit really well in that paradigm. They just don’t fit here.)

And I won’t lie. After this point I went wild. did not help. And I honest to god tried lax.scan, which is will solve the problem but at what cost?.

But at some point, you read enough of the docs to find the answer.

The correct answer here is to use the JAX concept of a pytree. Pytrees are essentially17 lists of arrays. They’re very flexible and they have a jax.tree_map function that lets you map over them! We are saved!

import numpy as np
from jax import numpy as jnp
from jax import tree_map

def _structured_copy_csc(A_index, A_x, L_index):
    def body_fun(A_rows, A_vals, L_rows):
      out = jnp.zeros(len(L_rows))
      copy_idx =  jnp.nonzero(jnp.in1d(L_rows, A_rows), size = len(A_rows))[0] 
      out =[copy_idx].set(A_vals)
      return out
    L_x = tree_map(body_fun, A_index, A_x, L_index)
    return L_x

Testing it out

Ok so now lets see if it works. To do that I’m going to define a very simple function \[ f(A, \alpha, \beta) = \|\alpha I + \beta \operatorname{tril}(A)\|_F^2, \] that is the sum of the squares of all of the elements of \(\alpha I + \beta \operatorname{tril}(A)\). There’s obviously an easy way to do this, but I’m going to do it in a way that uses the function we just built.

def test_func(A_index, A_x, params):
  I_index = [jnp.array([j]) for j in range(len(A_index))]
  I_x = [jnp.array([params[0]]) for j in range(len(A_index))]
  I_x2 = _structured_copy_csc(I_index, I_x, A_index)
  return jnp.sum((jnp.concatenate(I_x2) + params[1] * jnp.concatenate(A_x))**2)

Next, we need a test case. Once again, we will use the 2D Laplacian on a regular \(n \times n\) grid (up to a scaling). This is a nice little function because it’s easy to make test problems of different sizes.

from scipy import sparse

def make_matrix(n):
    one_d = sparse.diags([[-1.]*(n-1), [2.]*n, [-1.]*(n-1)], [-1,0,1])
    A_lower = sparse.tril(sparse.kronsum(one_d, one_d) + sparse.eye(n*n), format = "csc")
    A_index = jnp.split(jnp.array(A_lower.indices), A_lower.indptr[1:-1])
    A_x = jnp.split(jnp.array(, A_lower.indptr[1:-1])
    return (A_index, A_x)

With our test case in hand, we can check to see if JAX will differentiate for us!

from jax import grad, jit
from jax.test_util import check_grads

grad_func = grad(test_func, argnums = 2)

A_index, A_x = make_matrix(50)
print(f"The value at (2.0, 2.0) is {test_func(A_index, A_x, (2.0, 2.0))}.")
print(f"The gradient is {np.array(grad_func(A_index, A_x, (2.0, 2.0)))}.")
The value at (2.0, 2.0) is 379600.0.
The gradient is [ 60000. 319600.].

Fabulous! That works!

But what about JIT?

JIT took fucking ages. I’m talking “it threw a message” amounts of time. I’m not even going to pretend that I understand why. But I can hazard a guess.

My running assumption, taken from the docs, is that as long as the function only relies of quantities that are derived from the shapes of the inputs (and not the values), then JAX will be able to trace through and JIT through the functions with ease.

This might not be true for tree_maps. The docs are, as far as I can tell, silent on this matter. And a cursory look through the github repo did not give me any hints as to how tree_map() is translated.

Let’s take a look to see if this is true.

import timeit
from functools import partial
jit_test_func = jit(test_func)

A_index, A_x = make_matrix(5)
times = timeit.repeat(partial(jit_test_func, A_index, A_x, (2.0, 2.0)), number = 1)
print(f"n = 5: {[round(t, 4) for t in times]}")
n = 5: [1.6695, 0.0001, 0.0, 0.0, 0.0]

We can see that the first run includes compilation time, but after that it runs a bunch faster. This is how a JIT system is supposed to work! But the question is: will it recompile when we run it for a different matrix?

_ = jit_test_func(A_index, A_x, (2.0, 2.0)) 
A_index, A_x = make_matrix(20)
times = timeit.repeat(partial(jit_test_func, A_index, A_x, (2.0, 2.0)), number = 1)
print(f"n = 20: {[round(t, 4) for t in times]}")
n = 20: [38.5779, 0.0006, 0.0003, 0.0003, 0.0003]

Damn. It recompiles. But, as we will see, it does not recompile if we only change A_x.

# What if we change A_x only
_ = jit_test_func(A_index, A_x, (2.0, 2.0)) 
A_x2 = tree_map(lambda x: x + 1.0, A_x)
times = timeit.repeat(partial(jit_test_func, A_index, A_x2, (2.0, 2.0)), number = 1)
print(f"n = 20, new A_x: {[round(t, 4) for t in times]}")
n = 20, new A_x: [0.0006, 0.0007, 0.0005, 0.0003, 0.0003]

This gives us some hope! This is because the structure of A (aka A_index) is fixed in our application, but the values A_x changes. So as long as the initial JIT compilation is reasonable, we should be ok.

Unfortunately, there is something bad happening with the compilation. For \(n=10\), it takes (on my machine) about 2 seconds for the initial compilation. For \(n=20\), that increases to 16 seconds. Once \(n = 30\), this balloons up to 51 seconds. Once we reach the lofty peaks18 of \(n=40\), we are up at 149 seconds to compile.

This is not good. The function we are JIT-ing is very simple: just one tree_map. I do not know enough19 about the internals of JAX, so I don’t want to speculate too wildly. But it seems like it might be unrolling the tree_map before compilation, which is … bad.

Let’s admit failure

Ok. So that didn’t bloody work. I’m not going to make such broad statements as you can’t use the JAX library in python to write a transformable sparse Cholesky factorisation, but I am more than prepared to say that I cannot do such a thing.

But, if I’m totally honest, I’m not enormously surprised. Even in looking at the very simple operation we focussed on today, it’s pretty clear that the operations required to work on a sparse matrix don’t look an awful lot like the types of operations you need to do the types of machine learning work that is JAX’s raison d’être.

And it is never surprising to find that a library designed to do a fundamentally different thing does not easily adapt to whatever random task I decide to throw at it.

But there is a light: JAX is an extensible language. We can build a new JAX primitive (or, new JAX primitives) and manually write all of the transformations (batching, JIT, and autodiffing).

And that is what we shall do next! It’s gonna be a blast!


  1. If you’ve never come across this term before, you can Google it for actual details, but the squishy version is that it will compile your code so it runs fast (like C code) instead of slow (like python code). JIT stands for just in time, which means that the code is compiled when it’s needed rather than before everything else is run. It’s a good thing. It makes the machine go bing faster.↩︎

  2. I give less of a shit about the third transformation in this context. I’m not completely sure what you would batch when you’re dealing with a linear mixed-ish model. But hey. Why not.↩︎

  3. If you’ve ever spoken to a Scala advocate (or any other pure functional language), you can probably see the edges of why the arrays need to be immutable.\[ \phantom{a} \] Restrictions to JIT-able control flow has to do with how it’s translated onto the XLA compiler, which involves tracing through the code with an abstract data type with the same shape as the one that it’s being called with. Because this abstract data type does not have any values, structural parts of the code that require knowledge of specific values of the arguments will be lost. You can get around this partially by declaring those important values to be static, which would make the JIT compiler re-compile the function each time that value changes. We are not going to do that. \[ \phantom{a} \] Restrictions to gradients have to do (I assume) with reverse-mode autodiff needing to construct the autodiff tree at compile time, which means you need to be able to compute the number of operations from the types and shapes of the input variables and not from their values.↩︎

  4. Coverage is pretty good on the using bit, but, as is usual, the bits on extending the system are occasionally a bit … sparse. (What in the hairy Christ is a transposition rule actually supposed to do????)↩︎

  5. Forest↩︎

  6. aka implement the damn thing in C++ and then do some proper work on it.↩︎

  7. It is useful to think of a sparse matrix type as the triple (value_type, indices, indptr). This means that if we are going to do something like add sparse matrices, we need to first cast them both to have the same type. After the cast, addition of two different sparse matrices becomes the addition of their x attributes. The same holds for scalar multiplication. Sparse matrix-matrix multiplication is a bit different because you once again need to symbolically work out the sparsity structure (aka the type) of the product. ↩︎

  8. I think. That’s certainly what’s implied by the docs, but I don’t want to give the impression that I’m sure. Because this is complicated.↩︎

  9. What is jax.lax? Oh honey you don’t want to know.↩︎

  10. aka there’s no weird copying↩︎

  11. Whatever that means anyway↩︎

  12. slowwwwww to compile↩︎

  13. The XLA compiler does very clever things. Incidentally, loop unrolling is actually one of the optimisations that compilers have in their pocket. Just not one that’s usually used for loops as large as this.↩︎

  14. Read about XLA High Level Operations (HLOs) here. The XLA documentation is not extensive, but there’s still a lot to read.↩︎

  15. This is why we have a new data structure.↩︎

  16. My kingdom for a ragged array.↩︎

  17. Yes. They are more complicated than this. But for our purposes they are lists of arrays.↩︎

  18. \(n=50\) takes so long it prints a message telling us what to do if we need to do if we want to file a bug! Compilation eventually clocks in at 361 seconds.↩︎

  19. aka I know sweet bugger all↩︎



BibTeX citation:
  author = {Simpson, Dan},
  title = {Sparse {Matrices} 3: {Failing} at {JAX}},
  date = {2022-05-14},
  url = {},
  langid = {en}
For attribution, please cite this work as:
Simpson, Dan. 2022. “Sparse Matrices 3: Failing at JAX.” May 14, 2022.