The time has come once more to resume my journey into sparse matrices. There’s been a bit of a pause, mostly because I realised that I didn’t know how to implement the sparse Cholesky factorisation in a JAX-traceable way. But now the time has come. It is time for me to get on top of JAX’s weird control-flow constructs.

And, along the way, I’m going to re-do the sparse Cholesky factorisation to make it, well, better.

In order to temper expectations, I will tell you that this post does not do the numerical factorisation, only the symbolic one. Why? Well I wrote most of it on a long-haul flight and I didn’t get to the numerical part. And this was long enough. So hold your breaths for Part 7b, which will come as soon as I write it.

You can consider this a *much* better re-do of Part 2. This is no longer my first python coding exercise in a decade, so hopefully the code is better. And I’m definitely trying a lot harder to think about the limitations of JAX.

Before I start, I should probably say why I’m doing this. JAX is a truly magical thing that will compute gradients and every thing else just by clever processing of the Jacobian-vector product code. Unfortunately, this is only possible if the Jacobian-vector product code is JAX traceable and this code is structurally extremely similar^{1} to the code for the sparse Cholesky factorisation.

I am doing this in the hope of (eventually getting to) autodiff. But that won’t be this blog post. This blog post is complicated enough.

## Control flow of the damned

The first an most important rule of programming with JAX is that loops will break your heart. I mean, whatever, I guess they’re fine. But there’s a problem. Imagine the following function

This is, basically, the worst implementation of multiplication by an integer that you can possibly imagine. This code will run fine in Python, but if you try to JIT compile it, JAX is gonna get *angry*. It will produce the machine code equivalent of

There are two bad things happening here. First, note that the “compiled” code depends on `n`

and will have to be compiled anew each time `n`

changes. Secondly, the loop has been replaced by `n`

copies of the loop body. This is called *loop unrolling* and, when used judiciously by a clever compiler, is a great way to speed up code. When done completely for *every* loop this is a nightmare and the corresponding code will take a geological amount of time to compile.

A similar thing^{2} happens when you need to run autodiff on `f(x,n)`

. For each `n`

an expression graph is constructed that contains the unrolled for loop. This suggests that autodiff might also end up being quite slow (or, more problematically, more memory-hungry).

So the first rule of JAX is to avoid for loops. But if you can’t do that, there are three built-in loop structures that play nicely with JIT compilation and sometimes^{3} differentiation. These three constructs are

- A while loop
`jax.lax.while(cond_func, body_func, init)`

- An accumulator
`jax.lax.scan(body_func, init, xs)`

- A for loop
`jax.lax.fori_loop(lower, upper, body_fun, init)`

Of those three, the first and third work mostly as you’d expect, while the second is a bit more hairy. The `while`

function is roughly equivalent to

So basically it’s just a while loop. The thing that’s important is that it compiles down to a single XLA operation^{4} instead of some unrolled mess.

One thing that is important to realise is that while loops are only forwards-mode differentiable, which means that it is *very* expensive^{5} to compute gradients. The reason for this is that we simply do not know how long that loop actually is and so it’s impossible to build a fixed-size expression graph.

The `jax.lax.scan`

function is probably the one that people will be least familiar with. That said, it’s also the one that is roughly “how a for loop should work”. The concept that’s important here is a for-loop with *carry over*. Carry over is information that changes from one step of the loop to the next. This is what separates us from a `map`

statement, which would apply the same function independently to each element of a list.

The scan function looks like

A critically important limitation to `jax.lax.scan`

is that is that every `x`

in `xs`

must have the same shape! This mean, for example, that

is not a valid argument. Like all limitations in JAX, this serves to make the code transformable into efficiently compiled code across various different processors.

For example, if I wanted to use `jax.lax.scan`

on my example from before I would get

```
from jax import lax
from jax import numpy as jnp
def f(x, n):
init = jnp.zeros_like(x)
xs = jnp.repeat(x, n)
def body_func(carry, y):
val = carry + y
return (val, val)
final, journey = lax.scan(body_func, init, xs)
return (final, journey)
final, journey = f(1.2, 7)
print(final)
print(journey)
```

```
8.4
[1.2 2.4 3.6000001 4.8 6. 7.2 8.4 ]
```

This translation is a bit awkward compared to the for loop but it’s the sort of thing that you get used to.

This function can be differentiated^{6} and compiled. To differentiate it, I need a version that returns a scalar, which is easy enough to do with a lambda.

```
from jax import jit, grad
f2 = lambda x, n: f(x,n)[0]
f2_grad = grad(f2, argnums = 0)
print(f2_grad(1.2, 7))
```

`7.0`

The `argnums`

option tells JAX that we are only differentiating wrt the first argument.

JIT compilation is a tiny bit more delicate. If we try the natural thing, we are going to get an error.

```
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
When jit-compiling jnp.repeat, the total number of repeats must be static. To fix this, either specify a static value for `repeats`, or pass a static value to `total_repeat_length`.
The error occurred while tracing the function f at /var/folders/08/4p5p665j4d966tr7nvr0v24c0000gn/T/ipykernel_24749/3851190413.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument 'n'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
```

In order to compile a function, JAX needs to know how big everything is. And right now it does not know what `n`

is. This shows itself through the `ConcretizationTypeError`

, which basically says that as JAX was looking through your code it found something it can’t manipulate. In this case, it was in the `jnp.repeat`

function.

We can fix this problem by declaring this parameter `static`

.

A static parameter is a parameter value that is known at compile time. If we define `n`

to be static, then the first time you call `f_jit(x, 7)`

it will compile and then it will reuse the compiled code for any other value of `x`

