```
def _symbolic_factor_csc(A_indices, A_indptr):
# Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.
n = len(A_indptr) - 1
L_sym = [np.array([], dtype=int) for j in range(n)]
children = [np.array([], dtype=int) for j in range(n)]
for j in range(n):
L_sym[j] = A_indices[A_indptr[j]:A_indptr[j + 1]]
for child in children[j]:
tmp = L_sym[child][L_sym[child] > j]
L_sym[j] = np.unique(np.append(L_sym[j], tmp))
if len(L_sym[j]) > 1:
p = L_sym[j][1]
children[p] = np.append(children[p], j)
L_indptr = np.zeros(n+1, dtype=int)
L_indptr[1:] = np.cumsum([len(x) for x in L_sym])
L_indices = np.concatenate(L_sym)
return L_indices, L_indptr
```

This is part three of an ongoing exercise in hubris. Part one is here. Part two is here. The overall aim of this series of posts is to look at how sparse Cholesky factorisations work, how JAX works, and how to marry the two with the ultimate aim of putting a bit of sparse matrix support into PyMC, which should allow for faster inference in linear mixed models, Gaussian spatial models. And hopefully, if anyone ever gets around to putting the Laplace approximation in, all sorts of GLMMs and non-Gaussian models with splines and spatial effects.

It’s been a couple of weeks since the last blog, but I’m going to just assume that you are fully on top of all of those details. To that end, let’s jump in.

## What is JAX?

JAX is a minor miracle. It will take python+numpy code and make it cool. It will let you JIT^{1} compile it! It will let you differentiate it! It will let you batch^{2}. JAX refers to these three operations as *transformations*.

But, as The Mountain Goats tell us *God is present in the sweeping gesture, but the devil is in the details*. And oh boy are those details going to be really fucking important to us.

There are going to be two key things that will make our lives more difficult:

Not every operation can be transformed by every operation. For example, you can’t always JIT or take gradients of a

`for`

loop. This means that some things have to be re-written carefully to make sure it’s possible to get the advantages we need.JAX arrays are

*immutable*. That means that once a variable is defined it*cannot be changed*. This means that things like`a = a + 1`

is not allowed! If you’ve come from an R/Python/C/Fortran world, this is the weirdest thing to deal with.

There are really excellent reasons for both of these restrictions. And looking into the reasons is fascinating. But not a topic for this blog^{3}

JAX has some pretty decent^{4} documentation, a core piece of which outlines some of the sharp edges you will run into. As you read through the documentation, the design choices become clearer.

So let’s go and find some sharp edges together!

## To JAX or not to JAX

But first, we need to ask ourselves *which functions do we need to JAX*?

In the context of our problem we, so far, have three functions:

`_symbolic_factor_csc(A_indices, A_indptr)`

, which finds the non-zero indices of the sparse Cholesky factor and return them in CSC format,`_deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr)`

, which takes the*entries*of the matrix \(A\) and re-creates them so they can be indexed within the larger pattern of non-zero elements of \(L\),`_sparse_cholesky_csc_impl(L_indices, L_indptr, L_x)`

, which actually does the sparse Cholesky factorisation.

Let’s take them piece by piece, which is also a good opportunity to remind everyone what the code looked like.

## Symbolic factorisation

This function only needs to be computed once per non-zero pattern. In the applications I outlined in the first post, this non-zero pattern is *fixed*. This means that you only need to run this function *once* per analysis (unlike the others, that you will have to run once per iteration!).

As a general rule, if you only do something once, it isn’t all that necessary to devote *too much* time into optimising it. There are, however, some obvious things we could do.

It is, for instance, pretty easy to see how you would implement this with an explicit tree^{5} structure instead of constantly `np.append`

ing the `children`

array. This is *far* better from a memory standpoint.

It’s also easy to imagine this as a two-pass algorithm, where you build the tree and count the number of non-zero elements in the first pass and then build and populate `L_indices`

in the second pass.

