Sparse matrices part 7a: Another shot at JAX-ing the Cholesky decomposition

I work in R a lot so I should be used to weird syntax. This part looks at the non-zero pattern.

JAX
Sparse matrices
Autodiff
Author

Dan Simpson

Published

December 2, 2022

The time has come once more to resume my journey into sparse matrices. There’s been a bit of a pause, mostly because I realised that I didn’t know how to implement the sparse Cholesky factorisation in a JAX-traceable way. But now the time has come. It is time for me to get on top of JAX’s weird control-flow constructs.

And, along the way, I’m going to re-do the sparse Cholesky factorisation to make it, well, better.

In order to temper expectations, I will tell you that this post does not do the numerical factorisation, only the symbolic one. Why? Well I wrote most of it on a long-haul flight and I didn’t get to the numerical part. And this was long enough. So hold your breaths for Part 7b, which will come as soon as I write it.

You can consider this a much better re-do of Part 2. This is no longer my first python coding exercise in a decade, so hopefully the code is better. And I’m definitely trying a lot harder to think about the limitations of JAX.

Before I start, I should probably say why I’m doing this. JAX is a truly magical thing that will compute gradients and every thing else just by clever processing of the Jacobian-vector product code. Unfortunately, this is only possible if the Jacobian-vector product code is JAX traceable and this code is structurally extremely similar1 to the code for the sparse Cholesky factorisation.

I am doing this in the hope of (eventually getting to) autodiff. But that won’t be this blog post. This blog post is complicated enough.

Control flow of the damned

The first an most important rule of programming with JAX is that loops will break your heart. I mean, whatever, I guess they’re fine. But there’s a problem. Imagine the following function

def f(x: jax.Array, n: Int) -> jax.Array:
  out = jnp.zeros_like(x)
  for j in range(n):
    out = out + x
  return out

This is, basically, the worst implementation of multiplication by an integer that you can possibly imagine. This code will run fine in Python, but if you try to JIT compile it, JAX is gonna get angry. It will produce the machine code equivalent of

def f_n(x):
  out = x
  out = out + x
  out = out + x
  // do this n times
  return out

There are two bad things happening here. First, note that the “compiled” code depends on n and will have to be compiled anew each time n changes. Secondly, the loop has been replaced by n copies of the loop body. This is called loop unrolling and, when used judiciously by a clever compiler, is a great way to speed up code. When done completely for every loop this is a nightmare and the corresponding code will take a geological amount of time to compile.

A similar thing2 happens when you need to run autodiff on f(x,n). For each n an expression graph is constructed that contains the unrolled for loop. This suggests that autodiff might also end up being quite slow (or, more problematically, more memory-hungry).

So the first rule of JAX is to avoid for loops. But if you can’t do that, there are three built-in loop structures that play nicely with JIT compilation and sometimes3 differentiation. These three constructs are

  1. A while loop jax.lax.while(cond_func, body_func, init)
  2. An accumulator jax.lax.scan(body_func, init, xs)
  3. A for loop jax.lax.fori_loop(lower, upper, body_fun, init)

Of those three, the first and third work mostly as you’d expect, while the second is a bit more hairy. The while function is roughly equivalent to

