from scipy import sparse
import numpy as np
def make_matrix(n):
one_d = sparse.diags([[-1.]*(n-1), [2.]*n, [-1.]*(n-1)], [-1,0,1])
A = (sparse.kronsum(one_d, one_d) + sparse.eye(n*n)).tocsc()
A_lower = sparse.tril(A, format = "csc")
A_index = A_lower.indices
A_indptr = A_lower.indptr
A_x = A_lower.data
return (A_index, A_indptr, A_x, A)
A_indices, A_indptr, A_x, A = make_matrix(10)
This is part five of our ongoing series on implementing differentiable sparse linear algebra in JAX. In some sense this is the last boring post before we get to the derivatives. Was this post going to include the derivatives? It sure was but then I realised that a different choice was to go to bed so I can get up nice and early in the morning and vote in our election.
It goes without saying that before I split the posts, it was more than twice as long and I was nowhere near finished. So probably the split was a good choice.
But how do you add a primative to JAX?
Well, the first step is you read the docs.
They tell you that you need to implement a few things:
- An implementation of the call with “abstract types”
- An implementation of the call with concrete types (aka evaluation the damn function)
Then,
if you want your primitive to be JIT-able, you need to implement a compilation rule.
if you want your primitive to be batch-able, you need to implement a batching rule.
if you want your primitive to be differentiable, you need to implement the derivatives in a way that allows them to be propagated appropriately.
In this post, we are going to do the first task: we are going to register JAX-traceable versions of the four main primitives we are going to need for our task. For the most part, the implementations here will be replaced with C++ bindings (because only a fool writes their own linear algebra code). But this is the beginning1 of our serious journey into JAX.
First things first, some primitives
In JAX-speak, a primitive is a function that is JAX-traceable2. It is not necessary for every possible transformation to be implemented. In fact, today I’m not going to implement any transformations. That is a problem for future Dan.
We have enough today problems.
Because today we need to write four new primitives.
But first of all, let’s build up a test matrix so we can at least check that this code runs. This is the same example from blog 3. You can tell my PhD was in numerical analysis because I fucking love a 2D Laplacian.
Primitive one: \(A^{-1}b\)
Because I’m feeling lazy today and we don’t actually need the Cholesky directly for any of this, I’m going to just use scipy. Why? Well, honestly, just because I’m lazy. But also so I can prove an important point: the implementation of the primitive does not need to be JAX traceable. So I’m implementing it in a way that is not now and will likely never be JAX traceable3.
First off, we need to write the solve function and bind it4 to JAX. Specific information about what exactly some of these commands are doing can be found in the docs, but the key thing is that there is no reason to dick around whit JAX types in any of these implementation functions. They are only ever called using (essentially) numpy5 arrays. So we can just program like normal human beings.
from jax import numpy as jnp
from jax import core
sparse_solve_p = core.Primitive("sparse_solve")
def sparse_solve(A_indices, A_indptr, A_x, b):
"""A JAX traceable sparse solve"""
return sparse_solve_p.bind(A_indices, A_indptr, A_x, b)
@sparse_solve_p.def_impl
def sparse_solve_impl(A_indices, A_indptr, A_x, b):
"""The implementation of the sparse solve. This is not JAX traceable."""
A_lower = sparse.csc_array((A_x, A_indices, A_indptr))
assert A_lower.shape[0] == A_lower.shape[1]
assert A_lower.shape[0] == b.shape[0]
A = A_lower + A_lower.T - sparse.diags(A_lower.diagonal())
return sparse.linalg.spsolve(A, b)
## Check it works
b = jnp.ones(100)
x = sparse_solve(A_indices, A_indptr, A_x, b)
print(f"The error in the sparse sovle is {np.sum(np.abs(b - A @ x)): .2e}")
The error in the sparse sovle is 0.00e+00
In order to facilitate its transformations, JAX will occasionally6 call functions using abstract data types. These data types know the shape of the inputs and their data type. So our next step is to specialise the sparse_solve
function for this case. We might as well do some shape checking while we’re just hanging around. But the essential part of this function is just saying that the output of \(A^{-1}b\) is the same shape as \(b\) (which is usually a vector, but the code is no more complex if it’s a [dense] matrix).
Primitive two: The triangular solve
This is very similar. We need to have a function that computes \(L^{-1}b\) and \(L^{-T}b\). The extra wrinkle from the last time around is that we need to pass a keyword argument transpose
to indicate which system should be solved.
Once again, we are going to use the appropriate scipy
function (in this case sparse.linalg.spsolve_triangular
). There’s a little bit of casting between sparse matrix types here as sparse.linalg.spsolve_triangular
assumes the matrix is in CSR format.
sparse_triangular_solve_p = core.Primitive("sparse_triangular_solve")
def sparse_triangular_solve(L_indices, L_indptr, L_x, b, *, transpose: bool = False):
"""A JAX traceable sparse triangular solve"""
return sparse_triangular_solve_p.bind(L_indices, L_indptr, L_x, b, transpose = transpose)
@sparse_triangular_solve_p.def_impl
def sparse_triangular_solve_impl(L_indices, L_indptr, L_x, b, *, transpose = False):
"""The implementation of the sparse triangular solve. This is not JAX traceable."""
L = sparse.csc_array((L_x, L_indices, L_indptr))
assert L.shape[0] == L.shape[1]
assert L.shape[0] == b.shape[0]
if transpose:
return sparse.linalg.spsolve_triangular(L.T, b, lower = False)
else:
return sparse.linalg.spsolve_triangular(L.tocsr(), b, lower = True)
Now we can check if it works. We can use the fact that our matrix (A_indices, A_indptr, A_x)
is lower-triangular (because we only store the lower triangle) to make our test case.
## Check if it works
b = np.random.standard_normal(100)
x1 = sparse_triangular_solve(A_indices, A_indptr, A_x, b)
x2 = sparse_triangular_solve(A_indices, A_indptr, A_x, b, transpose = True)
print(f"""Error in trianglular solve: {np.sum(np.abs(b - sparse.tril(A) @ x1)): .2e}
Error in triangular transpose solve: {np.sum(np.abs(b - sparse.triu(A) @ x2)): .2e}""")
Error in trianglular solve: 3.53e-15
Error in triangular transpose solve: 5.08e-15
And we can also do the abstract evaluation.
Great! Now on to the next one!
Primitive three: The sparse cholesky
Ok. This one is gonna be a pain in the arse. But we need to do it. Why? Because we are going to need a JAX-traceable version further on down the track.
The issue here is that the non-zero pattern of the Cholesky decomposition is computed on the fly. This is absolutely not allowed in JAX. It must know the shape of all things at the moment it is called.
This is going to make for a somewhat shitty user experience for this function. It’s unavoidable with JAX designed7 the way it is.
The code in jax.experimental.sparse.bcoo.fromdense
has this exact problem. In their case, they are turning a dense matrix into a sparse matrix and they can’t know until they see the dense matrix how many non-zeros there are. So they do the sensible thing and ask the user to specify it. They do this using the nse
keyword parameter. If you’re curious what nse
stands for, it turns out it’s not “non-standard evaluation” but rather “number of specified entries”. Most other systems use the abbreviation nnz
for “number of non-zeros”, but I’m going to stick with the JAX notation.
The one little thing we need to add to this code is a guard to make sure that if the sparse_cholesky
function is called without specifying
sparse_cholesky_p = core.Primitive("sparse_cholesky")
def sparse_cholesky(A_indices, A_indptr, A_x, *, L_nse: int = None):
"""A JAX traceable sparse cholesky decomposition"""
if L_nse is None:
err_string = "You need to pass a value to L_nse when doing fancy sparse_cholesky."
_ = core.concrete_or_error(None, A_x, err_string)
return sparse_cholesky_p.bind(A_indices, A_indptr, A_x, L_nse = L_nse)
@sparse_cholesky_p.def_impl
def sparse_cholesky_impl(A_indices, A_indptr, A_x, *, L_nse = None):
"""The implementation of the sparse cholesky This is not JAX traceable."""
L_indices, L_indptr= _symbolic_factor(A_indices, A_indptr)
if L_nse is not None:
assert len(L_indices) == nse
L_x = _structured_copy(A_indices, A_indptr, A_x, L_indices, L_indptr)
L_x = _sparse_cholesky_impl(L_indices, L_indptr, L_x)
return L_indices, L_indptr, L_x
The rest of the code is just the sparse Cholesky code from blog 2 and I’ve hidden it under the fold. (You would think I would package this up properly, but I simply haven’t. Why not? Who knows8.)
Click here to see the implementation
def _symbolic_factor(A_indices, A_indptr):
# Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.
n = len(A_indptr) - 1
L_sym = [np.array([], dtype=int) for j in range(n)]
children = [np.array([], dtype=int) for j in range(n)]
for j in range(n):
L_sym[j] = A_indices[A_indptr[j]:A_indptr[j + 1]]
for child in children[j]:
tmp = L_sym[child][L_sym[child] > j]
L_sym[j] = np.unique(np.append(L_sym[j], tmp))
if len(L_sym[j]) > 1:
p = L_sym[j][1]
children[p] = np.append(children[p], j)
L_indptr = np.zeros(n+1, dtype=int)
L_indptr[1:] = np.cumsum([len(x) for x in L_sym])
L_indices = np.concatenate(L_sym)
return L_indices, L_indptr
def _structured_copy(A_indices, A_indptr, A_x, L_indices, L_indptr):
n = len(A_indptr) - 1
L_x = np.zeros(len(L_indices))
for j in range(0, n):
copy_idx = np.nonzero(np.in1d(L_indices[L_indptr[j]:L_indptr[j + 1]],
A_indices[A_indptr[j]:A_indptr[j+1]]))[0]
L_x[L_indptr[j] + copy_idx] = A_x[A_indptr[j]:A_indptr[j+1]]
return L_x
def _sparse_cholesky_impl(L_indices, L_indptr, L_x):
n = len(L_indptr) - 1
descendant = [[] for j in range(0, n)]
for j in range(0, n):
tmp = L_x[L_indptr[j]:L_indptr[j + 1]]
for bebe in descendant[j]:
k = bebe[0]
Ljk= L_x[bebe[1]]
pad = np.nonzero( \
L_indices[L_indptr[k]:L_indptr[k+1]] == L_indices[L_indptr[j]])[0][0]
update_idx = np.nonzero(np.in1d( \
L_indices[L_indptr[j]:L_indptr[j+1]], \
L_indices[(L_indptr[k] + pad):L_indptr[k+1]]))[0]
tmp[update_idx] = tmp[update_idx] - \
Ljk * L_x[(L_indptr[k] + pad):L_indptr[k + 1]]
diag = np.sqrt(tmp[0])
L_x[L_indptr[j]] = diag
L_x[(L_indptr[j] + 1):L_indptr[j + 1]] = tmp[1:] / diag
for idx in range(L_indptr[j] + 1, L_indptr[j + 1]):
descendant[L_indices[idx]].append((j, idx))
return L_x
Once again, we can check to see if this worked!
L_indices, L_indptr, L_x = sparse_cholesky(A_indices, A_indptr, A_x)
L = sparse.csc_array((L_x, L_indices, L_indptr))
print(f"The error in the sparse cholesky is {np.sum(np.abs((A - L @ L.T).todense())): .2e}")
The error in the sparse cholesky is 1.02e-13
And, of course, we can do abstract evaluation. Here is where we actually need to use L_nse
to work out the dimension of our output.
Primitive four: \(\log(|A|)\)
And now we have our final primitive: the log determinant! Wow. So much binding. For this one, we compute the Cholesky factorisation and note that \[\begin{align*} |A| = |LL^T| = |L||L^T| = |L|^2. \end{align*}\] If we successfully remember that the determinant of a triangular matrix is the product of its diagonal entries, we have a formula we can implement.
Same deal as last time.
sparse_log_det_p = core.Primitive("sparse_log_det")
def sparse_log_det(A_indices, A_indptr, A_x):
"""A JAX traceable sparse log-determinant"""
return sparse_log_det_p.bind(A_indices, A_indptr, A_x)
@sparse_log_det_p.def_impl
def sparse_log_det_impl(A_indices, A_indptr, A_x):
"""The implementation of the sparse log-determinant. This is not JAX traceable.
"""
L_indices, L_indptr, L_x = sparse_cholesky_impl(A_indices, A_indptr, A_x)
return 2.0 * sum(np.log(L_x[L_indptr[:-1]]))
A canny reader may notice that I’m assuming that the first element in each column is the diagonal. This will be true as long as the diagonal elements of \(L\) are non-zero, which is true as long as \(A\) is symmetric positive definite.
Let’s test9 it out.
ld = sparse_log_det(A_indices, A_indptr, A_x)
LU = sparse.linalg.splu(A)
ld_true = sum(np.log(LU.U.diagonal()))
print(f"The error in the log-determinant is {ld - ld_true: .2e}")
The error in the log-determinant is 0.00e+00
Finally, we can do the abstract evaluation.
Where are we now but nowhere?
So we are done for today. Our next step will be to implement all of the bits that are needed to make the derivatives work. So in the next instalment we will differentiate log-determinants, Cholesky decompositions, and all kinds of other fun things.
It should be a blast.
Footnotes
The second half of this post is half written but, to be honest, I want to go to bed more than I want to implement more derivatives, so I’m splitting the post.↩︎
aka JAX can map out how the pieces of the function go together and it can then use that map to make its weird transformations↩︎
But mostly because although I’m going to have to implement the Cholesky and triangular solves later on down the line, I’m writing this in order and I don’t wanna.↩︎
The JAX docs don’t use decorators for their bindings but I use decorators because I like decorators.↩︎
Something something duck type. They’re arrays with numbers in them that work in numpy and scipy. Get off my arse.↩︎
This is mostly for JIT, so it’s not necessary today, but to be very honest it’s the only easy thing to do here and I’m not above giving myself a participation trophy.↩︎
This is a … fringe problem in JAX-land, so it makes sense that there is a less than beautiful solution to the problem. I think this would be less of a design problem in Stan, where it’s possible to make the number of unknowns in the autodiff tree depend on
int
arrays is a complex way.↩︎Well, me. I’m who knows. I’m still treating this like scratch code in a notepad. Although we are moving towards the point where I’m going to have to set everything out properly. Maybe that’s the next post?↩︎
Full disclosure: first time out I forgot to multiply by two. This is why we test.↩︎
Reuse
Citation
@online{simpson2022,
author = {Simpson, Dan},
title = {Sparse {Matrices} 5: {I} Bind You {Nancy}},
date = {2022-05-20},
url = {https://dansblog.netlify.app/2022-05-18-sparse4-some-primatives},
langid = {en}
}