The thing is, neither of these things fixes the core problem for using JAX to JIT this: the dimensions of the internal arrays depend on the *values* of the inputs. This is not possible.

It seems like this would be a huge limitation, but in reality it isn’t. Most functions aren’t like this one! And, if we remember that JAX is a domain language focussing mainly on ML applications, this is *very rarely* the case. It is always good to remember context!

So what are our options? We have two.

- Leave it in Python and just eat the speed.
- Build a new JAX primitive and write the XLA compilation rule
^{6}.

Today are opting for the first option!

## The structure-changing copy

```
def _deep_copy_csc(A_indices, A_indptr, A_x, L_indices, L_indptr):
n = len(A_indptr) - 1
L_x = np.zeros(len(L_indices))
for j in range(0, n):
copy_idx = np.nonzero(np.in1d(L_indices[L_indptr[j]:L_indptr[j + 1]],
A_indices[A_indptr[j]:A_indptr[j+1]]))[0]
L_x[L_indptr[j] + copy_idx] = A_x[A_indptr[j]:A_indptr[j+1]]
return L_x
```

This is, fundamentally, a piece of bookkeeping. An annoyance of sparse matrices. Or, if you will, explicit *cast* between different sparse matrix types^{7}. This is a thing that we do actually need to be able to differentiate, so it needs to live in JAX.

So where are the potential problems? Let’s go line by line.

`n = len(A_indptr) - 1`

: This is lovely.`n`

is used in a for loop later, but because it is a function of the*shape*of`A_indptr`

, it is considered static and we will be able to JIT over it!`L_x = np.zeros(len(L_indices))`

: Again, this is fine. Sizes are derived from shapes, life is peachy.`for j in range(0, n):`

: This could be a problem if`n`

was an argument or derived from*values*of the arguments, but it’s derived from a shape so it is static. Praise be! Well, actually it’s a bit more involved than that.

The problem with the `for`

loop is what will happen when it is JIT’d. Essentially, the loop will be statically unrolled^{8}. That is fine for small loops, but it’s a bit of a pain in the arse when `n`

is large.

In this case, we might want to use the structured control flow in `jax.lax`

^{9} In this case we would need `jax.lax.fori_loop(start, end, body_fun, init_value)`

. This makes the code look less *pythonic*, but probably should make it faster. It is also, and I cannot stress this enough, an absolute dick to use.

(In actuality, we will see that we do not need this particular corner of the language here!)

`copy_idx = np.nonzero(...)`

: This looks like it’s going to be complicated, but actually it is a perfectly reasonable composition of`numpy`

functions. Hence, we can use the same`jax.numpy`

functions with minimal changes. The one change that we are going to need to make in order to end up with a JIT-able and differentiable function is that we need to tell JAX how many non-zero elements there are. Thankfully, we know this! Because the non-zero pattern of \(A\) is a subset of the non-zero pattern of \(L\), we know that

will have exactly `len(A_indices[A_indptr[j]:A_indptr[j+1]])`

`True`

values, and so `np.nonzero(...)`

will have that many. We can pass this information to `jnp.nonzero()`

using the optional `size`

argument.

**Oh no! We have a problem!** This return size is *a function of the values* of `A_indptr`

rather than a function of the shape. This means we’re a bit fucked.

There are two routes out:

- Declare
`A_indptr`

to be a static parameter, or - Change the representation from CSC to something more convenient.

In this case we could do either of these things, but I’m going to opt for the second option, as it’s going to be more useful going forward.

But before we do that, let’s look at the final line in the code.

`L_x[L_indptr[j] + copy_idx] = A_x[A_indptr[j]:A_indptr[j+1]]`

: The final non-trivial line of the code is also a problem. The issue is that these arrays are*immutable*and we are asking to change the values! That is not allowed!

The solution here is to use a clunkier syntax. In JAX, we need to replace

with the less pleasant

What is going on under the hood to make the second option ok while the first is an error is well beyond the scope of this little post. But the important thing is that they *compile down* to an in-place^{10} update, which is all we really care about.