`
def jax_lax_while_loop(cond_func, body_func, init):
  x  = init
  while cond_func(x):
    x = body_func(x)
  return x

So basically it’s just a while loop. The thing that’s important is that it compiles down to a single XLA operation4 instead of some unrolled mess.

One thing that is important to realise is that while loops are only forwards-mode differentiable, which means that it is very expensive5 to compute gradients. The reason for this is that we simply do not know how long that loop actually is and so it’s impossible to build a fixed-size expression graph.

The jax.lax.scan function is probably the one that people will be least familiar with. That said, it’s also the one that is roughly “how a for loop should work”. The concept that’s important here is a for-loop with carry over. Carry over is information that changes from one step of the loop to the next. This is what separates us from a map statement, which would apply the same function independently to each element of a list.

The scan function looks like

def jax_lax_scan(body_func, init, xs):
  len_x0 = len(x0)
  if not all(len(x) == len_x0 for x in xs):
    raise ValueError("All x must have the same length!!")
  carry = init
  ys = []
  for x in xs:
    carry, y = body_func(carry, x)
    ys.append(y)
  
  return carry, np.stack(ys)

A critically important limitation to jax.lax.scan is that is that every x in xs must have the same shape! This mean, for example, that

xs = [[1], [2,3], [4], 5,6,7]

is not a valid argument. Like all limitations in JAX, this serves to make the code transformable into efficiently compiled code across various different processors.

For example, if I wanted to use jax.lax.scan on my example from before I would get

from jax import lax
from jax import numpy as jnp

def f(x, n):
  init = jnp.zeros_like(x)
  xs = jnp.repeat(x, n)
  def body_func(carry, y):
    val = carry + y
    return (val, val)
  
  final, journey = lax.scan(body_func, init, xs)
  return (final, journey)

final, journey = f(1.2, 7)
print(final)
print(journey)
8.4
[1.2       2.4       3.6000001 4.8       6.        7.2       8.4      ]

This translation is a bit awkward compared to the for loop but it’s the sort of thing that you get used to.

This function can be differentiated6 and compiled. To differentiate it, I need a version that returns a scalar, which is easy enough to do with a lambda.

from jax import jit, grad

f2 = lambda x, n: f(x,n)[0]
f2_grad = grad(f2, argnums = 0)

print(f2_grad(1.2, 7))
7.0

The argnums option tells JAX that we are only differentiating wrt the first argument.

JIT compilation is a tiny bit more delicate. If we try the natural thing, we are going to get an error.

f_jit_bad = jit(f)
bad = f_jit_bad(1.2, 7)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
When jit-compiling jnp.repeat, the total number of repeats must be static. To fix this, either specify a static value for `repeats`, or pass a static value to `total_repeat_length`.
The error occurred while tracing the function f at /var/folders/08/4p5p665j4d966tr7nvr0v24c0000gn/T/ipykernel_24749/3851190413.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument 'n'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In order to compile a function, JAX needs to know how big everything is. And right now it does not know what n is. This shows itself through the ConcretizationTypeError, which basically says that as JAX was looking through your code it found something it can’t manipulate. In this case, it was in the jnp.repeat function.

We can fix this problem by declaring this parameter static.

f_jit = jit(f, static_argnums=(1,))
print(f_jit(1.2,7)[0])
8.4

A static parameter is a parameter value that is known at compile time. If we define n to be static, then the first time you call f_jit(x, 7) it will compile and then it will reuse the compiled code for any other value of x. If we then call f_jit(x, 9), the code will compile again.

To see this, we can make use of a JAX oddity: if a function prints something7, then it will only be printed upon compilation and never again. This means that we can’t do debug by print. But on the upside, it’s easy to check, when things are compiling.

def f2(x, n):
  print(f"compiling: n = {n}")
  return f(x,n)[0]

f2_jit = jit(f2, static_argnums=(1,))
print(f2_jit(1.2,7))
print(f2_jit(1.8,7))
print(f2_jit(1.2,9))
print(f2_jit(1.8,7))
compiling: n = 7
8.4
12.6
compiling: n = 9
10.799999
12.6

This is a perfectly ok solution as long as the static parameters don’t change very often. In our context, this is going to have to do with the sparsity pattern.

Finally, we can talk about jax.lax.fori_loop, the in-built for loop. This is basically a convenience wrapper for jax.lax.scan (when lower and upper are static) or jax.lax.while (when they are not). The Python pseudocode is

def jax_lax_fori_loop(lower, upper, body_func, init):
  out = init
  for i in range(lower, upper):
    out = body_func(i, out)
  return out

To close out this bit where I repeat the docs, there is also a traceable if/else: jax.lax.cond which has the pseudocode

def jax_lax_cond(pred, true_fun, false_fun, val):
  if pred:
    return true_fun(val)
  else:
    return false_fun(val)

Building a JAX-traceable symbolic sparse Choleksy factorisation

In order to build a JAX-traceable sparse Cholesky factorisation \(A = LL^T\), we are going to need to build up a few moving parts.

  1. Build the elimination tree of \(A\) and find the number of non-zeros in each column of \(L\)

  2. Build the symbolic factorisation8 of \(L\) (aka the location of the non-zeros of \(L\))

  3. Do the actual numerical decomposition.

In the previous post we did not explicitly form the elimination tree. Instead, I used dynamic memory allocation. This time I’m being more mature.

Building the expression graph

The elimination tree9 \(\mathcal{T}_A\) is a (forest of) rooted tree(s) that compactly represent the non-zero pattern of the Cholesky factor \(L\). In particular, the elimination tree has the property that, for any \(k > j\) , \(L_{kj} \neq 0\) if and only if there is a path from \(j\) to \(k\) in the tree. Or, in the language of trees, \(L_{kj} \neq 0\) if and only if \(j\) is a descendant of \(k\) in the tree \(\mathcal{T}_A\).

We can describe10 \(\mathcal{T}_A\) by listing the parent of each node. The parent node of \(j\) in the tree is the smallest \(i > j\) with \(L_{ij} \neq 0\).

We can turn this into an algorithm. An efficient version, which is described in Tim Davies book takes about \(\mathcal{O(\text{nnz}(A))}\) operations. But I’m going to program up a slower one that takes \(\mathcal{O(\text{nnz}(L))}\) operations, but has the added benefit11 of giving me the column counts for free.

To do this, we are going to walk the tree and dynamically add up the column counts as we go.

To start off, let’s do this in standard python so that we can see what the algorithm look like. The key concept is that if we write \(\mathcal{T}_{j-1}\) as the elimination tree encoding the structure of12 L[:j, :j], then we can ask about how this tree connects with node j.

A theorem gives a very simple answer to this.

Theorem 1 If \(j > i\), then \(A_{j,i} \neq 0\) implies that \(i\) is a descendant of \(j\) in \(\mathcal{T}_A\). In particular, that means that there is a directed path in \(\mathcal{T}_A\) from \(i\) to \(j\).

This tells us that the connection between \(\mathcal{T}_{j-1}\) and node \(j\) is that for each non-zero elements \(i\) of the \(j\)th row of \(A\), we can walk $ must have a path in \(\mathcal{T}_{j-1}\) from \(i\) and we will eventually get to a node that has no parent in \(\{0,\ldots, j-1\}\). Because there must be a path from \(i\) to \(j\) in \(T_j\), it means that the parent of this terminal node must be \(j\).

As with everything Cholesky related, this works because the algorithm proceeds from left to right, which in this case means that the node label associated with any descendant of \(j\) is always less than \(j\).

The algorithm is then a fairly run-of-the-mill13 tree traversal, where we keep track of where we have been so we don’t double count our columns.

Probably the most important thing here is that I am using the full sparse matrix rather than just its lower triangle. This is, basically, convenience. I need access to the left half of the \(j\)th row of \(A\), which is conveniently the same as the top half of the \(j\)th column. And sometimes you just don’t want to be dicking around with swapping between row- and column-based representations.

import numpy as np

def etree_base(A_indices, A_indptr):
  n = len(A_indptr) - 1
  parent = [-1] * n
  mark = [-1] * n
  col_count = [1] * n
  for j in range(n):
    mark[j] = j
    for indptr in range(A_indptr[j], A_indptr[j+1]):
      node = A_indices[indptr]
      while node < j and mark[node] != j:
        if parent[node] == -1:
          parent[node] = j
        mark[node] = j
        col_count[node] += 1
        node = parent[node]
  return (parent, col_count)

To convince ourselves this works, let’s run an example and compare the column counts we get to our previous method.

Some boilerplate from previous editions.
from scipy import sparse
import scipy as sp
    

def make_matrix(n):
  one_d = sparse.diags([[-1.]*(n-2), [2.]*n, [-1.]*(n-2)], [-2,0,2])
  A = (sparse.kronsum(one_d, one_d) + sparse.eye(n*n))
  A_csc = A.tocsc()
  A_csc.eliminate_zeros()
  A_lower = sparse.tril(A_csc, format = "csc")
  A_index = A_lower.indices
  A_indptr = A_lower.indptr
  A_x = A_lower.data
  return (A_index, A_indptr, A_x, A_csc)

def _symbolic_factor(A_indices, A_indptr):
  # Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.
  n = len(A_indptr) - 1
  L_sym = [np.array([], dtype=int) for j in range(n)]
  children = [np.array([], dtype=int) for j in range(n)]
  
  for j in range(n):
    L_sym[j] = A_indices[A_indptr[j]:A_indptr[j + 1]]
    for child in children[j]:
      tmp = L_sym[child][L_sym[child] > j]
      L_sym[j] = np.unique(np.append(L_sym[j], tmp))
    if len(L_sym[j]) > 1:
      p = L_sym[j][1]
      children[p] = np.append(children[p], j)
        
  L_indptr = np.zeros(n+1, dtype=int)
  L_indptr[1:] = np.cumsum([len(x) for x in L_sym])
  L_indices = np.concatenate(L_sym)
  
  return L_indices, L_indptr
# A_indices/A_indptr are the lower triangle, A is the entire matrix
A_indices, A_indptr, A_x, A = make_matrix(37)
parent, col_count = etree_base(A.indices, A.indptr)
L_indices, L_indptr = _symbolic_factor(A_indices, A_indptr)

true_parent = L_indices[L_indptr[:-2] + 1]
true_parent[np.where(np.diff(L_indptr[:-1]) == 1)] = -1
print(all(x == y for (x,y) in zip(parent[:-1], true_parent)))

true_col_count  = np.diff(L_indptr)
print(all(true_col_count == col_count))
True
True

Excellent. Now we just need to convert it to JAX.

Or do we?

To be honest, this is a little pointless. This function is only run once per matrix so we won’t really get much speedup14 from compilation.

Nevertheless, we might try.

@jit
def etree(A_indices, A_indptr):
 # print("(Re-)compiling etree(A_indices, A_indptr)")
  ## innermost while loop
  def body_while(val):
  #  print(val)
    j, node, parent, col_count, mark = val
    update_parent = lambda x: x[0].at[x[1]].set(x[2])
    parent = lax.cond(lax.eq(parent[node], -1), update_parent, lambda x: x[0], (parent, node, j))
    mark = mark.at[node].set(j)
    col_count = col_count.at[node].add(1)
    return (j, parent[node], parent, col_count, mark)

  def cond_while(val):
    j, node, parent, col_count, mark = val
    return lax.bitwise_and(lax.lt(node, j), lax.ne(mark[node], j))

  ## Inner for loop
  def body_inner_for(indptr, val):
    j, A_indices, A_indptr, parent, col_count, mark = val
    node = A_indices[indptr]
    j, node, parent, col_count, mark = lax.while_loop(cond_while, body_while, (j, node, parent, col_count, mark))
    return (j, A_indices, A_indptr, parent, col_count, mark)
  
  ## Outer for loop
  def body_out_for(j, val):
     A_indices, A_indptr, parent, col_count, mark = val
     mark = mark.at[j].set(j)
     j, A_indices, A_indptr, parent, col_count, mark = lax.fori_loop(A_indptr[j], A_indptr[j+1], body_inner_for, (j, A_indices, A_indptr, parent, col_count, mark))
     return (A_indices, A_indptr, parent, col_count, mark)

  ## Body of code
  n = len(A_indptr) - 1
  parent = jnp.repeat(-1, n)
  mark = jnp.repeat(-1, n)
  col_count = jnp.repeat(1,  n)
  init = (A_indices, A_indptr, parent, col_count, mark)
  A_indices, A_indptr, parent, col_count, mark = lax.fori_loop(0, n, body_out_for, init)
  return parent, col_count

Wow. That is ugly. But let’s see15 if it works!

parent_jax, col_count_jax = etree(A.indices, A.indptr)

print(all(x == y for (x,y) in zip(parent_jax[:-1], true_parent)))
print(all(true_col_count == col_count_jax))
True
True

Success!

I guess we could ask ourselves if we gained any speed.

Here is the pure python code.

import timeit
A_indices, A_indptr, A_x, A = make_matrix(20)

times = timeit.repeat(lambda: etree_base(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(50)
times = timeit.repeat(lambda: etree_base(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")


A_indices, A_indptr, A_x, A = make_matrix(200)
times = timeit.repeat(lambda: etree_base(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
n = 400: [0.0, 0.0, 0.0, 0.0, 0.0]
n = 2500: [0.03, 0.03, 0.03, 0.03, 0.03]
n = 40000: [0.83, 0.82, 0.82, 0.82, 0.82]

And here is our JAX’d and JIT’d code.

A_indices, A_indptr, A_x, A = make_matrix(20)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(50)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")


A_indices, A_indptr, A_x, A = make_matrix(200)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

parent, col_count= etree(A.indices, A.indptr)
L_indices, L_indptr = _symbolic_factor(A_indices, A_indptr)

A_indices, A_indptr, A_x, A = make_matrix(1000)
times = timeit.repeat(lambda: etree(A.indices, A.indptr),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
n = 400: [0.13, 0.0, 0.0, 0.0, 0.0]
n = 2500: [0.12, 0.0, 0.0, 0.0, 0.0]
n = 40000: [0.14, 0.02, 0.02, 0.02, 0.02]
n = 1000000: [2.24, 2.11, 2.12, 2.12, 2.12]

You can see that there is some decent speedup. For the first three examples, the computation time is dominated by the compilation time, but we see when the matrix has a million unknowns the compilation time is negligible. At this scale it would probably be worth using the fancy algorithm. That said, it is probably not worth sweating a three second that is only done once when your problem is that big!

The non-zero pattern of \(L\)

Now that we know how many non-zeros there are, it’s time to populate them. Last time, I used some dynamic memory allocation to make this work, but JAX is certainly not going to allow me to do that. So instead I’m going to have to do the worst thing possible: think.

The way that we went about it last time was, to be honest, a bit arse-backwards. The main reason for this is that I did not have access to the elimination tree. But now we do, we can actually use it.

The trick is to slightly rearrange16 the order of operations to get something that is more convenient for working out the structure.

Recall from last time that we used the left-looking Cholesky factorisation, which can be written in the dense case as

def dense_left_cholesky(A):
  n = A.shape[0]
  L = np.zeros_like(A)
  for j in range(n):
    L[j,j] = np.sqrt(A[j,j] - np.inner(L[j, :j], L[j, :j]))
    L[(j+1):, j] = (A[(j+1):, j] - L[(j+1):, :j] @ L[j, :j].transpose()) / L[j,j]
  return L

This is not the only way to organise those operations. An alternative is the up-looking Cholesky factorisation, which can be implemented in the dense case as

def dense_up_cholesky(A):
  n = A.shape[0]
  L = np.zeros_like(A)
  L[0,0] = np.sqrt(A[0,0])
  for i in range(1,n):
    #if i > 0:
    L[i, :i] = (np.linalg.solve(L[:i, :i], A[:i,i])).transpose()
    L[i, i] = np.sqrt(A[i,i] - np.inner(L[i, :i], L[i, :i]))
  return L

This is quite a different looking beast! It scans row by row rather than column by column. And while the left-looking algorithm is based on matrix-vector multiplies, the up-looking algorithm is based on triangular solves. So maybe we should pause for a moment to check that these are the same algorithm!

A = np.random.rand(15, 15)
A = A + A.transpose()
A = A.transpose() @ A + 1*np.eye(15)

L_left = dense_left_cholesky(A)
L_up = dense_up_cholesky(A)

print(round(sum(sum(abs((L_left - L_up)[:])))),2)
0 2

They are the same!!

The reason for considering the up-looking algorithm is that it gives a slightly nicer description of the non-zeros of row i, which will let us find the location of the non-zeros in the whole matrix. In particular, the non-zeros to the left of the diagonal on row i correspond to the non-zero indices of the solution to the lower triangular linear system17 \[ L_{1:(i-1),1:(i-1)} x^{(i)} = A_{1:i-1, i}. \] Because \(A\) is sparse, this is a system of \(\operatorname{nnz}(A_{1:i-1,i})\) linear equations, rather than \((i-1)\) equations that we would have in the dense case. That means that the sparsity pattern of \(x^{(i)}\) will be the union of the sparsity patterns of the columns of \(L_{1:(i-1),1:(i-1)}\) that correspond to the non-zero entries of \(A_{1:i-1, i}\).

This means two things. Firstly, if \(A_{ji}\neq 0\), then \(x^{(i)}_j \neq 0\). Secondly, if $x^{(i)}_j $ and \(L_{kj}\neq 0\), then \(x_k \neq 0\). These two facts give us a way of finding the non-zero set of \(x^{(i)}\) if we remember just one more fact: a definition of the elimination tree is that \(L_{kj} \neq 0\) if \(j\) is a descendant of \(k\) in the elimination tree.

This reduces the problem of finding the non-zero elements of \(x^{(i)}\) to the problem of finding all of the descendants of \(\{j: A_{ji} \neq 0\}\) in the subtree \(\mathcal{T}_{i-1}\). And if there is one thing that people who are ok at programming are excellent at it is walking down a damn tree.

So let’s do that. Well, I’ve already done it. In fact, that was how I found the column counts in the first place! With this interpretation, the outer loop is taking us across the rows. And once I am in row j18, I then find a starting node node (which is a non-zero in \(A_{1:(i-1),i}\)) and I walk along that node checking each time if I’ve actually seen that node19 before. If I haven’t seen it before, I added one to the column count of column node20.

To allocate the non-zero structure, I just need to replace that counter increment with an assignment.

Attempt 1: Lord that’s slow

We will do the pure python version first.

def symbolic_cholesky_base(A_indices, A_indptr, parent, col_count):
  n = len(A_indptr) - 1
  col_ptr = np.repeat(1, n+1)
  col_ptr[1:] += np.cumsum(col_count) 
  L_indices = np.zeros(sum(col_count), dtype=int)
  L_indptr = np.zeros(n+1, dtype=int)
  L_indptr[1:] = np.cumsum(col_count)
  mark = [-1] * n

  for i in range(n):
    mark[i] = i
    L_indices[L_indptr[i]] = i

    for indptr in range(A_indptr[i], A_indptr[i+1]):
      node = A_indices[indptr]
      while node < i and mark[node] != i:
        mark[node] = i
        L_indices[col_ptr[node]] = i
        col_ptr[node] += 1
        node = parent[node]
  
  return (L_indices, L_indptr)

Does it work?

A_indices, A_indptr, A_x, A = make_matrix(13)
parent, col_count = etree_base(A.indices, A.indptr)

L_indices, L_indptr = symbolic_cholesky_base(A.indices, A.indptr, parent, col_count)
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)

print(all(x==y for (x,y) in zip(L_indices, L_indices_true)))
print(all(x==y for (x,y) in zip(L_indptr, L_indptr_true)))
True
True

Fabulosa!

Now let’s do the compiled version.

from functools import partial
@partial(jit, static_argnums = (4,))
def symbolic_cholesky(A_indices, A_indptr, L_indptr, parent, nnz):
  
  ## innermost while loop
  def body_while(val):
    i, L_indices, L_indptr, node, parent, col_ptr, mark = val
    mark = mark.at[node].set(i)
    #p = 
    L_indices = L_indices.at[col_ptr[node]].set(i)
    col_ptr = col_ptr.at[node].add(1)
    return (i, L_indices, L_indptr, parent[node], parent, col_ptr, mark)

  def cond_while(val):
    i, L_indices, L_indptr, node, parent, col_ptr, mark = val
    return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))

  ## Inner for loop
  def body_inner_for(indptr, val):
    i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
    node = A_indices[indptr]
    i, L_indices, L_indptr, node, parent, col_ptr, mark = lax.while_loop(cond_while, body_while, (i, L_indices, L_indptr, node, parent, col_ptr, mark))
    return (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
  
  ## Outer for loop
  def body_out_for(i, val):
     A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
     mark = mark.at[i].set(i)
     L_indices = L_indices.at[L_indptr[i]].set(i)
     i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(A_indptr[i], A_indptr[i+1], body_inner_for, (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark))
     return (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)

  ## Body of code
  n = len(A_indptr) - 1
  col_ptr = L_indptr + 1
  L_indices = jnp.zeros(nnz, dtype=int)
  
  mark = jnp.repeat(-1, n)
  
  init = (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
  A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(0, n, body_out_for, init)
  return L_indices

Now let’s check it works

A_indices, A_indptr, A_x, A = make_matrix(20)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)


L_indices = symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1])
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))

A_indices, A_indptr, A_x, A = make_matrix(31)

parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)


L_indices = symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1])
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)

print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
True
True
True
True

Success!

One minor problem. This is slow. as. balls.

A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree_base(A.indices, A.indptr)
times = timeit.repeat(lambda: symbolic_cholesky_base(A.indices, A.indptr, parent, col_count),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree_base(A.indices, A.indptr)
times = timeit.repeat(lambda: symbolic_cholesky_base(A.indices, A.indptr, parent, col_count),number = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
n = 2500: [0.05, 0.04, 0.04, 0.04, 0.04]
n = 40000: [1.97, 2.09, 2.04, 2.03, 1.92]

And here is our JAX’d and JIT’d code.

A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
n = 2500: [0.15]
n = 40000: [29.19]

Oooof. Something is going horribly wrong.

Why is it so slow?

The first thing to check is if it’s the compile time. We can do this by explicitly lowering the the JIT’d function to its XLA representation and then compiling it.

A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: jit(partial(symbolic_cholesky, nnz=int(L_indptr[-1]))).lower(A.indices, A.indptr, L_indptr, parent).compile(),number = 1, repeat = 5)
print(f"Compilation time: n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: jit(partial(symbolic_cholesky, nnz=int(L_indptr[-1]))).lower(A.indices, A.indptr, L_indptr, parent).compile(),number = 1, repeat = 5)
print(f"Compilation time: n = {A.shape[0]}: {[round(t,2) for t in times]}")
Compilation time: n = 2500: [0.15, 0.15, 0.15, 0.16, 0.15]
Compilation time: n = 40000: [0.16, 0.15, 0.14, 0.16, 0.15]

It is not the compile time.

And that is actually a good thing because that suggests that we aren’t having problems with the compiler unrolling all of our wonderful loops! But that does mean that we have to look a bit deeper into the code. Some smart people would probably be able to look at the jaxpr intermediate representation to diagnose the problem. But I couldn’t see anything there.

Instead I thought if I were a clever, efficient compiler, what would I have problems with?. And the answer is the classic sparse matrix answer: indirect indexing.

The only structural difference between the etree function and the symbolic_cholesky function is this line in the body_while() function:

 L_indices = L_indices.at[col_ptr[node]].set(i)

In order to evaluate this code, the compiler has to resolve two levels of indirection. By contrast, the indexing in etree() was always direct. So let’s see what happens if we take the same function and remove that double indirection.

@partial(jit, static_argnums = (4,))
def test_fun(A_indices, A_indptr, L_indptr, parent, nnz):
  
  ## innermost while loop
  def body_while(val):
    i, L_indices, L_indptr, node, parent, col_ptr, mark = val
    mark = mark.at[node].set(i)
    L_indices = L_indices.at[node].set(i)
    col_ptr = col_ptr.at[node].add(1)
    return (i, L_indices, L_indptr, parent[node], parent, col_ptr, mark)

  def cond_while(val):
    i, L_indices, L_indptr, node, parent, col_ptr, mark = val
    return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))

  ## Inner for loop
  def body_inner_for(indptr, val):
    i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
    node = A_indices[indptr]
    i, L_indices, L_indptr, node, parent, col_ptr, mark = lax.while_loop(cond_while, body_while, (i, L_indices, L_indptr, node, parent, col_ptr, mark))
    return (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
  
  ## Outer for loop
  def body_out_for(i, val):
     A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = val
     mark = mark.at[i].set(i)
     L_indices = L_indices.at[L_indptr[i]].set(i)
     i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(A_indptr[i], A_indptr[i+1], body_inner_for, (i, A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark))
     return (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)

  ## Body of code
  n = len(A_indptr) - 1
  col_ptr = L_indptr + 1
  L_indices = jnp.zeros(nnz, dtype=int)
  
  mark = jnp.repeat(-1, n)
  
  init = (A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark)
  A_indices, A_indptr, L_indices, L_indptr, parent, col_ptr, mark = lax.fori_loop(0, n, body_out_for, init)
  return L_indices

A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:test_fun(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:test_fun(A.indices, A.indptr, L_indptr, parent, nnz = L_indptr[-1]),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
n = 2500: [0.14]
n = 40000: [0.17]

That isn’t conclusive, but it does indicate that this might21 be the problem.

And this is a big problem for us! The sparse Cholesky algorithm has similar amounts of indirection. So we need to fix it.

Attempt 2: After some careful thought, things stayed the same

Now. I want to pretend that I’ve got elegant ideas about this. But I don’t. So let’s just do it. The most obvious thing to do is to use the algorithm to get the non-zero structure of the rows of \(L\). These are the things that are being indexed by col_ptr[node]], so if we have these explicitly we don’t need multiple indirection. We also don’t need a while loop.

In fact, if we have the non-zero structure of the rows of \(L\), we can turn that into the non-zero structure of the columns in linear-ish22 time.

All we need to do is make sure that our etree() function is also counting the number of nonzeros in each row.

@jit
def etree(A_indices, A_indptr):
 # print("(Re-)compiling etree(A_indices, A_indptr)")
  ## innermost while loop
  def body_while(val):
  #  print(val)
    j, node, parent, col_count, row_count, mark = val
    update_parent = lambda x: x[0].at[x[1]].set(x[2])
    parent = lax.cond(lax.eq(parent[node], -1), update_parent, lambda x: x[0], (parent, node, j))
    mark = mark.at[node].set(j)
    col_count = col_count.at[node].add(1)
    row_count = row_count.at[j].add(1)
    return (j, parent[node], parent, col_count, row_count, mark)

  def cond_while(val):
    j, node, parent, col_count, row_count, mark = val
    return lax.bitwise_and(lax.lt(node, j), lax.ne(mark[node], j))

  ## Inner for loop
  def body_inner_for(indptr, val):
    j, A_indices, A_indptr, parent, col_count, row_count, mark = val
    node = A_indices[indptr]
    j, node, parent, col_count, row_count, mark = lax.while_loop(cond_while, body_while, (j, node, parent, col_count, row_count, mark))
    return (j, A_indices, A_indptr, parent, col_count, row_count, mark)
  
  ## Outer for loop
  def body_out_for(j, val):
     A_indices, A_indptr, parent, col_count, row_count, mark = val
     mark = mark.at[j].set(j)
     j, A_indices, A_indptr, parent, col_count, row_count, mark = lax.fori_loop(A_indptr[j], A_indptr[j+1], body_inner_for, (j, A_indices, A_indptr, parent, col_count, row_count, mark))
     return (A_indices, A_indptr, parent, col_count, row_count, mark)

  ## Body of code
  n = len(A_indptr) - 1
  parent = jnp.repeat(-1, n)
  mark = jnp.repeat(-1, n)
  col_count = jnp.repeat(1,  n)
  row_count = jnp.repeat(1, n)
  init = (A_indices, A_indptr, parent, col_count, row_count, mark)
  A_indices, A_indptr, parent, col_count, row_count, mark = lax.fori_loop(0, n, body_out_for, init)
  return (parent, col_count, row_count)

Let’s check that the code is actually doing what I want.

A_indices, A_indptr, A_x, A = make_matrix(57)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indices, L_indptr = _symbolic_factor(A_indices, A_indptr)

true_parent = L_indices[L_indptr[:-2] + 1]
true_parent[np.where(np.diff(L_indptr[:-1]) == 1)] = -1
print(all(x == y for (x,y) in zip(parent[:-1], true_parent)))

true_col_count  = np.diff(L_indptr)
print(all(true_col_count == col_count))

true_row_count = np.array([len(np.where(L_indices == i)[0]) for i in range(57**2)])
print(all(true_row_count == row_count))
True
True
True

Excellent! With this we can modify our previous function to give us the row-indices of the non-zero pattern instead. Just for further chaos, please note that we are using a CSC representation of \(A\) to get a CSR representation of \(L\).

Once again, we will prototype in pure python and then translate to JAX. The thing to look out for this time is that we know how many non-zeros there are in a row and we know where we need to put them. This suggests that we can compute these things in body_inner_for and then do a vectorised version of our indirect indexing. This should compile down to a single XLA scatter call. This will reduce the number of overall scatter calls from \(\operatorname(nnz)(L)\) to \(n\). And hopefully this will fix things.

def symbolic_cholesky2_base(A_indices, A_indptr, L_indptr, row_count, parent, nnz):

  n = len(A_indptr) - 1
  col_ptr = L_indptr + 1
  L_indices = np.zeros(L_indptr[-1], dtype=int)
  mark = [-1] * n

  for i in range(n):
    mark[i] = i
    row_ind = np.repeat(nnz+1, row_count[i])
    row_ind[-1] = L_indptr[i]
    counter = 0
    for indptr in range(A_indptr[i], A_indptr[i+1]):
      node = A_indices[indptr]
      while node < i and mark[node] != i:
        mark[node] = i
        row_ind[counter] = col_ptr[node]
        col_ptr[node] += 1
        node = parent[node]
        counter +=1
    L_indices[row_ind] = i
  
  return L_indices


A_indices, A_indptr, A_x, A = make_matrix(13)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)


L_indices = symbolic_cholesky2_base(A.indices, A.indptr, L_indptr, row_count, parent, L_indptr[-1])
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)

print(all(x==y for (x,y) in zip(L_indices, L_indices_true)))
print(all(x==y for (x,y) in zip(L_indptr, L_indptr_true)))
True
True

Excellent. Now let’s JAX this. The JAX-heads among you will notice that we have a subtle23 problem: in a fori_loop, JAX does not treat i as static, which means that the length of the repeat (row_count[i]) can never be static and it therefore can’t be traced.

Shit.

It is hard to think of a good option here. A few months back Junpeng Lao24 sent me a script with his attempts at making the Cholesky stuff JAX transformable. And he hit the same problem. I was, in an act of hubris, trying very hard to not end up here. But that was tragically slow. So here we are.

He came up with two methods.

  1. Pad out row_ind so it’s always long enough. This only costs memory. The maximum size of row_ind is n. Unfortunately, this happens whenever \(A\) has a dense row. Sadly, for Bayesian25 linear mixed models this will happen if we put Gaussian priors on the covariate coefficients26 and we try to marginalise them out with the other multivariate Gaussian parts. It is possible to write the routines that deal with dense rows and columns explicitly, but it’s a pain in the arse.

  2. Do some terrifying work with lax.scan and dynamic slicing.

I’m going to try the first of these options.

@partial(jit, static_argnums = (5, 6))
def symbolic_cholesky2(A_indices, A_indptr, L_indptr, row_count, parent, nnz, max_row):
  ## innermost while loop
  def body_while(val):
    i, counter, row_ind, node, col_ptr, mark = val
    mark = mark.at[node].set(i)
    row_ind = row_ind.at[counter].set(col_ptr[node])
    col_ptr = col_ptr.at[node].add(1)
    return (i, counter + 1, row_ind, parent[node], col_ptr, mark)

  def cond_while(val):
    i, counter, row_ind, node, col_ptr, mark = val
    return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))

  ## Inner for loop
  def body_inner_for(indptr, val):
    i, counter, row_ind, parent, col_ptr, mark = val
    node = A_indices[indptr]
    i, counter, row_ind, node, col_ptr, mark = lax.while_loop(cond_while, body_while, (i, counter, row_ind, node, col_ptr, mark))
    return (i, counter, row_ind, parent, col_ptr, mark)
  
  ## Outer for loop
  def body_out_for(i, val):
     L_indices, parent, col_ptr, mark = val
     mark = mark.at[i].set(i)
     row_ind = jnp.repeat(nnz+1, max_row)
     row_ind = row_ind.at[row_count[i]-1].set(L_indptr[i])
     counter = 0

     i, counter, row_ind, parent, col_ptr, mark = lax.fori_loop(A_indptr[i], A_indptr[i+1], body_inner_for, (i, counter, row_ind, parent, col_ptr, mark))

     L_indices = L_indices.at[row_ind].set(i, mode = "drop")
     return (L_indices, parent, col_ptr, mark)

  ## Body of code
  n = len(A_indptr) - 1

  col_ptr = jnp.array(L_indptr + 1)
  L_indices = jnp.ones(nnz, dtype=int) * (-1)
  mark = jnp.repeat(-1, n)

  ## Make everything a jnp array. Really should use jaxtyping
  A_indices = jnp.array(A_indices)
  A_indptr = jnp.array(A_indptr)
  L_indptr = jnp.array(L_indptr)
  row_count = jnp.array(row_count)
  parent = jnp.array(parent)

  init = (L_indices, parent, col_ptr, mark)
  L_indices, parent, col_ptr, mark = lax.fori_loop(0, n, body_out_for, init)
  return L_indices

Ok. Let’s see if that worked.

A_indices, A_indptr, A_x, A = make_matrix(20)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)


L_indices = symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count)))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))

A_indices, A_indptr, A_x, A = make_matrix(31)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)


L_indices = symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count)))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))
True
True
True
True

Ok. Once more into the breach. Is this any better?

A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count))),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda:symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count))),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

# A_indices, A_indptr, A_x, A = make_matrix(300)
# parent, col_count, row_count = etree(A.indices, A.indptr)
# L_indptr = np.zeros(A.shape[0]+1, dtype=int)
# L_indptr[1:] = np.cumsum(col_count)
# times = timeit.repeat(lambda:symbolic_cholesky2(A.indices, A.indptr, L_indptr, row_count, parent, nnz = int(L_indptr[-1]), max_row = int(max(row_count))),number = 1, repeat = 1)
# print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
n = 2500: [0.28]
n = 40000: [28.31]

Fuck.

Attempt 3: A desperate attempt to make this bloody work

Right. Let’s try again. What if instead of doing all those scatters we instead, idk, just store two vectors and sort. Because at this point I will try fucking anything. What if we just list out the column index and row index as we find them (aka build the sparse matrix in COO27 format. The jax.experimental.sparse module has support for (blocked) COO objects but doesn’t implement this transformation. scipy.sparse has a fast conversion routine so I’m going to use it. In the interest of being 100% JAX, I tried a version with jnp.lexsort[index[1][jnp.lexsort((index[1], index[0]))]], which basically does the same thing but it’s a lot slower.

def symbolic_cholesky3(A_indices, A_indptr, L_indptr, parent, nnz):
  @partial(jit, static_argnums = (4,))
  def _inner(A_indices_, A_indptr_, L_indptr, parent, nnz):
    ## Make everything a jnp array. Really should use jaxtyping
    A_indices_ = jnp.array(A_indices_)
    A_indptr_ = jnp.array(A_indptr_)
    L_indptr = jnp.array(L_indptr)
    parent = jnp.array(parent)

    ## innermost while loop
    def body_while(val):
      index, i, counter,  node,  mark = val
      mark = mark.at[node].set(i)
      index[0] = index[0].at[counter].set(node) #column
      index[1] = index[1].at[counter].set(i) # row
      return (index, i, counter + 1, parent[node], mark)

    def cond_while(val):
      index, i, counter,  node,  mark = val
      return lax.bitwise_and(lax.lt(node, i), lax.ne(mark[node], i))

    ## Inner for loop
    def body_inner_for(indptr, val):
      index, i, counter, mark = val
      node = A_indices_[indptr]
      
      index, i, counter,  node,  mark = lax.while_loop(cond_while, body_while, (index, i, counter,  node,  mark))
      return (index, i, counter,  mark)
    
    ## Outer for loop
    def body_out_for(i, val):
      index, counter,  mark = val
      mark = mark.at[i].set(i)
      index[0] = index[0].at[counter].set(i)
      index[1] = index[1].at[counter].set(i)
      counter = counter + 1
      index, i, counter, mark = lax.fori_loop(A_indptr_[i], A_indptr_[i+1], body_inner_for, (index, i, counter,  mark))

      return (index, counter,  mark)

    ## Body of code
    n = len(A_indptr_) - 1
    mark = jnp.repeat(-1, n)

    index = [jnp.zeros(nnz, dtype=int), jnp.zeros(nnz, dtype=int)]
    counter = 0

    init = (index, counter, mark)
    index, counter, mark = lax.fori_loop(0, n, body_out_for, init)
    
    return index
  n = len(A_indptr) - 1
  index = _inner(A_indices, A_indptr, L_indptr, parent, nnz)
  ## return jnp.lexsort[index[1][jnp.lexsort((index[1], index[0]))
  return sparse.coo_array((np.ones(nnz), (index[1], index[0])), shape = (n,n)).tocsc().indices

First things first, let’s check how fast this is.

A_indices, A_indptr, A_x, A = make_matrix(15)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1]))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))

A_indices, A_indptr, A_x, A = make_matrix(31)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
L_indices = symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1]))
L_indices_true, L_indptr_true = _symbolic_factor(A_indices, A_indptr)
print(all(L_indices == L_indices_true))
print(all(L_indptr == L_indptr_true))

A_indices, A_indptr, A_x, A = make_matrix(50)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 5)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(200)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 5)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(300)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 5)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")

A_indices, A_indptr, A_x, A = make_matrix(1000)
parent, col_count, row_count = etree(A.indices, A.indptr)
L_indptr = np.zeros(A.shape[0]+1, dtype=int)
L_indptr[1:] = np.cumsum(col_count)
times = timeit.repeat(lambda: symbolic_cholesky3(A.indices, A.indptr, L_indptr, parent, nnz = int(L_indptr[-1])),number = 1, repeat = 1)
print(f"n = {A.shape[0]}: {[round(t,2) for t in times]}")
True
True
True
True
n = 2500: [0.13, 0.13, 0.13, 0.14, 0.14]
n = 40000: [0.19, 0.19, 0.19, 0.19, 0.19]
n = 90000: [0.43, 0.32, 0.32, 0.33, 0.33]
n = 1000000: [11.91]

You know what? I’ll take it. It’s not perfect, in particular I would prefer a pure JAX solution. But everything I tried hit hard against the indirect memory access issue. The best I found was using jnp.lexsort but even it had noticeable performance degradation as nnz increased relative to the scipy solution.

Next time on Sparse Matrices with Dan

So that’s where I’m going to leave it. I am off my flight and I’ve slept very well and now I’m going to be on holidays for a little while.

The next big thing to do is look at the numerical factorisation. We are going to run headlong into all of the problems we’ve hit today, so that should be fun. The reason why I’m separating it into a separate post28 is that I want to actually test all of those things out properly.

So next time you can expect

  1. Classes! Because frankly this code is getting far too messy, especially now that certain things need to be passed as static arguments. The only reason I’ve avoided it up to now is that I think it hides too much of the algorithm in boilerplate. But now the boilerplate is ruining my life and causing far too many dumb typos29.

  2. Type hints! Because for a language where types aren’t explicit, they sure are important. Also because I’m going to class it up I might as well do it properly.

  3. Some helper routines! I’m going to need a sparse-matrix scatter operation (aka the structured copy of \(A\) to have the sparsity pattern of \(L\))! And I’m certainly going to need some reorderings30

  4. A battle royale between padded and non-padded methods!

It should be a fun time!

Footnotes

  1. If you’re wondering about the break between sparse matrix posts, I realised this pretty much immediately and just didn’t want to deal with it!↩︎

  2. If a person who actually knows how the JAX autodiff works happens across this blog, I’m so sorry.↩︎

  3. omg you guys. So many details↩︎

  4. These are referred to as HLOs (Higher-level operations)↩︎

  5. Instead of doing one pass of reverse-mode, you would need to do \(d\) passes of forwards mode to get the gradient with respect to a d-dimensional parameter.↩︎

  6. Unlike jax.lax.while, which is only forwards differentiable, jax.lax.scan is fully differentiable.↩︎

  7. In general, if the function has state.↩︎

  8. This is the version of the symbolic factorisation that is most appropriate for us, as we will be doing a lot of Cholesky factorisations with the same sparsity structure. If we rearrange the algorithm to the up-looking Cholesky decomposition, we only need the column counts and this is also called the symbolic factorisation. This is, incidentally, how Eigen’s sparse Cholesky works.↩︎

  9. Actually it’s a forest↩︎

  10. Because we are talking about a tree, each child node has at most one parent. If it doesn’t have a parent it’s the root of the tree. I remember a lecturer saying that it should be called “father and son” or “mother and daughter” because every child has 2 parents but only one mother or one father. The 2000s were a wild time.↩︎

  11. These can also be computed in approximately \(\mathcal{O(\text{nnz}(A))}\) time, which is much faster. But the algorithm is, frankly, pretty tricky and I’m not in the mood to program it up. This difference would be quite important if I wasn’t storing the full symbolic factorisation and was instead computing it every time, but in my context it is less clear that this is worth the effort.↩︎

  12. Python notation! This is rows/cols 0 to j-1↩︎

  13. Python, it turns out, does not have a do while construct because, apparently, everything is empty and life is meaningless.↩︎

  14. The argument for JIT works by amortizing the compile time over several function evaluations. If I wanted to speed this algorithm up, I’d implement the more complex \(\mathcal{O}(\operatorname{nnz}(A))\) version.↩︎

  15. Obviously it did not work the first time. A good way to debug JIT’d code is to use the python translations of the control flow literals. Why? Well for one thing there is an annoying tendency for JAX to fail silently when their is an out-of-bounds indexing error. Which happens, just for example, if you replace node = A_indices[indptr] with node = A_indices[A_indptr[indptr]] because you got a text message half way through the line.↩︎

  16. We will still use the left-looking algorithm for the numerical computation. The two algorithms are equivalent in exact arithmetic and, in particular, have identical sparsity structures.↩︎

  17. I’m mixing 1-based indexing in the maths with 0-based in the code because I think we need more chaos in our lives.↩︎

  18. Yes. I know. I’m swapping the meaning of \(i\) and \(j\) but you know that’s because in a symmetric matrix rows and columns are a bit similar. The upper half of column $$ is the left half of row \(j\) after all.↩︎

  19. If mark[node]==j then I have already found node and all of its ancestors in my sweep of row j↩︎

  20. This is because L[j,node] != 0 by our logic.↩︎

  21. I mean, I’m pretty sure it is. I’m writing this post in order, so I don’t know yet. But surely the compiler can’t reason about the possible values of node, which would be the only thing that would speed this up.↩︎

  22. Convert from CSR to (i, j, val) (called COO, which has a convenient implementation in jax.experimental.sparse) to CSC. This involves a linear pass, a sort, and another linear pass. So it’s $n n`ish. Hire me fancy tech companies. I can count. Just don’t ask me to program quicksort.↩︎

  23. Replace “subtle” with “fairly obvious once I realised how it’s converted to a lax.scan, but not at all obvious to me originally”.↩︎

  24. Who demanded a footnote.↩︎

  25. This also happens with the profile likelihood in non-Bayesian methods.↩︎

  26. the \(\beta\)s↩︎

  27. COO stands for coordinate list and it’s the least space-efficient of our options. It directly stores 3 length n vectors (row, col, value). It’s great for specifying matrices and it’s pretty easy to convert from this format to any of the others.↩︎

  28. other than holiday↩︎

  29. A_index and A.index are different↩︎

  30. I’m probably going to bind Eigen’s AMD decomposition. I’m certainly not writing it myself.↩︎

Reuse

Citation

BibTeX citation:
@online{simpson2022,
  author = {Dan Simpson},
  editor = {},
  title = {Sparse Matrices Part 7a: {Another} Shot at {JAX-ing} the
    {Cholesky} Decomposition},
  date = {2022-12-02},
  url = {https://dansblog.netlify.app/posts/2022-11-27-sparse7/sparse7.html},
  langid = {en}
}
For attribution, please cite this work as:
Dan Simpson. 2022. “Sparse Matrices Part 7a: Another Shot at JAX-Ing the Cholesky Decomposition.” December 2, 2022. https://dansblog.netlify.app/posts/2022-11-27-sparse7/sparse7.html.