. If we then call `f_jit(x, 9)`

, the code will *compile again*.

To see this, we can make use of a JAX oddity: if a function prints something^{7}, then it will only be printed upon compilation and never again. This means that we can’t do *debug by print*. But on the upside, it’s easy to check, when things are compiling.

```
def f2(x, n):
print(f"compiling: n = {n}")
return f(x,n)[0]
f2_jit = jit(f2, static_argnums=(1,))
print(f2_jit(1.2,7))
print(f2_jit(1.8,7))
print(f2_jit(1.2,9))
print(f2_jit(1.8,7))
```

```
compiling: n = 7
8.4
12.6
compiling: n = 9
10.799999
12.6
```

This is a perfectly ok solution as long as the static parameters don’t change very often. In our context, this is going to have to do with the sparsity pattern.

Finally, we can talk about `jax.lax.fori_loop`

, the in-built for loop. This is basically a convenience wrapper for `jax.lax.scan`

(when `lower`

and `upper`

are static) or `jax.lax.while`

(when they are not). The Python pseudocode is

To close out this bit where I repeat the docs, there is also a traceable if/else: `jax.lax.cond`

which has the pseudocode

## Building a JAX-traceable symbolic sparse Choleksy factorisation

In order to build a JAX-traceable sparse Cholesky factorisation \(A = LL^T\), we are going to need to build up a few moving parts.

Build the elimination tree of \(A\) and find the number of non-zeros in each column of \(L\)

Build the

*symbolic factorisation*^{8}of \(L\) (aka the location of the non-zeros of \(L\))Do the actual numerical decomposition.

In the previous post we did not explicitly form the elimination tree. Instead, I used dynamic memory allocation. This time I’m being more mature.

### Building the expression graph

The elimination tree^{9} \(\mathcal{T}_A\) is a (forest of) rooted tree(s) that compactly represent the non-zero pattern of the Cholesky factor \(L\). In particular, the elimination tree has the property that, for any \(k > j\) , \(L_{kj} \neq 0\) if and only if there is a path from \(j\) to \(k\) in the tree. Or, in the language of trees, \(L_{kj} \neq 0\) if and only if \(j\) is a descendant of \(k\) in the tree \(\mathcal{T}_A\).

We can describe^{10} \(\mathcal{T}_A\) by listing the parent of each node. The parent node of \(j\) in the tree is the smallest \(i > j\) with \(L_{ij} \neq 0\).

We can turn this into an algorithm. An efficient version, which is described in Tim Davies book takes about \(\mathcal{O(\text{nnz}(A))}\) operations. But I’m going to program up a slower one that takes \(\mathcal{O(\text{nnz}(L))}\) operations, but has the added benefit^{11} of giving me the column counts for free.

To do this, we are going to walk the tree and dynamically add up the column counts as we go.

To start off, let’s do this in standard python so that we can see what the algorithm look like. The key concept is that if we write \(\mathcal{T}_{j-1}\) as the elimination tree encoding the structure of^{12} `L[:j, :j]`

, then we can ask about how this tree connects with node `j`

.

A theorem gives a very simple answer to this.

**Theorem 1 **If \(j > i\), then \(A_{j,i} \neq 0\) implies that \(i\) is a descendant of \(j\) in \(\mathcal{T}_A\). In particular, that means that there is a directed path in \(\mathcal{T}_A\) from \(i\) to \(j\).

This tells us that the connection between \(\mathcal{T}_{j-1}\) and node \(j\) is that for each non-zero elements \(i\) of the \(j\)th row of \(A\), we can walk $ must have a path in \(\mathcal{T}_{j-1}\) from \(i\) and we will eventually get to a node that has no parent in \(\{0,\ldots, j-1\}\). Because there *must* be a path from \(i\) to \(j\) in \(T_j\), it means that the parent of this terminal node must be \(j\).

As with everything Cholesky related, this works because the algorithm proceeds from left to right, which in this case means that the node label associated with *any* descendant of \(j\) is always less than \(j\).

The algorithm is then a fairly run-of-the-mill^{13} tree traversal, where we keep track of where we have been so we don’t double count our columns.

Probably the most important thing here is that I am using the *full* sparse matrix rather than just its lower triangle. This is, basically, convenience. I need access to the left half of the \(j\)th row of \(A\), which is conveniently the same as the top half of the \(j\)th column. And sometimes you just don’t want to be dicking around with swapping between row- and column-based representations.

```
import numpy as np
def etree_base(A_indices, A_indptr):
n = len(A_indptr) - 1
parent = [-1] * n
mark = [-1] * n
col_count = [1] * n
for j in range(n):
mark[j] = j
for indptr in range(A_indptr[j], A_indptr[j+1]):
node = A_indices[indptr]
while node < j and mark[node] != j:
if parent[node] == -1:
parent[node] = j
mark[node] = j
col_count[node] += 1
node = parent[node]
return (parent, col_count)
```

To convince ourselves this works, let’s run an example and compare the column counts we get to our previous method.

## Some boilerplate from previous editions.