## Re-doing the data structure.

Ok. So we need a new data structure. That’s annoying. The rule, I guess, is always that if you need to innovate, you should innovate very little if you can get away with it, or a lot if you have to.

We are going to innovate only the tiniest of bits.

The idea is to keep the core structure of the CSC data structure, but to replace the `indptr`

array with explicitly storing the row indices and row values as a *list* of `np.arrays`

. So `A_index`

will now be a *list* of `n`

arrays that contain the row indices of the non-zero elements of \(A\), while `A_x`

will now be a *list* of `n`

arrays that contain the values of the non-zero elements of \(A\).

This means that the matrix \[ B = \begin{pmatrix} 1 &&5 \\ 2&3& \\ &4&6 \end{pmatrix} \] would be stored as

This is a considerably more *pythonic*^{11} version of CSC. So I guess that’s an advantage.

We can easily go from CSC storage to this modified storage.

## A JAX-tracable structure-changing copy

So now it’s time to come back to that damn `for`

loop. As flagged earlier, `for`

loops can be a bit picky in JAX. If we use them *as is*, then the code that is generated and then compiled is *unrolled*. You can think of this as if the JIT compiler automatically writes a C++ program and then compiles it. If you were to examine that code, the for loop would be replaced by `n`

almost identical blocks of code with only the index `j`

changing between them. This leads to a potentially very large program to compile^{12} and it limits the compiler’s ability to do clever things to make the compiled code run faster^{13}.

The `lax.fori_loop()`

function, on the other hand, compiles down to the equivalent of a single operation^{14}. This lets the compiler be super clever.

But we don’t actually need this here. Because if you take a look at the original for loop we are just applying the same two lines of code to each triple of lists in `A_index`

, `A_x`

, and `L_index`

(in our new^{15} data structure).

This just *screams* out for a map applying a single function independently to each column.

The challenge is to find the right map function. An obvious hope would be `jax.vmap`

. Sadly, `jax.vmap`

does not do that. (At least not without more padding^{16} than a drag queen.) The problem here is a misunderstanding of what different parts of JAX are for. Functions like `jax.vmap`

are made for applying the same function to arrays *of the same size*. This makes sense in their context. (JAX is, after all, made for machine learning and these shape assumptions fit really well in that paradigm. They just don’t fit here.)

And I won’t lie. After this point I went *wild*. `lax.map`

did not help. And I honest to god tried `lax.scan`

, which is will solve the problem but at what cost?.

But at some point, you read enough of the docs to find the answer.

The correct answer here is to use the JAX concept of a `pytree`

. Pytrees are essentially^{17} lists of arrays. They’re very flexible and they have a `jax.tree_map`

function that lets you map over them! We are saved!

```
import numpy as np
from jax import numpy as jnp
from jax import tree_map
def _structured_copy_csc(A_index, A_x, L_index):
def body_fun(A_rows, A_vals, L_rows):
out = jnp.zeros(len(L_rows))
copy_idx = jnp.nonzero(jnp.in1d(L_rows, A_rows), size = len(A_rows))[0]
out = out.at[copy_idx].set(A_vals)
return out
L_x = tree_map(body_fun, A_index, A_x, L_index)
return L_x
```

### Testing it out

Ok so now lets see if it works. To do that I’m going to define a very simple function \[ f(A, \alpha, \beta) = \|\alpha I + \beta \operatorname{tril}(A)\|_F^2, \] that is the sum of the squares of all of the elements of \(\alpha I + \beta \operatorname{tril}(A)\). There’s obviously an easy way to do this, but I’m going to do it in a way that uses the function we just built.

Next, we need a test case. Once again, we will use the 2D Laplacian on a regular \(n \times n\) grid (up to a scaling). This is a nice little function because it’s easy to make test problems of different sizes.