```
from scipy import sparse
import scipy as sp
def make_matrix(n):
one_d = sparse.diags([[-1.]*(n-2), [2.]*n, [-1.]*(n-2)], [-2,0,2])
A = (sparse.kronsum(one_d, one_d) + sparse.eye(n*n))
A_csc = A.tocsc()
A_csc.eliminate_zeros()
A_lower = sparse.tril(A_csc, format = "csc")
A_index = A_lower.indices
A_indptr = A_lower.indptr
A_x = A_lower.data
return (A_index, A_indptr, A_x, A_csc)
def _symbolic_factor(A_indices, A_indptr):
# Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.
n = len(A_indptr) - 1
L_sym = [np.array([], dtype=int) for j in range(n)]
children = [np.array([], dtype=int) for j in range(n)]
for j in range(n):
L_sym[j] = A_indices[A_indptr[j]:A_indptr[j + 1]]
for child in children[j]:
tmp = L_sym[child][L_sym[child] > j]
L_sym[j] = np.unique(np.append(L_sym[j], tmp))
if len(L_sym[j]) > 1:
p = L_sym[j][1]
children[p] = np.append(children[p], j)
L_indptr = np.zeros(n+1, dtype=int)
L_indptr[1:] = np.cumsum([len(x) for x in L_sym])
L_indices = np.concatenate(L_sym)
return L_indices, L_indptr
```

```
# A_indices/A_indptr are the lower triangle, A is the entire matrix
A_indices, A_indptr, A_x, A = make_matrix(37)
parent, col_count = etree_base(A.indices, A.indptr)
L_indices, L_indptr = _symbolic_factor(A_indices, A_indptr)
true_parent = L_indices[L_indptr[:-2] + 1]
true_parent[np.where(np.diff(L_indptr[:-1]) == 1)] = -1
print(all(x == y for (x,y) in zip(parent[:-1], true_parent)))
true_col_count = np.diff(L_indptr)
print(all(true_col_count == col_count))
```

```
True
True
```

Excellent. Now we just need to convert it to JAX.

Or do we?

To be honest, this is a little pointless. This function is only run once per matrix so we won’t really get much speedup^{14} from compilation.

Nevertheless, we might try.

```
@jit
def etree(A_indices, A_indptr):
# print("(Re-)compiling etree(A_indices, A_indptr)")
## innermost while loop
def body_while(val):
# print(val)
j, node, parent, col_count, mark = val
update_parent = lambda x: x[0].at[x[1]].set(x[2])
parent = lax.cond(lax.eq(parent[node], -1), update_parent, lambda x: x[0], (parent, node, j))
mark = mark.at[node].set(j)
col_count = col_count.at[node].add(1)
return (j, parent[node], parent, col_count, mark)
def cond_while(val):
j, node, parent, col_count, mark = val
return lax.bitwise_and(lax.lt(node, j), lax.ne(mark[node], j))
## Inner for loop
def body_inner_for(indptr, val):
j, A_indices, A_indptr, parent, col_count, mark = val
node = A_indices[indptr]
j, node, parent, col_count, mark = lax.while_loop(cond_while, body_while, (j, node, parent, col_count, mark))
return (j, A_indices, A_indptr, parent, col_count, mark)
## Outer for loop
def body_out_for(j, val):
A_indices, A_indptr, parent, col_count, mark = val
mark = mark.at[j].set(j)
j, A_indices, A_indptr, parent, col_count, mark = lax.fori_loop(A_indptr[j], A_indptr[j+1], body_inner_for, (j, A_indices, A_indptr, parent, col_count, mark))
return (A_indices, A_indptr, parent, col_count, mark)
## Body of code
n = len(A_indptr) - 1
parent = jnp.repeat(-1, n)
mark = jnp.repeat(-1, n)
col_count = jnp.repeat(1, n)
init = (A_indices, A_indptr, parent, col_count, mark)
A_indices, A_indptr, parent, col_count, mark = lax.fori_loop(0, n, body_out_for, init)
return parent, col_count
```

Wow. That is *ugly*. But let’s see^{15} if it works!

```
parent_jax, col_count_jax = etree(A.indices, A.indptr)
print(all(x == y for (x,y) in zip(parent_jax[:-1], true_parent)))
print(all(true_col_count == col_count_jax))
```

`True`

`True`

Success!

I guess we could ask ourselves if we gained any speed.

Here is the pure python code.

```
import timeit
A_indices, A_indptr, A_x, A = make_matrix(20)
times = timeit.repeat(lambda: etree_base(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(50)
times = timeit.repeat(lambda: etree_base(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
times = timeit.repeat(lambda: etree_base(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

```
n = 400: [0.0, 0.0, 0.0, 0.0, 0.0]
n = 2500: [0.03, 0.03, 0.03, 0.03, 0.03]
```

`n = 40000: [0.83, 0.82, 0.82, 0.82, 0.82]`

And here is our JAX’d and JIT’d code.

```
A_indices, A_indptr, A_x, A = make_matrix(20)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(50)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
parent, col_count= etree(A.indices, A.indptr)
L_indices, L_indptr = _symbolic_factor(A_indices, A_indptr)
A_indices, A_indptr, A_x, A = make_matrix(1000)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 400: [0.13, 0.0, 0.0, 0.0, 0.0]`

`n = 2500: [0.12, 0.0, 0.0, 0.0, 0.0]`

`n = 40000: [0.14, 0.02, 0.02, 0.02, 0.02]`

`n = 1000000: [2.24, 2.11, 2.12, 2.12, 2.12]`

You can see that there is some decent speedup. For the first three examples, the computation time is dominated by the compilation time, but we see when the matrix has a million unknowns the compilation time is negligible. At this scale it would probably be worth using the fancy algorithm. That said, it is probably not worth sweating a three second that is only done once when your problem is that big!

### The non-zero pattern of \(L\)

Now that we know how many non-zeros there are, it’s time to populate them. Last time, I used some dynamic memory allocation to make this work, but JAX is certainly not going to allow me to do that. So instead I’m going to have to do the worst thing possible: think.

The way that we went about it last time was, to be honest, a bit arse-backwards. The main reason for this is that I did not have access to the elimination tree. But now we do, we can actually use it.

The trick is to slightly rearrange^{16} the order of operations to get something that is more convenient for working out the structure.

Recall from last time that we used the *left-looking* Cholesky factorisation, which can be written in the dense case as

This is not the only way to organise those operations. An alternative is the *up-looking* Cholesky factorisation, which can be implemented in the dense case as

This is quite a different looking beast! It scans row by row rather than column by column. And while the left-looking algorithm is based on matrix-vector multiplies, the up-looking algorithm is based on triangular solves. So maybe we should pause for a moment to check that these are the same algorithm!

```
A = np.random.rand(15, 15)
A = A + A.transpose()
A = A.transpose() @ A + 1*np.eye(15)
L_left = dense_left_cholesky(A)
L_up = dense_up_cholesky(A)
print(round(sum(sum(abs((L_left - L_up)[:])))),2)
```

`0 2`

They are the same!!

The reason for considering the up-looking algorithm is that it gives a slightly nicer description of the non-zeros of row `i`

, which will let us find the location of the non-zeros in the whole matrix. In particular, the non-zeros to the left of the diagonal on row `i`

correspond to the non-zero indices of the solution to the lower triangular linear system^{17} \[
L_{1:(i-1),1:(i-1)} x^{(i)} = A_{1:i-1, i}.
\] Because \(A\) is sparse, this is a system of \(\operatorname{nnz}(A_{1:i-1,i})\) linear equations, rather than \((i-1)\) equations that we would have in the dense case. That means that the sparsity pattern of \(x^{(i)}\) will be the union of the sparsity patterns of the columns of \(L_{1:(i-1),1:(i-1)}\) that correspond to the non-zero entries of \(A_{1:i-1, i}\).

This means two things. Firstly, if \(A_{ji}\neq 0\), then \(x^{(i)}_j \neq 0\). Secondly, if $x^{(i)}_j $ *and* \(L_{kj}\neq 0\), then \(x_k \neq 0\). These two facts give us a way of finding the non-zero set of \(x^{(i)}\) if we remember just one more fact: a definition of the elimination tree is that \(L_{kj} \neq 0\) if \(j\) is a descendant of \(k\) in the elimination tree.

This reduces the problem of finding the non-zero elements of \(x^{(i)}\) to the problem of finding all of the descendants of \(\{j: A_{ji} \neq 0\}\) in the subtree \(\mathcal{T}_{i-1}\). And if there is one thing that people who are ok at programming are *excellent* at it is walking down a damn tree.

So let’s do that. Well, I’ve already done it. In fact, that was how I found the column counts in the first place! With this interpretation, the outer loop is taking us across the rows. And once I am in row `j`

^{18}, I then find a starting node `node`

(which is a non-zero in \(A_{1:(i-1),i}\)) and I walk along that node checking each time if I’ve actually seen that node^{19} before. If I haven’t seen it before, I added one to the column count of column `node`

^{20}.

To allocate the non-zero structure, I just need to replace that counter increment with an assignment.

### Attempt 1: Lord that’s slow

We will do the pure python version first.

```
def symbolic_cholesky_base(A_indices, A_indptr, parent, col_count):
n = len(A_indptr) - 1
col_ptr = np.repeat(1, n+1)
col_ptr[1:] += np.cumsum(col_count)
L_indices = np.zeros(sum(col_count), dtype=int)
L_indptr = np.zeros(n+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
mark = [-1] * n
for i in range(n):
mark[i] = i
L_indices[L_indptr[i]] = i
for indptr in range(A_indptr[i], A_indptr[i+1]):
node = A_indices[indptr]
while node < i and mark[node] != i:
mark[node] = i
L_indices[col_ptr[node]] = i
col_ptr[node] += 1
node = parent[node]
return (L_indices, L_indptr)
```

Does it work?

```
A_indices, A_indptr, A_x, A = make_matrix(13)
parent, col_count = etree_base(A.indices, A.indptr)
L_indices, L_indptr = symbolic_cholesky_base(A.indices, A.indptr, parent, col_count)
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(x==y for (x,y) in zip(L_indices, L_indices_true)))
print(all(x==y for (x,y) in zip(L_indptr, L_indptr_true)))
```

```
True
True
```

Fabulosa!

Now let’s do the compiled version.

```
from functools import partial
@partial(jit, static_argnums = (4,))
def symbolic_cholesky(A_indices, A_indptr, L_indptr, parent, nnz):
## innermost while loop
def body_while(val):
i, L_indices, L_indptr, node, parent, col_ptr, mark = val
mark = mark.at[node].set(i)
#p =
L_indices = L_indices.at[col_ptr[node]].set(i)
col_ptr = col_ptr.at[node].add(1)
return (i, L_indices, L_indptr, parent[node], parent, col_ptr, mark)
def cond_while(val):
i, L_indices, L_indptr, node, parent, col_ptr, mark = val
return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))
## Inner for loop
def body_inner_for(indptr, val):
i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
node = A_indices[indptr]
i, L_indices, L_indptr, node, parent, col_ptr, mark = lax.while_loop(cond_while, body_while, (i, L_indices, L_indptr, node, parent, col_ptr, mark))
return (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
## Outer for loop
def body_out_for(i, val):
A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
mark = mark.at[i].set(i)
L_indices = L_indices.at[L_indptr[i]].set(i)
i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(A_indptr[i], A_indptr[i+1], body_inner_for, (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark))
return (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
## Body of code
n = len(A_indptr) - 1
col_ptr = L_indptr + 1
L_indices = jnp.zeros(nnz, dtype=int)
mark = jnp.repeat(-1, n)
init = (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(0, n, body_out_for, init)
return L_indices
```

Now let’s check it works

```
A_indices, A_indptr, A_x, A = make_matrix(20)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1])
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
A_indices, A_indptr, A_x, A = make_matrix(31)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1])
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
```

```
True
True
```

```
True
True
```

Success!

One *minor* problem. This is slow. as. balls.

```
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree_base(A.indices, A.indptr)
times = timeit.repeat(lambda: symbolic_cholesky_base(A.indices, A.indptr, parent, col_count),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree_base(A.indices, A.indptr)
times = timeit.repeat(lambda: symbolic_cholesky_base(A.indices, A.indptr, parent, col_count),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 2500: [0.05, 0.04, 0.04, 0.04, 0.04]`

`n = 40000: [1.97, 2.09, 2.04, 2.03, 1.92]`

And here is our JAX’d and JIT’d code.

```
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 2500: [0.15]`

`n = 40000: [29.19]`

Oooof. Something is going horribly wrong.

### Why is it so slow?

The first thing to check is if it’s the compile time. We can do this by explicitly *lowering* the the JIT’d function to its XLA representation and then compiling it.

```
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: jit(partial(symbolic_cholesky, nnz=int(L_indptr[-1]))).lower(A.indices, A.indptr, L_indptr, parent).compile(),number = 1, repeat = 5)
print(f"Compilation time: n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: jit(partial(symbolic_cholesky, nnz=int(L_indptr[-1]))).lower(A.indices, A.indptr, L_indptr, parent).compile(),number = 1, repeat = 5)
print(f"Compilation time: n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`Compilation time: n = 2500: [0.15, 0.15, 0.15, 0.16, 0.15]`

`Compilation time: n = 40000: [0.16, 0.15, 0.14, 0.16, 0.15]`

It is not the compile time.

And that is actually a good thing because that suggests that we aren’t having problems with the compiler unrolling all of our wonderful loops! But that does mean that we have to look a bit deeper into the code. Some smart people would probably be able to look at the `jaxpr`

intermediate representation to diagnose the problem. But I couldn’t see anything there.

Instead I thought *if I were a clever, efficient compiler, what would I have problems with?*. And the answer is the classic sparse matrix answer: indirect indexing.

The only structural difference between the `etree`

function and the `symbolic_cholesky`

function is this line in the `body_while()`

function:

In order to evaluate this code, the compiler has to resolve *two levels* of indirection. By contrast, the indexing in `etree()`

was always direct. So let’s see what happens if we take the same function and remove that double indirection.

```
@partial(jit, static_argnums = (4,))
def test_fun(A_indices, A_indptr, L_indptr, parent, nnz):
## innermost while loop
def body_while(val):
i, L_indices, L_indptr, node, parent, col_ptr, mark = val
mark = mark.at[node].set(i)
L_indices = L_indices.at[node].set(i)
col_ptr = col_ptr.at[node].add(1)
return (i, L_indices, L_indptr, parent[node], parent, col_ptr, mark)
def cond_while(val):
i, L_indices, L_indptr, node, parent, col_ptr, mark = val
return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))
## Inner for loop
def body_inner_for(indptr, val):
i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
node = A_indices[indptr]
i, L_indices, L_indptr, node, parent, col_ptr, mark = lax.while_loop(cond_while, body_while, (i, L_indices, L_indptr, node, parent, col_ptr, mark))
return (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
## Outer for loop
def body_out_for(i, val):
A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
mark = mark.at[i].set(i)
L_indices = L_indices.at[L_indptr[i]].set(i)
i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(A_indptr[i], A_indptr[i+1], body_inner_for, (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark))
return (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
## Body of code
n = len(A_indptr) - 1
col_ptr = L_indptr + 1
L_indices = jnp.zeros(nnz, dtype=int)
mark = jnp.repeat(-1, n)
init = (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(0, n, body_out_for, init)
return L_indices
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:test_fun(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:test_fun(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 2500: [0.14]`

`n = 40000: [0.17]`

That isn’t conclusive, but it does indicate that this might^{21} be the problem.

And this is a *big* problem for us! The sparse Cholesky algorithm has similar amounts of indirection. So we need to fix it.

### Attempt 2: After some careful thought, things stayed the same

Now. I want to pretend that I’ve got elegant ideas about this. But I don’t. So let’s just do it. The most obvious thing to do is to use the algorithm to get the non-zero structure of the *rows* of \(L\). These are the things that are being indexed by `col_ptr[node]]`

, so if we have these explicitly we don’t need multiple indirection. We also don’t need a while loop.

In fact, if we have the non-zero structure of the rows of \(L\), we can turn that into the non-zero structure of the columns in linear-ish^{22} time.

All we need to do is make sure that our `etree()`

function is also counting the number of nonzeros in each row.

```
@jit
def etree(A_indices, A_indptr):
# print("(Re-)compiling etree(A_indices, A_indptr)")
## innermost while loop
def body_while(val):
# print(val)
j, node, parent, col_count, row_count, mark = val
update_parent = lambda x: x[0].at[x[1]].set(x[2])
parent = lax.cond(lax.eq(parent[node], -1), update_parent, lambda x: x[0], (parent, node, j))
mark = mark.at[node].set(j)
col_count = col_count.at[node].add(1)
row_count = row_count.at[j].add(1)
return (j, parent[node], parent, col_count, row_count, mark)
def cond_while(val):
j, node, parent, col_count, row_count, mark = val
return lax.bitwise_and(lax.lt(node, j), lax.ne(mark[node], j))
## Inner for loop
def body_inner_for(indptr, val):
j, A_indices, A_indptr, parent, col_count, row_count, mark = val
node = A_indices[indptr]
j, node, parent, col_count, row_count, mark = lax.while_loop(cond_while, body_while, (j, node, parent, col_count, row_count, mark))
return (j, A_indices, A_indptr, parent, col_count, row_count, mark)
## Outer for loop
def body_out_for(j, val):
A_indices, A_indptr, parent, col_count, row_count, mark = val
mark = mark.at[j].set(j)
j, A_indices, A_indptr, parent, col_count, row_count, mark = lax.fori_loop(A_indptr[j], A_indptr[j+1], body_inner_for, (j, A_indices, A_indptr, parent, col_count, row_count, mark))
return (A_indices, A_indptr, parent, col_count, row_count, mark)
## Body of code
n = len(A_indptr) - 1
parent = jnp.repeat(-1, n)
mark = jnp.repeat(-1, n)
col_count = jnp.repeat(1, n)
row_count = jnp.repeat(1, n)
init = (A_indices, A_indptr, parent, col_count, row_count, mark)
A_indices, A_indptr, parent, col_count, row_count, mark = lax.fori_loop(0, n, body_out_for, init)
return (parent, col_count, row_count)
```

Let’s check that the code is actually doing what I want.

```
A_indices, A_indptr, A_x, A = make_matrix(57)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indices, L_indptr = _symbolic_factor(A_indices, A_indptr)
true_parent = L_indices[L_indptr[:-2] + 1]
true_parent[np.where(np.diff(L_indptr[:-1]) == 1)] = -1
print(all(x == y for (x,y) in zip(parent[:-1], true_parent)))
true_col_count = np.diff(L_indptr)
print(all(true_col_count == col_count))
true_row_count = np.array([len(np.where(L_indices == i)[0]) for i in range(57**2)])
print(all(true_row_count == row_count))
```

```
True
True
```

`True`

Excellent! With this we can modify our previous function to give us the row-indices of the non-zero pattern instead. Just for further chaos, please note that we are using a CSC representation of \(A\) to get a CSR representation of \(L\).

Once again, we will prototype in pure python and then translate to JAX. The thing to look out for this time is that we *know* how many non-zeros there are in a row and we know where we need to put them. This suggests that we can compute these things in `body_inner_for`

and then do a vectorised version of our indirect indexing. This should compile down to a single XLA `scatter`

call. This will reduce the number of overall `scatter`

calls from \(\operatorname(nnz)(L)\) to \(n\). And hopefully this will fix things.

```
def symbolic_cholesky2_base(A_indices, A_indptr, L_indptr, row_count, parent, nnz):
n = len(A_indptr) - 1
col_ptr = L_indptr + 1
L_indices = np.zeros(L_indptr[-1], dtype=int)
mark = [-1] * n
for i in range(n):
mark[i] = i
row_ind = np.repeat(nnz+1, row_count[i])
row_ind[-1] = L_indptr[i]
counter = 0
for indptr in range(A_indptr[i], A_indptr[i+1]):
node = A_indices[indptr]
while node < i and mark[node] != i:
mark[node] = i
row_ind[counter] = col_ptr[node]
col_ptr[node] += 1
node = parent[node]
counter +=1
L_indices[row_ind] = i
return L_indices
A_indices, A_indptr, A_x, A = make_matrix(13)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky2_base(A.indices, A.indptr, L_indptr, row_count, parent, L_indptr[-1])
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(x==y for (x,y) in zip(L_indices, L_indices_true)))
print(all(x==y for (x,y) in zip(L_indptr, L_indptr_true)))
```

```
True
True
```

Excellent. Now let’s JAX this. The JAX-heads among you will notice that we have a subtle^{23} problem: in a `fori_loop`

, JAX does not treat `i`

as static, which means that the length of the repeat (`row_count[i]`

) can never be static and it therefore can’t be traced.

Shit.

It is hard to think of a good option here. A few months back Junpeng Lao^{24} sent me a script with his attempts at making the Cholesky stuff JAX transformable. And he hit the same problem. I was, in an act of hubris, trying very hard to not end up here. But that was tragically slow. So here we are.

He came up with two methods.

Pad out

`row_ind`

so it’s always long enough. This only costs memory. The maximum size of`row_ind`

is`n`

. Unfortunately, this happens whenever \(A\) has a dense row. Sadly, for Bayesian^{25}linear mixed models this will happen if we put Gaussian priors on the covariate coefficients^{26}and we try to marginalise them out with the other multivariate Gaussian parts. It is possible to write the routines that deal with dense rows and columns explicitly, but it’s a pain in the arse.Do some terrifying work with

`lax.scan`

and dynamic slicing.

I’m going to try the first of these options.

```
@partial(jit, static_argnums = (5, 6))
def symbolic_cholesky2(A_indices, A_indptr, L_indptr, row_count, parent, nnz, max_row):
## innermost while loop
def body_while(val):
i, counter, row_ind, node, col_ptr, mark = val
mark = mark.at[node].set(i)
row_ind = row_ind.at[counter].set(col_ptr[node])
col_ptr = col_ptr.at[node].add(1)
return (i, counter + 1, row_ind, parent[node], col_ptr, mark)
def cond_while(val):
i, counter, row_ind, node, col_ptr, mark = val
return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))
## Inner for loop
def body_inner_for(indptr, val):
i, counter, row_ind, parent, col_ptr, mark = val
node = A_indices[indptr]
i, counter, row_ind, node, col_ptr, mark = lax.while_loop(cond_while, body_while, (i, counter, row_ind, node, col_ptr, mark))
return (i, counter, row_ind, parent, col_ptr, mark)
## Outer for loop
def body_out_for(i, val):
L_indices, parent, col_ptr, mark = val
mark = mark.at[i].set(i)
row_ind = jnp.repeat(nnz+1, max_row)
row_ind = row_ind.at[row_count[i]-1].set(L_indptr[i])
counter = 0
i, counter, row_ind, parent, col_ptr, mark = lax.fori_loop(A_indptr[i], A_indptr[i+1], body_inner_for, (i, counter, row_ind, parent, col_ptr, mark))
L_indices = L_indices.at[row_ind].set(i, mode = "drop")
return (L_indices, parent, col_ptr, mark)
## Body of code
n = len(A_indptr) - 1
col_ptr = jnp.array(L_indptr + 1)
L_indices = jnp.ones(nnz, dtype=int) * (-1)
mark = jnp.repeat(-1, n)
## Make everything a jnp array. Really should use jaxtyping
A_indices = jnp.array(A_indices)
A_indptr = jnp.array(A_indptr)
L_indptr = jnp.array(L_indptr)
row_count = jnp.array(row_count)
parent = jnp.array(parent)
init = (L_indices, parent, col_ptr, mark)
L_indices, parent, col_ptr, mark = lax.fori_loop(0, n, body_out_for, init)
return L_indices
```

Ok. Let’s see if that worked.

```
A_indices, A_indptr, A_x, A = make_matrix(20)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count)))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
A_indices, A_indptr, A_x, A = make_matrix(31)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count)))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
```

```
True
True
```

```
True
True
```

Ok. Once more into the breach. Is this any better?

```
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count))),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count))),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
# A_indices, A_indptr, A_x, A = make_matrix(300)
# parent, col_count, row_count = etree(A.indices, A.indptr)
# L_indptr = np.zeros(A.shape[0]+1, dtype=int)
# L_indptr[1:] = np.cumsum(col_count)
# times = timeit.repeat(lambda:symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count))),number = 1, repeat = 1)
# print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

`n = 2500: [0.28]`

`n = 40000: [28.31]`

Fuck.

### Attempt 3: A desperate attempt to make this bloody work

Right. Let’s try again. What if instead of doing all those scatters we instead, idk, just store two vectors and sort. Because at this point I will try fucking anything. What if we just list out the column index and row index as we find them (aka build the sparse matrix in COO^{27} format. The `jax.experimental.sparse`

module has support for (blocked) COO objects but doesn’t implement this transformation. `scipy.sparse`

has a fast conversion routine so I’m going to use it. In the interest of being 100% JAX, I tried a version with `jnp.lexsort[index[1][jnp.lexsort((index[1], index[0]))]`

], which basically does the same thing but it’s a lot slower.

```
def symbolic_cholesky3(A_indices, A_indptr, L_indptr, parent, nnz):
@partial(jit, static_argnums = (4,))
def _inner(A_indices_, A_indptr_, L_indptr, parent, nnz):
## Make everything a jnp array. Really should use jaxtyping
A_indices_ = jnp.array(A_indices_)
A_indptr_ = jnp.array(A_indptr_)
L_indptr = jnp.array(L_indptr)
parent = jnp.array(parent)
## innermost while loop
def body_while(val):
index, i, counter, node, mark = val
mark = mark.at[node].set(i)
index[0] = index[0].at[counter].set(node) #column
index[1] = index[1].at[counter].set(i) # row
return (index, i, counter + 1, parent[node], mark)
def cond_while(val):
index, i, counter, node, mark = val
return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))
## Inner for loop
def body_inner_for(indptr, val):
index, i, counter, mark = val
node = A_indices_[indptr]
index, i, counter, node, mark = lax.while_loop(cond_while, body_while, (index, i, counter, node, mark))
return (index, i, counter, mark)
## Outer for loop
def body_out_for(i, val):
index, counter, mark = val
mark = mark.at[i].set(i)
index[0] = index[0].at[counter].set(i)
index[1] = index[1].at[counter].set(i)
counter = counter + 1
index, i, counter, mark = lax.fori_loop(A_indptr_[i], A_indptr_[i+1], body_inner_for, (index, i, counter, mark))
return (index, counter, mark)
## Body of code
n = len(A_indptr_) - 1
mark = jnp.repeat(-1, n)
index = [jnp.zeros(nnz, dtype=int), jnp.zeros(nnz, dtype=int)]
counter = 0
init = (index, counter, mark)
index, counter, mark = lax.fori_loop(0, n, body_out_for, init)
return index
n = len(A_indptr) - 1
index = _inner(A_indices, A_indptr, L_indptr, parent, nnz)
## return jnp.lexsort[index[1][jnp.lexsort((index[1], index[0]))
return sparse.coo_array((np.ones(nnz), (index[1], index[0])), shape = (n,n)).tocsc().indices
```

First things first, let’s check how fast this is.

```
A_indices, A_indptr, A_x, A = make_matrix(15)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1]))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
A_indices, A_indptr, A_x, A = make_matrix(31)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1]))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 5)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 5)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(300)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 5)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
A_indices, A_indptr, A_x, A = make_matrix(1000)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
```

```
True
True
True
True
```

`n = 2500: [0.13, 0.13, 0.13, 0.14, 0.14]`

`n = 40000: [0.19, 0.19, 0.19, 0.19, 0.19]`

`n = 90000: [0.43, 0.32, 0.32, 0.33, 0.33]`

`n = 1000000: [11.91]`

You know what? I’ll take it. It’s not perfect, in particular I would prefer a pure JAX solution. But everything I tried hit hard against the indirect memory access issue. The best I found was using `jnp.lexsort`

but even it had noticeable performance degradation as `nnz`

increased relative to the scipy solution.

# Next time on Sparse Matrices with Dan

So that’s where I’m going to leave it. I am off my flight and I’ve slept very well and now I’m going to be on holidays for a little while.

The next big thing to do is look at the numerical factorisation. We are going to run headlong into all of the problems we’ve hit today, so that should be fun. The reason why I’m separating it into a separate post^{28} is that I want to actually test all of those things out properly.

So next time you can expect

Classes! Because frankly this code is getting far too messy, especially now that certain things need to be passed as static arguments. The only reason I’ve avoided it up to now is that I think it hides too much of the algorithm in boilerplate. But now the boilerplate is ruining my life and causing far too many dumb typos

^{29}.Type hints! Because for a language where types aren’t explicit, they sure are important. Also because I’m going to class it up I might as well do it properly.

Some helper routines! I’m going to need a sparse-matrix scatter operation (aka the structured copy of \(A\) to have the sparsity pattern of \(L\))! And I’m certainly going to need some reorderings

^{30}A battle royale between padded and non-padded methods!

It should be a fun time!

## Footnotes

If you’re wondering about the break between sparse matrix posts, I realised this pretty much immediately and just didn’t want to deal with it!↩︎

If a person who actually knows how the JAX autodiff works happens across this blog, I’m so sorry.↩︎

omg you guys. So many details↩︎

These are referred to as HLOs (Higher-level operations)↩︎

Instead of doing one pass of reverse-mode, you would need to do \(d\) passes of forwards mode to get the gradient with respect to a d-dimensional parameter.↩︎

Unlike

`jax.lax.while`

, which is only forwards differentiable,`jax.lax.scan`

is fully differentiable.↩︎In general, if the function has state.↩︎

This is the version of the symbolic factorisation that is most appropriate for us, as we will be doing a lot of Cholesky factorisations with the same sparsity structure. If we rearrange the algorithm to the up-looking Cholesky decomposition, we only need the column counts and this is also called the symbolic factorisation. This is, incidentally, how Eigen’s sparse Cholesky works.↩︎

Actually it’s a forest↩︎

Because we are talking about a tree, each child node has at most one parent. If it doesn’t have a parent it’s the root of the tree. I remember a lecturer saying that it should be called “father and son” or “mother and daughter” because every child has 2 parents but only one mother or one father. The 2000s were a wild time.↩︎

These can also be computed in approximately \(\mathcal{O(\text{nnz}(A))}\) time, which is much faster. But the algorithm is, frankly, pretty tricky and I’m not in the mood to program it up. This difference would be quite important if I wasn’t storing the full symbolic factorisation and was instead computing it every time, but in my context it is less clear that this is worth the effort.↩︎

Python notation! This is rows/cols 0 to

`j-1`

↩︎Python, it turns out, does not have a

`do while`

construct because, apparently, everything is empty and life is meaningless.↩︎The argument for JIT works by amortizing the compile time over several function evaluations. If I wanted to speed this algorithm up, I’d implement the more complex \(\mathcal{O}(\operatorname{nnz}(A))\) version.↩︎

Obviously it did not work the first time. A good way to debug JIT’d code is to use the python translations of the control flow literals. Why? Well for one thing there is an annoying tendency for JAX to fail silently when their is an out-of-bounds indexing error. Which happens, just for example, if you replace

`node = A_indices[indptr]`

with`node = A_indices[A_indptr[indptr]]`

because you got a text message half way through the line.↩︎We will still use the left-looking algorithm for the numerical computation. The two algorithms are equivalent in exact arithmetic and, in particular, have identical sparsity structures.↩︎

I’m mixing 1-based indexing in the maths with 0-based in the code because I think we need more chaos in our lives.↩︎

Yes. I know. I’m swapping the meaning of \(i\) and \(j\) but you know that’s because in a symmetric matrix rows and columns are a bit similar. The upper half of column $$ is the left half of row \(j\) after all.↩︎

If

`mark[node]==j`

then I have already found`node`

and all of its ancestors in my sweep of row`j`

↩︎This is because

`L[j,node] != 0`

by our logic.↩︎I mean, I’m pretty sure it is. I’m writing this post in order, so I don’t know yet. But surely the compiler can’t reason about the possible values of

`node`

, which would be the only thing that would speed this up.↩︎Convert from CSR to

`(i, j, val)`

(called COO, which has a convenient implementation in`jax.experimental.sparse`

) to CSC. This involves a linear pass, a sort, and another linear pass. So it’s $n n`ish. Hire me fancy tech companies. I can count. Just don’t ask me to program quicksort.↩︎Replace “subtle” with “fairly obvious once I realised how it’s converted to a

`lax.scan`

, but not at all obvious to me originally”.↩︎Who demanded a footnote.↩︎

This also happens with the profile likelihood in non-Bayesian methods.↩︎

the \(\beta\)s↩︎

COO stands for

*coordinate list*and it’s the least space-efficient of our options. It directly stores 3 length`n`

vectors`(row, col, value)`

. It’s great for specifying matrices and it’s pretty easy to convert from this format to any of the others.↩︎other than holiday↩︎

`A_index`

and`A.index`

are different↩︎I’m probably going to bind Eigen’s AMD decomposition. I’m certainly not writing it myself.↩︎

## Reuse

## Citation

```
@online{simpson2022,
author = {Dan Simpson},
editor = {},
title = {Sparse Matrices Part 7a: {Another} Shot at {JAX-ing} the
{Cholesky} Decomposition},
date = {2022-12-02},
url = {https://dansblog.netlify.app/posts/2022-11-27-sparse7/sparse7.html},
langid = {en}
}
```