```
from scipy import sparse
def make_matrix(n):
one_d = sparse.diags([[-1.]*(n-1), [2.]*n, [-1.]*(n-1)], [-1,0,1])
A_lower = sparse.tril(sparse.kronsum(one_d, one_d) + sparse.eye(n*n), format = "csc")
A_index = jnp.split(jnp.array(A_lower.indices), A_lower.indptr[1:-1])
A_x = jnp.split(jnp.array(A_lower.data), A_lower.indptr[1:-1])
return (A_index, A_x)
```

With our test case in hand, we can check to see if JAX will differentiate for us!

```
from jax import grad, jit
from jax.test_util import check_grads
grad_func = grad(test_func, argnums = 2)
A_index, A_x = make_matrix(50)
print(f"The value at (2.0, 2.0) is {test_func(A_index, A_x, (2.0, 2.0))}.")
print(f"The gradient is {np.array(grad_func(A_index, A_x, (2.0, 2.0)))}.")
```

`The value at (2.0, 2.0) is 379600.0.`

`The gradient is [ 60000. 319600.].`

Fabulous! That works!

## But what about JIT?

JIT took fucking *ages*. I’m talking “it threw a message” amounts of time. I’m not even going to pretend that I understand why. But I can hazard a guess.

My running assumption, taken from the docs, is that as long as the function only relies of quantities that are derived from the *shapes* of the inputs (and not the values), then JAX will be able to trace through and JIT through the functions with ease.

This might not be true for `tree_map`

s. The docs are, as far as I can tell, silent on this matter. And a cursory look through the github repo did not give me any hints as to how `tree_map()`

is translated.

Let’s take a look to see if this is true.

```
import timeit
from functools import partial
jit_test_func = jit(test_func)
A_index, A_x = make_matrix(5)
times = timeit.repeat(partial(jit_test_func, A_index, A_x, (2.0, 2.0)), number = 1)
print(f"n = 5: {[round(t, 4) for t in times]}")
```

`n = 5: [1.6695, 0.0001, 0.0, 0.0, 0.0]`

We can see that the first run includes compilation time, but after that it runs a bunch faster. This is how a JIT system is supposed to work! But the question is: will it recompile when we run it for a different matrix?

```
_ = jit_test_func(A_index, A_x, (2.0, 2.0))
A_index, A_x = make_matrix(20)
times = timeit.repeat(partial(jit_test_func, A_index, A_x, (2.0, 2.0)), number = 1)
print(f"n = 20: {[round(t, 4) for t in times]}")
```

`n = 20: [38.5779, 0.0006, 0.0003, 0.0003, 0.0003]`

Damn. It recompiles. But, as we will see, it does not recompile if we only change `A_x`

.

```
# What if we change A_x only
_ = jit_test_func(A_index, A_x, (2.0, 2.0))
A_x2 = tree_map(lambda x: x + 1.0, A_x)
times = timeit.repeat(partial(jit_test_func, A_index, A_x2, (2.0, 2.0)), number = 1)
print(f"n = 20, new A_x: {[round(t, 4) for t in times]}")
```

`n = 20, new A_x: [0.0006, 0.0007, 0.0005, 0.0003, 0.0003]`

This gives us some hope! This is because the *structure* of A (aka `A_index`

) is fixed in our application, but the values `A_x`

changes. So as long as the initial JIT compilation is reasonable, we should be ok.

Unfortunately, there is something bad happening with the compilation. For \(n=10\), it takes (on my machine) about 2 seconds for the initial compilation. For \(n=20\), that increases to 16 seconds. Once \(n = 30\), this balloons up to 51 seconds. Once we reach the lofty peaks^{18} of \(n=40\), we are up at 149 seconds to compile.

This is not good. The function we are JIT-ing is *very* simple: just one `tree_map`

. I do not know enough^{19} about the internals of JAX, so I don’t want to speculate too wildly. But it seems like it might be unrolling the `tree_map`

before compilation, which is … bad.

## Let’s admit failure

Ok. So that didn’t bloody work. I’m not going to make such broad statements as *you can’t use the JAX library in python to write a transformable sparse Cholesky factorisation*, but I am more than prepared to say that *I* cannot do such a thing.

But, if I’m totally honest, I’m not *enormously* surprised. Even in looking at the very simple operation we focussed on today, it’s pretty clear that the operations required to work on a sparse matrix don’t look an awful lot like the types of operations you need to do the types of machine learning work that is JAX’s *raison d’être*.

And it is *never* surprising to find that a library designed to do a fundamentally different thing does not easily adapt to whatever random task I decide to throw at it.

But there is a light: JAX is an extensible language. We can build a new JAX primitive (or, new JAX primitives) and manually write all of the transformations (batching, JIT, and autodiffing).

And that is what we shall do next! It’s gonna be a blast!

## Footnotes

If you’ve never come across this term before, you can Google it for actual details, but the squishy version is that it will

*compile*your code so it runs fast (like C code) instead of slow (like python code). JIT stands for*just in time*, which means that the code is compiled when it’s needed rather than before everything else is run. It’s a good thing. It makes the machine go*bing*faster.↩︎I give less of a shit about the third transformation in this context. I’m not completely sure what you would batch when you’re dealing with a linear mixed-ish model. But hey. Why not.↩︎

If you’ve ever spoken to a Scala advocate (or any other pure functional language), you can probably see the edges of why the arrays need to be immutable.\[ \phantom{a} \] Restrictions to JIT-able control flow has to do with how it’s translated onto the XLA compiler, which involves

*tracing*through the code with an abstract data type with the same shape as the one that it’s being called with. Because this abstract data type does not have any values, structural parts of the code that*require*knowledge of specific values of the arguments will be lost. You can get around this partially by declaring those important values to be*static*, which would make the JIT compiler re-compile the function each time that value changes. We are not going to do that. \[ \phantom{a} \] Restrictions to gradients have to do (I assume) with reverse-mode autodiff needing to construct the autodiff tree at compile time, which means you need to be able to compute the number of operations from the types and shapes of the input variables and not from their values.↩︎Coverage is pretty good on the

*using*bit, but, as is usual, the bits on extending the system are occasionally a bit … sparse. (What in the hairy Christ is a transposition rule actually supposed to do????)↩︎Forest↩︎

aka implement the damn thing in C++ and then do some proper work on it.↩︎

It is useful to think of a sparse matrix type as the triple

`(value_type, indices, indptr)`

. This means that if we are going to do something like add sparse matrices, we need to first cast them both to have the same type. After the cast, addition of two different sparse matrices becomes the addition of their`x`

attributes. The same holds for scalar multiplication. Sparse matrix-matrix multiplication is a bit different because you once again need to symbolically work out the sparsity structure (aka the type) of the product. ↩︎I think. That’s certainly what’s implied by the docs, but I don’t want to give the impression that I’m sure. Because this is complicated.↩︎

What is

`jax.lax`

? Oh honey you don’t want to know.↩︎aka there’s no weird copying↩︎

slowwwwww to compile↩︎

The XLA compiler does very clever things. Incidentally, loop unrolling is actually one of the optimisations that compilers have in their pocket. Just not one that’s usually used for loops as large as this.↩︎

Read about XLA High Level Operations (HLOs) here. The XLA documentation is not extensive, but there’s still a lot to read.↩︎

This is why we have a new data structure.↩︎

My kingdom for a ragged array.↩︎

Yes. They are more complicated than this. But for our purposes they are lists of arrays.↩︎

\(n=50\) takes so long it prints a message telling us what to do if we need to do if we want to file a bug! Compilation eventually clocks in at 361 seconds.↩︎

aka I know sweet bugger all↩︎

## Reuse

## Citation

```
@online{simpson2022,
author = {Simpson, Dan},
title = {Sparse {Matrices} 3: {Failing} at {JAX}},
date = {2022-05-14},
url = {https://dansblog.netlify.app/2022-05-14-jax-ing-a-sparse-cholesky-factorisation-part-3-in-an-ongoing-journey},
langid = {en}
}
```