Sparse matrices 6: To catch a derivative, first you’ve got to think like a derivative

Open up the kennels, Kenneth. Mamma’s coming home tonight.

JAX
Sparse matrices
Autodiff
Author

Dan Simpson

Published

May 30, 2022

Welcome to part six!!! of our ongoing series on making sparse linear algebra differentiable in JAX with the eventual hope to be able to do some cool statistical shit. We are nowhere near done.

Last time, we looked at making JAX primitives. We built four of them. Today we are going to implement the corresponding differentiation rules! For three1 of them.

So strap yourselves in. This is gonna be detailed.

If you’re interested in the code2, the git repo for this post is linked at the bottom and in there you will find a folder with the python code in a python file.

She is beauty and she is grace. She is queen of 50 states. She is elegance and taste. She is miss autodiff

Derivatives are computed in JAX through the glory and power of automatic differentiation. If you came to this blog hoping for a great description of how autodiff works, I am terribly sorry but I absolutely do not have time for that. Might I suggest google? Or maybe flick through this survey by Charles Margossian..

The most important thing to remember about algorithmic differentiation is that it is not symbolic differentiation. That is, it does not create the functional form of the derivative of the function and compute that. Instead, it is a system for cleverly composing derivatives in each bit of the program to compute the value of the derivative of the function.

But for that to work, we need to implement those clever little mini-derivatives. In particular, every function \(f(\cdot): \mathbb{R}^n \rightarrow \mathbb{R}^m\) needs to have a function to compute the corresponding Jacobian-vector product \[ (\theta, v) \rightarrow J(\theta) v, \] where the \(n \times m\) matrix \(J(\theta)\) has entries \[ J(\theta)_{ij} = \frac{\partial f_j }{\partial \theta_j}. \]

Ok. So let’s get onto this. We are going to derive and implement some Jacobian-vector products. And all of the assorted accoutrement. And by crikey. We are going to do it all in a JAX-traceable way.

JVP number one: The linear solve.

The first of the derivatives that we need to work out is the derivative of a linear solve \(A^{-1}b\). Now, intrepid readers, the obvious thing to do is look the damn derivative up. You get exactly no hero points for computing it yourself.

But I’m not you, I’m a dickhead.

So I’m going to derive it. I could pretend there are reasons3, but that would just be lying. I’m doing it because I can.

Beyond the obvious fun of working out a matrix derivative from first principles, this is fun because we have two arguments instead of just one. Double the fun.

And we really should make sure the function is differentiated with respect to every reasonable argument. Why? Because if you write code other people might use, you don’t get to control how they use it (or what they will email you about). So it’s always good practice to limit surprises (like a function not being differentiable wrt some argument) to cases4 where it absolutely necessary. This reduces the emails.

To that end, let’s take an arbitrary SPD matrix \(A\) with a fixed sparsity pattern. Let’s take another symmetric matrix \(\Delta\) with the same sparsity pattern and assume that \(\Delta\) is small enough5 that \(A + \Delta\) is still symmetric positive definite. We also need a vector \(\delta\) with a small \(\|\delta\|\).

Now let’s get algebraing. \[\begin{align*} f(A + \Delta, b + \delta) &= (A+\Delta)^{-1}(b + \delta) \\ &= (I + A^{-1}\Delta)^{-1}A^{-1}(b + \delta) \\ &= (I - A^{-1}\Delta + o(\|\Delta\|))A^{-1}(b + \delta) \\ &= A^{-1}b + A^{-1}(\delta - \Delta A^{-1}b ) + o(\|\Delta\| + \|\delta\|) \end{align*}\]

Easy6 as.

We’ve actually calculated the derivative now, but it’s a little more work to recognise it.

To do that, we need to remember the practical definition of the Jacobian of a function \(f(x)\) that takes an \(n\)-dimensional input and produces an \(m\)-dimensional output. It is the \(n \times m\) matrix \(J_f(x)\) such that \[ f(x + \delta) = f(x) + J_f(x)\delta + o(\|\delta\|). \]

The formulas further simplify if we write \(c = A^{-1}b\). Then, if we want the Jacobian-vector product for the first argument, it is \[ -A^{-1}\Delta c, \] while the Jacobian-vector product for the second argument is \[ A^{-1}\delta. \]

The only wrinkle in doing this is we need to remember that we are only storing the lower triangle of \(A\). Because we need to represent \(\Delta\) the same way, it is represented as a vector Delta_x that contains only the lower triangle of \(\Delta\). So we need to make sure we remember to form the whole matrix before we do the matrix-vector product \(\Delta c\)!

But otherwise, the implementation is going to be pretty straightforward. The Jacobian-vector product costs one additional linear solve (beyond the one needed to compute the value \(c = A^{-1}b\)).

In the language of JAX (and autodiff in general), we refer to \(\Delta\) and \(\delta\) as tangent vectors. In search of a moderately coherent naming convention, we are going to refer to the tangent associated with the variable x as xt.

So let’s implement this. Remember: it needs7 to be JAX traceable.

Primitive two: The triangular solve

For some sense of continuity, we are going to keep the naming of the primitives from the last blog post, but we are not going to attack them in the same order. Why not? Because we work in order of complexity.

So first off we are going to do the triangular solve. As I have yet to package up the code (I promise, that will happen next8), I’m just putting it here under the fold.

The primal implementation
from scipy import sparse
import numpy as np
from jax import numpy as jnp
from jax import core
from jax._src import abstract_arrays
from jax import core

sparse_triangular_solve_p = core.Primitive("sparse_triangular_solve")

def sparse_triangular_solve(L_indices, L_indptr, L_x, b, *, transpose: bool = False):
  """A JAX traceable sparse  triangular solve"""
  return sparse_triangular_solve_p.bind(L_indices, L_indptr, L_x, b, transpose = transpose)

@sparse_triangular_solve_p.def_impl
def sparse_triangular_solve_impl(L_indices, L_indptr, L_x, b, *, transpose = False):
  """The implementation of the sparse triangular solve. This is not JAX traceable."""
  L = sparse.csc_array((L_x, L_indices, L_indptr)) 
  
  assert L.shape[0] == L.shape[1]
  assert L.shape[0] == b.shape[0]
  
  if transpose:
    return sparse.linalg.spsolve_triangular(L.T, b, lower = False)
  else:
    return sparse.linalg.spsolve_triangular(L.tocsr(), b, lower = True)

@sparse_triangular_solve_p.def_abstract_eval
def sparse_triangular_solve_abstract_eval(L_indices, L_indptr, L_x, b, *, transpose = False):
  assert L_indices.shape[0] == L_x.shape[0]
  assert b.shape[0] == L_indptr.shape[0] - 1
  return abstract_arrays.ShapedArray(b.shape, b.dtype)

The Jacobian-vector product

from jax._src import ad_util
from jax.interpreters import ad
from jax import lax
from jax.experimental import sparse as jsparse

def sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent, *, transpose):
  """
  A jax-traceable jacobian-vector product. In order to make it traceable, 
  we use the experimental sparse CSC matrix in JAX.
  
  Input:
    arg_values:   A tuple of (L_indices, L_indptr, L_x, b) that describe
                  the triangular matrix L and the rhs vector b
    arg_tangent:  A tuple of tangent values (same lenght as arg_values).
                  The first two values are nonsense - we don't differentiate
                  wrt integers!
    transpose:    (boolean) If true, solve L^Tx = b. Otherwise solve Lx = b.
  Output:         A tuple containing the maybe_transpose(L)^{-1}b and the corresponding
                  Jacobian-vector product.
  """
  L_indices, L_indptr, L_x, b = arg_values
  _, _, L_xt, bt = arg_tangent
  value = sparse_triangular_solve(L_indices, L_indptr, L_x, b, transpose=transpose)
  if type(bt) is ad.Zero and type(L_xt) is ad.Zero:
    # I legit do not think this ever happens. But I'm honestly not sure.
    print("I have arrived!")
    return value, lax.zeros_like_array(value) 
  
  if type(L_xt) is not ad.Zero:
    # L is variable
    if transpose:
      Delta = jsparse.CSC((L_xt, L_indices, L_indptr), shape = (b.shape[0], b.shape[0])).transpose()
    else:
      Delta = jsparse.CSC((L_xt, L_indices, L_indptr), shape = (b.shape[0], b.shape[0]))

    jvp_Lx = sparse_triangular_solve(L_indices, L_indptr, L_x, Delta @ value, transpose = transpose) 
  else:
    jvp_Lx = lax.zeros_like_array(value) 

  if type(bt) is not ad.Zero:
    # b is variable
    jvp_b = sparse_triangular_solve(L_indices, L_indptr, L_x, bt, transpose = transpose)
  else:
    jvp_b = lax.zeros_like_array(value)

  return value, jvp_b - jvp_Lx

ad.primitive_jvps[sparse_triangular_solve_p] = sparse_triangular_solve_value_and_jvp

Before we see if this works, let’s first have talk about the structure of the function I just wrote. Generally speaking, we want a function that takes in the primals and tangents at tuples and then returns the value and the9 Jacobian-vector product.

The main thing you will notice in the code is that there is a lot of checking for ad.Zero. This is a special type defined in JAX that is, essentially, telling the autodiff system that we are not differentiating wrt that variable. This is different to a tangent that just happens to be numerically equal to zero. Any code for a Jacobian-vector product needs to handle this special value.

As we have two arguments, we have 3 interesting options:

  1. Both L_xt and bt are ad.Zero: This means the function is a constant and the derivative is zero. I am fairly certain that we do not need to manually handle this case, but because I don’t know and I do not like surprises, it’s in there.

  2. L_xt is not ad.Zero: This means that we need to differentiate wrt the matrix. In this case we need to compute \(\Delta c\) or \(\Delta^T c\), depending on the transpose argument. In order to do this, I used the jax.experimental.sparse.CSC class, which has some very limited sparse matrix support (basically matrix-vector products). This is extremely convenient because it means I don’t need to write the matrix-vector product myself!

  3. bt is not ad.Zero: This means that we need to differentiate wrt the rhs vector. This part of the formula is pretty straightforward: just an application of the primal.

In the case that either L_xt or bt are ad.Zero, we simply set the corresponding contribution to the jvp to zero.

It’s worth saying that you can bypass all of this ad.Zero logic by writing separate functions for the JVP contribution from each input and then chaining them together using10 ad.defjvp2() to chain them together. This is what the lax.linalg.triangular_solve() implementation does.

So why didn’t I do this? I avoided this because in the other primitives I have to implement, there are expensive computations (like Cholesky factorisations) that I want to share between the primal and the various tangent calculations. The ad.defjvp frameworks don’t allow for that. So I decided not to demonstrate/learn two separate patterns.

Transposition

Now I’ve never actively wanted a Jacobian-vector product in my whole life. I’m sorry. I want a gradient. Gimme a gradient. I am the Veruca Salt of gradients.

In may autodiff systems, if you want11 a gradient, you need to implement vector-Jacobian products12 explicitly.

One of the odder little innovations in JAX is that instead of forcing you to implement this as well13, you only need to implement half of it.

You see, some clever analysis that, as far as I far as I can tell14, is detailed in this paper shows that you only need to form explicit vector-Jacobian products for the structurally linear arguments of the function.

In JAX (and maybe elsewhere), this is known as a transposition rule. The combination of a transopition rule and a JAX-traceable Jacobian-vector product is enough for JAX to compute all of the directional derivatives and gradients we could ever hope for.

As far as I understand, it is all about functions that are structurally linear in some arguments. For instance, if \(A(x)\) is a matrix-valued function and \(x\) and \(y\) are vectors, then the function \[ f(x, y) = A(x)y + g(x) \] is structurally linear in \(y\) in the sense that for every fixed value of \(x\), the function \[ f_x(y) = A(x) y + g(x) \] is linear in \(y\). The resulting transpositon rule is then

def f_transpose(x, y):
  Ax = A(x)
  gx = g(x)
  return (None, Ax.T @ y + gx)

The first element of the return is None because \(f(x,y)\) is not15 structurally linear in \(x\) so there is nothing to transpose. The second element simply takes the matrix in the linear function and transposes it.

If you know anything about autodiff, you’ll think “this doesn’t feel like enough” and it’s not. JAX deals with the non-linear part of \(f(x,y)\) by tracing the evaluation tree for its Jacobian-vector product and … manipulating16 it.

We already built the abstract evaluation function last time around, so the tracing part can be done. All we need is the transposition rule.

The linear solve \(f(A, b) = A^{-1}b\) is non-linear in the first argument but linear in the second argument. So we only need to implement \[ J^T_b(A,b)w = A^{-T}w, \] where the subscript \(b\) indicates we’re only computing the Jacobian wrt \(b\).

Initially, I struggled to work out what needed to be implemented here. The thing that clarified the process for me was looking at JAX’s internal implementation of the Jacobian-vector product for a dense matrix. From there, I understood what this had to look like for a vector-valued function and this is the result.

def sparse_triangular_solve_transpose_rule(cotangent, L_indices, L_indptr, L_x, b, *, transpose):
  """
  Transposition rule for the triangular solve. 
  Translated from here https://github.com/google/jax/blob/41417d70c03b6089c93a42325111a0d8348c2fa3/jax/_src/lax/linalg.py#L747.
  Inputs:
    cotangent: Output cotangent (aka adjoint). (produced by JAX)
    L_indices, L_indptr, L_x: Represenation of sparse matrix. L_x should be concrete
    b: The right hand side. Must be an jax.interpreters.ad.UndefinedPrimal
    transpose: (boolean) True: solve $L^Tx = b$. False: Solve $Lx = b$.
  Output:
    A 4-tuple with the adjoints (None, None, None, b_adjoint)
  """
  assert not ad.is_undefined_primal(L_x) and ad.is_undefined_primal(b)
  if type(cotangent) is ad_util.Zero:
    cot_b = ad_util.Zero(b.aval)
  else:
    cot_b = sparse_triangular_solve(L_indices, L_indptr, L_x, cotangent, transpose = not transpose)
  return None, None, None, cot_b

ad.primitive_transposes[sparse_triangular_solve_p] = sparse_triangular_solve_transpose_rule

If this doesn’t make a lot of sense to you, that’s because it’s confusing.

One way to think of it is in terms of the more ordinary notation. Mike Giles has a classic paper that covers these results for basic linear algebra. The idea is to imagine that, as part of your larger program, you need to compute \(c = A^{-1}b\).

Forward-mode autodiff computes the sensitivity of \(c\), usually denoted \(\dot c\) from the sensitivies \(\dot A\) and \(\dot b\). These have already been computed. The formula in Giles is \[ \dot c = A^{-1}(\dot b - \dot A c). \] The canny reader will recognise this as exactly17 the formula for the Jacobian-vector product.

So what does reverse-mode autodiff do? Well it moves through the program in the other direction. So instead of starting with the sensitivities \(\dot A\) and \(\dot b\) already computed, we instead start with the18 adjoint sensitivity \(\bar c\). Our aim is to compute \(\bar A\) and \(\bar b\) from \(\bar c\).

The details of how to do this are19 beyond the scope, but without tooooooo much effort you can show that \[ \bar b = A^{-T} \bar c, \] which you should recognise as the equation that was just implemented.

The thing that we do not have to implement in JAX is the other adjoint that, for dense matrices20, is \[ \bar{A} = -\bar{b}c^T. \] Through the healing power of … something?—Truly I do not know.— JAX can work that bit out itself. woo.

Testing the numerical implementation of the Jacobian-vector product

So let’s see if this works. I’m not going to lie, I’m flying by the seat of my pants here. I’m not super familiar with the JAX internals, so I have written a lot of test cases. You may wish to skip this part. But rest assured that almost every single one of these cases was useful to me working out how this thing actually worked!

def make_matrix(n):
    one_d = sparse.diags([[-1.]*(n-1), [2.]*n, [-1.]*(n-1)], [-1,0,1])
    A = (sparse.kronsum(one_d, one_d) + sparse.eye(n*n)).tocsc()
    A_lower = sparse.tril(A, format = "csc")
    A_index = A_lower.indices
    A_indptr = A_lower.indptr
    A_x = A_lower.data
    return (A_index, A_indptr, A_x, A)

A_indices, A_indptr, A_x, A = make_matrix(10)

This is the same test case as the last blog. We will just use the lower triangle of \(A\) as the test matrix.

First things first, let’s check out the numerical implementation of the function. We will do that by comparing the implemented Jacobian-vector product with the definition of the Jacobian-vector product (aka the forward21 difference approximation).

There are lots of things that we could do here to turn these into actual tests. For instance, the test suite inside JAX has a lot of nice convenience functions for checking implementations of derivatives. But I went with homespun because that was how I was feeling.

You’ll also notice that I’m using random numbers here, which is fine for a blog. Not so fine for a test that you don’t want to be potentially22 flaky.

The choice of eps = 1e-4 is roughly23 because it’s the square root of the single precision machine epsilon24. A very rough back of the envelope calculation for the forward difference approximation to the derivative shows that the square root of the machine epislon is about the size you want your perturbation to be.

b = np.random.standard_normal(100)

bt = np.random.standard_normal(100)
bt /= np.linalg.norm(bt)

A_xt = np.random.standard_normal(len(A_x))
A_xt /= np.linalg.norm(A_xt)

arg_values = (A_indices, A_indptr, A_x, b )

arg_tangent_A = (None, None, A_xt, ad.Zero(type(b)))
arg_tangent_b = (None, None, ad.Zero(type(A_xt)), bt)
arg_tangent_Ab = (None, None, A_xt, bt)

p, t_A = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_A, transpose = False)
_, t_b = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_b, transpose = False)
_, t_Ab = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_Ab, transpose = False)
pT, t_AT = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_A, transpose = True)
_, t_bT = sparse_triangular_solve_value_and_jvp(arg_values, arg_tangent_b, transpose = True)

eps = 1e-4
tt_A = (sparse_triangular_solve(A_indices, A_indptr, A_x + eps * A_xt, b) - p) /eps
tt_b = (sparse_triangular_solve(A_indices, A_indptr, A_x, b + eps * bt) - p) / eps
tt_Ab = (sparse_triangular_solve(A_indices, A_indptr, A_x + eps * A_xt, b + eps * bt) - p) / eps
tt_AT = (sparse_triangular_solve(A_indices, A_indptr, A_x + eps * A_xt, b, transpose = True) - pT) / eps
tt_bT = (sparse_triangular_solve(A_indices, A_indptr, A_x, b + eps * bt, transpose = True) - pT) / eps

print(f"""
Transpose = False:
  Error A varying: {np.linalg.norm(t_A - tt_A): .2e}
  Error b varying: {np.linalg.norm(t_b - tt_b): .2e}
  Error A and b varying: {np.linalg.norm(t_Ab - tt_Ab): .2e}

Transpose = True:
  Error A varying: {np.linalg.norm(t_AT - tt_AT): .2e}
  Error b varying: {np.linalg.norm(t_bT - tt_bT): .2e}
""")

Transpose = False:
  Error A varying:  1.08e-07
  Error b varying:  0.00e+00
  Error A and b varying:  4.19e-07

Transpose = True:
  Error A varying:  1.15e-07
  Error b varying:  0.00e+00

Brilliant! Everythign correct withing single precision!

Checking on the plumbing

Making the numerical implementation work is only half the battle. We also have to make it work in the context of JAX.

Now I would be lying if I pretended this process went smoothly. But the first time is for experience. It’s mostly a matter of just reading the documentation carefully and going through similar examples that have already been implemented.

And testing. I learnt how this was supposed to work by testing it.

(For full disclosure, I also wrote a big block f-string in the sparse_triangular_solve() function at one point that told me the types, shapes, and what transpose was, which was how I worked out that my code was breaking because I forgot the first to None outputs in the transposition rule. When it doubt, print shit.)

As you will see from my testing code, I was not going for elegance. I was running the damn permutations. If you’re looking for elegance, look elsewhere.

from jax import jvp, grad
from jax import scipy as jsp

def f(theta):
  Ax_theta = jnp.array(A_x)
  Ax_theta = Ax_theta.at[A_indptr[20]].add(theta[0])
  Ax_theta = Ax_theta.at[A_indptr[50]].add(theta[1])
  b = jnp.ones(100)
  return sparse_triangular_solve(A_indices, A_indptr, Ax_theta, b, transpose = True)

def f_jax(theta):
  Ax_theta = jnp.array(sparse.tril(A).todense())
  Ax_theta = Ax_theta.at[20,20].add(theta[0])
  Ax_theta = Ax_theta.at[50,50].add(theta[1])
  b = jnp.ones(100)
  return jsp.linalg.solve_triangular(Ax_theta, b, lower = True, trans = "T")

def g(theta):
  Ax_theta = jnp.array(A_x)
  b = jnp.ones(100)
  b = b.at[0].set(theta[0])
  b = b.at[51].set(theta[1])
  return sparse_triangular_solve(A_indices, A_indptr, Ax_theta, b, transpose = True)

def g_jax(theta):
  Ax_theta = jnp.array(sparse.tril(A).todense())
  b = jnp.ones(100)
  b = b.at[0].set(theta[0])
  b = b.at[51].set(theta[1])
  return jsp.linalg.solve_triangular(Ax_theta, b, lower = True, trans = "T")

def h(theta):
  Ax_theta = jnp.array(A_x)
  Ax_theta = Ax_theta.at[A_indptr[20]].add(theta[0]) 
  b = jnp.ones(100)
  b = b.at[51].set(theta[1])
  return sparse_triangular_solve(A_indices, A_indptr, Ax_theta, b, transpose = False)

def h_jax(theta):
  Ax_theta = jnp.array(sparse.tril(A).todense())
  Ax_theta = Ax_theta.at[20,20].add(theta[0])
  b = jnp.ones(100)
  b = b.at[51].set(theta[1])
  return jsp.linalg.solve_triangular(Ax_theta, b, lower = True, trans = "N")

def no_diff(theta):
  return sparse_triangular_solve(A_indices, A_indptr, A_x, jnp.ones(100), transpose = False)

def no_diff_jax(theta):
  return jsp.linalg.solve_triangular(jnp.array(sparse.tril(A).todense()), jnp.ones(100), lower = True, trans = "N")

A_indices, A_indptr, A_x, A = make_matrix(10)
primal1, jvp1 = jvp(f, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal2, jvp2 = jvp(f_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad1 = grad(lambda x: jnp.mean(f(x)))(jnp.array([-142., 342.]))
grad2 = grad(lambda x: jnp.mean(f_jax(x)))(jnp.array([-142., 342.]))

primal3, jvp3 = jvp(g, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal4, jvp4 = jvp(g_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad3 = grad(lambda x: jnp.mean(g(x)))(jnp.array([-142., 342.]))
grad4 = grad(lambda x: jnp.mean(g_jax(x)))(jnp.array([-142., 342.]))  

primal5, jvp5 = jvp(h, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal6, jvp6 = jvp(h_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad5 = grad(lambda x: jnp.mean(h(x)))(jnp.array([-142., 342.]))
grad6 = grad(lambda x: jnp.mean(h_jax(x)))(jnp.array([-142., 342.]))

primal7, jvp7 = jvp(no_diff, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal8, jvp8 = jvp(no_diff_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad7 = grad(lambda x: jnp.mean(no_diff(x)))(jnp.array([-142., 342.]))
grad8 = grad(lambda x: jnp.mean(no_diff_jax(x)))(jnp.array([-142., 342.]))

print(f"""
Variable L:
  Primal difference: {np.linalg.norm(primal1 - primal2): .2e}
  JVP difference: {np.linalg.norm(jvp1 - jvp2): .2e}
  Gradient difference: {np.linalg.norm(grad1 - grad2): .2e}

Variable b:
  Primal difference: {np.linalg.norm(primal3 - primal4): .2e}
  JVP difference: {np.linalg.norm(jvp3 - jvp4): .2e}
  Gradient difference: {np.linalg.norm(grad3 - grad4): .2e} 

Variable L and b:
  Primal difference: {np.linalg.norm(primal5 - primal6): .2e}
  JVP difference: {np.linalg.norm(jvp5 - jvp6): .2e}
  Gradient difference: {np.linalg.norm(grad5 - grad6): .2e}

No diff:
  Primal difference: {np.linalg.norm(primal7 - primal8)}
  JVP difference: {np.linalg.norm(jvp7 - jvp8)}
  Gradient difference: {np.linalg.norm(grad7 - grad8)}
""")

Variable L:
  Primal difference:  1.98e-07
  JVP difference:  2.58e-12
  Gradient difference:  0.00e+00

Variable b:
  Primal difference:  7.94e-06
  JVP difference:  1.83e-08
  Gradient difference:  3.29e-10 

Variable L and b:
  Primal difference:  2.08e-06
  JVP difference:  1.08e-08
  Gradient difference:  2.33e-10

No diff:
  Primal difference: 2.2101993124579167e-07
  JVP difference: 0.0
  Gradient difference: 0.0

Stunning!

Primitive one: The general \(A^{-1}b\)

Ok. So this is a very similar problem to the one that we just solved. But, as fate would have it, the solution is going to look quite different. Why? Because we need to compute a Cholesky factorisation.

First things first, though, we are going to need a JAX-traceable way to compute a Cholesky factor. This means that we need25 to tell our sparse_solve function the how many non-zeros the sparse Cholesky will have. Why? Well. It has to do with how the function is used.

When sparse_cholesky() is called with concrete inputs26, then it can quite happily work out the sparsity structure of \(L\). But when JAX is preparing to transform the code, eg when it’s building a gradient, it calls sparse_cholesky() using abstract arguments that only share the shape information from the inputs. This is not enough to compute the sparsity structure. We need the indices and indptr arrays.

This means that we need sparse_cholesky() to throw an error if L_nse isn’t passed. This wasn’t implemented well last time, so here it is done properly.

(If you’re wondering about that None argument, it is the identity transform. So if A_indices is a concrete value, ind = A_indices. Otherwise an error is called.)

sparse_cholesky_p = core.Primitive("sparse_cholesky")

def sparse_cholesky(A_indices, A_indptr, A_x, *, L_nse: int = None):
  """A JAX traceable sparse cholesky decomposition"""
  if L_nse is None:
    err_string = "You need to pass a value to L_nse when doing fancy sparse_cholesky."
    ind = core.concrete_or_error(None, A_indices, err_string)
    ptr = core.concrete_or_error(None, A_indptr, err_string)
    L_ind, _ = _symbolic_factor(ind, ptr)
    L_nse = len(L_ind)
  
  return sparse_cholesky_p.bind(A_indices, A_indptr, A_x, L_nse = L_nse)
The rest of the Choleksy code
@sparse_cholesky_p.def_impl
def sparse_cholesky_impl(A_indices, A_indptr, A_x, *, L_nse):
  """The implementation of the sparse cholesky This is not JAX traceable."""
  
  L_indices, L_indptr= _symbolic_factor(A_indices, A_indptr)
  if L_nse is not None:
    assert len(L_indices) == L_nse
    
  L_x = _structured_copy(A_indices, A_indptr, A_x, L_indices, L_indptr)
  L_x = _sparse_cholesky_impl(L_indices, L_indptr, L_x)
  return L_indices, L_indptr, L_x

def _symbolic_factor(A_indices, A_indptr):
  # Assumes A_indices and A_indptr index the lower triangle of $A$ ONLY.
  n = len(A_indptr) - 1
  L_sym = [np.array([], dtype=int) for j in range(n)]
  children = [np.array([], dtype=int) for j in range(n)]
  
  for j in range(n):
    L_sym[j] = A_indices[A_indptr[j]:A_indptr[j + 1]]
    for child in children[j]:
      tmp = L_sym[child][L_sym[child] > j]
      L_sym[j] = np.unique(np.append(L_sym[j], tmp))
    if len(L_sym[j]) > 1:
      p = L_sym[j][1]
      children[p] = np.append(children[p], j)
        
  L_indptr = np.zeros(n+1, dtype=int)
  L_indptr[1:] = np.cumsum([len(x) for x in L_sym])
  L_indices = np.concatenate(L_sym)
  
  return L_indices, L_indptr



def _structured_copy(A_indices, A_indptr, A_x, L_indices, L_indptr):
  n = len(A_indptr) - 1
  L_x = np.zeros(len(L_indices))
  
  for j in range(0, n):
    copy_idx = np.nonzero(np.in1d(L_indices[L_indptr[j]:L_indptr[j + 1]],
                                  A_indices[A_indptr[j]:A_indptr[j+1]]))[0]
    L_x[L_indptr[j] + copy_idx] = A_x[A_indptr[j]:A_indptr[j+1]]
  return L_x

def _sparse_cholesky_impl(L_indices, L_indptr, L_x):
  n = len(L_indptr) - 1
  descendant = [[] for j in range(0, n)]
  for j in range(0, n):
    tmp = L_x[L_indptr[j]:L_indptr[j + 1]]
    for bebe in descendant[j]:
      k = bebe[0]
      Ljk= L_x[bebe[1]]
      pad = np.nonzero(                                                       \
          L_indices[L_indptr[k]:L_indptr[k+1]] == L_indices[L_indptr[j]])[0][0]
      update_idx = np.nonzero(np.in1d(                                        \
                    L_indices[L_indptr[j]:L_indptr[j+1]],                     \
                    L_indices[(L_indptr[k] + pad):L_indptr[k+1]]))[0]
      tmp[update_idx] = tmp[update_idx] -                                     \
                        Ljk * L_x[(L_indptr[k] + pad):L_indptr[k + 1]]
            
    diag = np.sqrt(tmp[0])
    L_x[L_indptr[j]] = diag
    L_x[(L_indptr[j] + 1):L_indptr[j + 1]] = tmp[1:] / diag
    for idx in range(L_indptr[j] + 1, L_indptr[j + 1]):
      descendant[L_indices[idx]].append((j, idx))
  return L_x

@sparse_cholesky_p.def_abstract_eval
def sparse_cholesky_abstract_eval(A_indices, A_indptr, A_x, *, L_nse):
  return core.ShapedArray((L_nse,), A_indices.dtype),                   \
         core.ShapedArray(A_indptr.shape, A_indptr.dtype),             \
         core.ShapedArray((L_nse,), A_x.dtype)

Why do we need a new pattern for this very very similar problem?

Ok. So now on to the details. If we try to repeat our previous pattern it would look like this.

def sparse_solve_value_and_jvp(arg_values, arg_tangents, *, L_nse):
  """ 
  Jax-traceable jacobian-vector product implmentation for sparse_solve.
  """
  
  A_indices, A_indptr, A_x, b = arg_values
  _, _, A_xt, bt = arg_tangents

  # Needed for shared computation
  L_indices, L_indptr, L_x = sparse_cholesky(A_indices, A_indptr, A_x)

  # Make the primal
  primal_out = sparse_triangular_solve(L_indices, L_indptr, L_x, b, transpose = False)
  primal_out = sparse_triangular_solve(L_indices, L_indptr, L_x, primal_out, transpose = True)

  if type(A_xt) is not ad.Zero:
    Delta_lower = jsparse.CSC((A_xt, A_indices, A_indptr), shape = (b.shape[0], b.shape[0]))
    # We need to do Delta @ primal_out, but we only have the lower triangle
    rhs = Delta_lower @ primal_out + Delta_lower.transpose() @ primal_out - A_xt[A_indptr[:-1]] * primal_out
    jvp_Ax = sparse_triangular_solve(L_indices, L_indptr, L_x, rhs)
    jvp_Ax = sparse_triangular_solve(L_indices, L_indptr, L_x, jvp_Ax, transpose = True)
  else:
    jvp_Ax = lax.zeros_like_array(primal_out)

  if type(bt) is not ad.Zero:
    jvp_b = sparse_triangular_solve(L_indices, L_indptr, L_x, bt)
    jvp_b = sparse_triangular_solve(L_indices, L_indptr, L_x, jvp_b, transpose = True)
  else:
    jvp_b = lax.zeros_like_array(primal_out)

  return primal_out, jvp_b - jvp_Ax

That’s all well and good. Nothing weird there.

The problem comes when you need to implement the transposition rule. Remembering that \(\bar b = A^{-T}\bar c = A^{-1}\bar c\), you might see the issue: we are going to need the Cholesky factorisation. But we have no way to pass \(L\) to the transpose function.

This means that we would need to compute two Cholesky factorisations per gradient instead of one. As the Cholesky factorisation is our slowest operation, we do not want to do extra ones! We want to compute the Cholesky triangle once and pass it around like a party bottom27. We do not want each of our functions to have to make a deep and meaningful connection with the damn matrix28.

A different solution

So how do we pass around our Cholesky triangle? Well, I do love a good class so my first thought was “fuck it. I’ll make a class and I’ll pass it that way”. But the developers of JAX had a much better idea.

Their idea was to abstract the idea of a linear solve and its gradients. They do this through lax.custom_linear_solve. This is a function that takes all of the bits that you would need to compute \(A^{-1}b\) and all of its derivatives. In particular it takes29:

  • matvec: A function that matvec(x) that computes \(Ax\). This might seem a bit weird, but it’s the most common atrocity committed by mathematicians is abstracting30 a matrix to a linear mapping. So we might as well just suck it up.
  • b: The right hand side vector31
  • solve: A function that takes takes the matvec and a vector so that32 solve(matvec, matvec(x)) == x
  • symmetric: A boolean indicating if \(A\) is symmetric.

The idea (happily copped from the implementation of jax.scipy.linalg.solve) is to wrap our Cholesky decomposition in the solve function. Through the never ending miracle of partial evaluation.

from functools import partial

def sparse_solve(A_indices, A_indptr, A_x, b, *, L_nse = None):
  """
  A JAX-traceable sparse solve. For this moment, only for vector b
  """
  assert b.shape[0] == A_indptr.shape[0] - 1
  assert b.ndim == 1
  
  L_indices, L_indptr, L_x = sparse_cholesky(
    lax.stop_gradient(A_indices), 
    lax.stop_gradient(A_indptr), 
    lax.stop_gradient(A_x), L_nse = L_nse)
  
  def chol_solve(L_indices, L_indptr, L_x, b):
    out = sparse_triangular_solve(L_indices, L_indptr, L_x, b, transpose = False)
    return sparse_triangular_solve(L_indices, L_indptr, L_x, out, transpose = True)
  
  def matmult(A_indices, A_indptr, A_x, b):
    A_lower = jsparse.CSC((A_x, A_indices, A_indptr), shape = (b.shape[0], b.shape[0]))
    return A_lower @ b + A_lower.transpose() @ b - A_x[A_indptr[:-1]] * b

  solver = partial(
    lax.custom_linear_solve,
    lambda x: matmult(A_indices, A_indptr, A_x, x),
    solve = lambda _, x: chol_solve(L_indices, L_indptr, L_x, x),
    symmetric = True)

  return solver(b)

There are three things of note in that implementation.

  1. The calls to lax.stop_gradient(): These tell JAX to not bother computing the gradient of these terms. The relevant parts of the derivatives are computed explicitly by lax.custom_linear_solve in terms of matmult and solve, neither of which need the explicit derivative of the cholesky factorisation.!

  2. That definition of matmult()33: Look. I don’t know what to tell you. Neither addition nor indexing is implemented for jsparse.CSC objects. So we did it the semi-manual way. (I am thankful that matrix-vector multiplication is available)

  3. The definition of solver(): Partial evaluation is a wonderful wonderful thing. functools.partial() transforms lax.custom_linear_solve() from a function that takes 3 arguments (and some keywords), into a function solver() that takes one34 argument35 (b, the only positional argument of lax.custom_linear_solve() that isn’t specified).

Does it work?

def f(theta):
  Ax_theta = jnp.array(theta[0] * A_x)
  Ax_theta = Ax_theta.at[A_indptr[:-1]].add(theta[1])
  b = jnp.ones(100)
  return sparse_solve(A_indices, A_indptr, Ax_theta, b)

def f_jax(theta):
  Ax_theta = jnp.array(theta[0] * A.todense())
  Ax_theta = Ax_theta.at[np.arange(100),np.arange(100)].add(theta[1])
  b = jnp.ones(100)
  return jsp.linalg.solve(Ax_theta, b)

def g(theta):
  Ax_theta = jnp.array(A_x)
  b = jnp.ones(100)
  b = b.at[0].set(theta[0])
  b = b.at[51].set(theta[1])
  return sparse_solve(A_indices, A_indptr, Ax_theta, b)

def g_jax(theta):
  Ax_theta = jnp.array(A.todense())
  b = jnp.ones(100)
  b = b.at[0].set(theta[0])
  b = b.at[51].set(theta[1])
  return jsp.linalg.solve(Ax_theta, b)

def h(theta):
  Ax_theta = jnp.array(A_x)
  Ax_theta = Ax_theta.at[A_indptr[:-1]].add(theta[0])
  b = jnp.ones(100)
  b = b.at[51].set(theta[1])
  return sparse_solve(A_indices, A_indptr, Ax_theta, b)

def h_jax(theta):
  Ax_theta = jnp.array(A.todense())
  Ax_theta = Ax_theta.at[np.arange(100),np.arange(100)].add(theta[0])
  b = jnp.ones(100)
  b = b.at[51].set(theta[1])
  return jsp.linalg.solve(Ax_theta, b)

primal1, jvp1 = jvp(f, (jnp.array([2., 3.]),), (jnp.array([1., 2.]),))
primal2, jvp2 = jvp(f_jax, (jnp.array([2., 3.]),), (jnp.array([1., 2.]),))
grad1 = grad(lambda x: jnp.mean(f(x)))(jnp.array([2., 3.]))
grad2 = grad(lambda x: jnp.mean(f_jax(x)))(jnp.array([2., 3.]))


primal3, jvp3 = jvp(g, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
primal4, jvp4 = jvp(g_jax, (jnp.array([-142., 342.]),), (jnp.array([1., 2.]),))
grad3 = grad(lambda x: jnp.mean(g(x)))(jnp.array([-142., 342.]))
grad4 = grad(lambda x: jnp.mean(g_jax(x)))(jnp.array([-142., 342.]))

primal5, jvp5 = jvp(h, (jnp.array([2., 342.]),), (jnp.array([1., 2.]),))
primal6, jvp6 = jvp(h_jax, (jnp.array([2., 342.]),), (jnp.array([1., 2.]),))
grad5 = grad(lambda x: jnp.mean(f(x)))(jnp.array([2., 342.]))
grad6 = grad(lambda x: jnp.mean(f_jax(x)))(jnp.array([2., 342.]))

print(f"""
Check the plumbing!
Variable A:
  Primal difference: {np.linalg.norm(primal1 - primal2): .2e}
  JVP difference: {np.linalg.norm(jvp1 - jvp2): .2e}
  Gradient difference: {np.linalg.norm(grad1 - grad2): .2e}
  
Variable b:
  Primal difference: {np.linalg.norm(primal3 - primal4): .2e}
  JVP difference: {np.linalg.norm(jvp3 - jvp4): .2e}
  Gradient difference: {np.linalg.norm(grad3 - grad4): .2e} 
    
Variable A and b:
  Primal difference: {np.linalg.norm(primal5 - primal6): .2e}
  JVP difference: {np.linalg.norm(jvp5 - jvp6): .2e}
  Gradient difference: {np.linalg.norm(grad5 - grad6): .2e}
  """)

Check the plumbing!
Variable A:
  Primal difference:  1.98e-07
  JVP difference:  1.43e-07
  Gradient difference:  0.00e+00
  
Variable b:
  Primal difference:  4.56e-06
  JVP difference:  6.52e-08
  Gradient difference:  9.31e-10 
    
Variable A and b:
  Primal difference:  8.10e-06
  JVP difference:  1.83e-06
  Gradient difference:  1.82e-12
  

Yes.

Why is this better than just differentiating through the Cholesky factorisation?

The other option for making this work would’ve been to implement the Cholesky factorisation as a primitive (~which we are about to do!~ which we will do another day) and then write the sparse solver directly as a pure JAX function.

def sparse_solve_direct(A_indices, A_indptr, A_x, b, *, L_nse = None):
  L_indices, L_indptr, L_x = sparse_cholesky(A_indices, A_indptr, A_x)
  out = sparse_triangular_solve(L_indices, L_indptr, L_x, b)
  return sparse_triangular_solve(L_indices, L_indptr, L_x, out, transpose = True)

This function is JAX-traceable36 and, therefore, we could compute the gradient of it directly. It turns out that this is going to be a bad idea.

Why? Because the derivative of sparse_cholesky, which we would have to chain together with the derivatives from the solver, is pretty complicated. Basically, this means that we’d have to do a lot more work37 than we do if we just implement the symbolic formula for the derivatives.

Primitive three: The dreaded log determinant

Ok, so now we get to the good one. The log-determinant of \(A\). The first thing that we need to do is wrench out a derivative. This is not as easy as it was for the linear solve. So what follows is a modification for sparse matrices from Appendix A of Boyd’s convex optimisation book.

It’s pretty easy to convince yourself that \[\begin{align*} \log(|A + \Delta|) &= \log\left( \left|A^{1/2}(I + A^{-1/2}\Delta A^{-1/2})A^{1/2}\right|\right) \\ &= \log(|A|) + \log\left( \left|I + A^{-1/2}\Delta A^{-1/2}\right|\right). \end{align*}\]

It is harder to convince yourself how this could possibly be a useful fact.

If we write \(\lambda_i\), \(i = 1, \ldots, n\) as the eigenvalues of \(A^{-1/2}\Delta A^{-1/2}\), then we have \[ \log(|A + \Delta |) = \log(|A|) + \sum_{i=1}^n \log( 1 + \lambda_i). \] Remembering that \(\Delta\) is very small, it follows that \(A^{-1/2}\Delta A^{-1/2}\) will also be small. That translates to the eigenvalues of \(A^{-1/2}\Delta A^{-1/2}\) all being small. Therefore, we can use the approximation \(\log(1 + \lambda_i) = \lambda_i + \mathcal{O}(\lambda_i^2)\).

This means that38 \[\begin{align*} \log(|A + \Delta |) &= \log(|A|) + \sum_{i=1}^n \lambda_i + \mathcal{O}\left(\|\Delta\|^2\right) \\ &=\log(|A|) + \operatorname{tr}\left(A^{-1/2} \Delta A^{-1} \right) + \mathcal{O}\left(\|\Delta\|^2\right) \\ &= \log(|A|) + \operatorname{tr}\left(A^{-1} \Delta \right) + \mathcal{O}\left(\|\Delta\|^2\right), \end{align*}\] which follows from the cyclic property of the trace.

If we recall the formula from the last section defining the Jacobian-vector product, in our context \(m = 1\), \(x\) is the vector of non-zero entries of the lower triangle of \(A\) stacked by column, and \(\delta\) is the vector of non-zero entries of the lower triangle of \(\Delta\). That means the Jacobian-vector product is \[ J(x)\delta = \operatorname{tr}\left(A^{-1} \Delta \right) = \sum_{i=1}^n\sum_{j=1}^n[A^{-1}]_{ij} \Delta_{ij}. \]

Remembering that \(\Delta\) is sparse with the same sparsity pattern as \(A\), we see that the Jacobian-vector product requires us to know the values of \(A^{-1}\) that correspond to non-zero elements of \(A\). That’s good news because we will see that these entries are relatively cheap and easy to compute. Whereas the full inverse is dense and very expensive to compute.

But before we get to that, I need to point out a trap for young players39. Lest your implementations go down faster than me when someone asks politely.

The problem comes from how we store our matrix. A mathematician would suggest that it’s our representation. A physicist40 would shit on about being coordinate free with such passion that he41 will keep going even after you quietly leave the room.

The problem is that we only store the non-zero entries of the lower-triangular part of \(A\). This means that we need to be careful that when we compute the Jacobian-vector product that we properly compute the Matrix-vector product.

Let A_indices and A_indptr define the sparsity structure of \(A\) (and \(\Delta\)). Then if \(A_x\) is our input and \(v\) is our vector, then we need to do the follow steps to compute the Jacobian-vector product:

  1. Compute Ainv_x (aka the non-zero elements of \(A^{-1}\) that correspond to the sparsity pattern of \(A\))
  2. Compute the matrix vector product as
jvp = 2 * sum(Ainv_x * v) - sum(Ainv_x[A_indptr[:-1]] * v[A_indptr[:-1]])

Why does it look like that? Well we need to add the contribution from the upper triangle as well as the lower triangle. And one way to do that is to just double the sum and then subtract off the diagonal terms that we’ve counted twice.

(I’m making a pretty big assumption here, which is fine in our context, that \(A\) has a non-zero diagonal. If that doesn’t hold, it’s just a change of the indexing in the second term to just pull out the diagonal terms.)

Using similar reasoning, we can compute the Jacobian as \[ [J_f(x)]_{i1} = \begin{cases} \operatorname{partial-inverse}(x)_i, \qquad & x_i \text{ is a diagonal element of }A \\ 2\operatorname{partial-inverse}(x)_i, \qquad & \text{otherwise}, \end{cases} \] where \(\operatorname{partial-inverse}(x)\) is the vector that stacks the columns of the elements of \(A^{-1}\) that correspond to the non-zero elements of \(A\). (Yikes!)

Computing the partial inverse

So now we need to actually work out how to compute this partial inverse of a symmetric positive definite matrix \(A\). To do this, we are going to steal a technique that goes back to Takahashi, Fagan, and Chen42 in 1973. (For this presentation, I’m basically pillaging Håvard Rue and Sara Martino’s 2007 paper.)

Their idea was that if we write \(A = VDV^T\), where \(V\) is a lower-triangular matrix with ones on the diagonal and \(D\) is diagonal. This links up with our usual Cholesky factorisation through the identity \(L = VD^{1/2}\). It follows that if \(S = A^{-1}\), then \(VDV^TS = I\). Then, we make some magic manipulations43. \[\begin{align*} V^TS &= D^{-1}V^{-1} \\ S + V^TS &= S + D^{-1}V^{-1} \\ S &= D^{-1}V^{-1} + (I - V^T)S. \end{align*}\]

Once again, this does not look super-useful. The trick is to notice 2 things.

  1. Because \(V\) is lower triangular, \(V^{-1}\) is also lower triangular and the elements of \(V^{-1}\) are the inverse of the diagonal elements of \(V\) (aka they are all 1). Therefore, \(D^{-1}V^{-1}\) is a lower triangular matrix with a diagonal given by the diagonal of \(D^{-1}\).

  2. \(I - V^T\) is an upper triangular matrix and \([I - V^T]_{nn} = 0\).

These two things together lead to the somewhat unexpected situation where the upper triangle of \(S = D^{-1}V^{-1} + (I- V^T)S\) defines a set of recursions for the upper triangle of \(S\). (And, therefore, all of \(S\) because \(S\) is symmetric!) These are sometimes referred to as the Takahashi recursions.

But we don’t want the whole upper triangle of \(S\), we just want the ones that correspond to the non-zero elements of \(A\). Unfortunately, the set of recursions are not, in general, solveable using only that subset of \(S\). But we are in luck: they are solveable using the elements of \(S\) that correspond to the non-zeros of \(L + L^T\), which, as we know from a few posts ago, is a superset of the non-zero elements of \(A\)!

From this, we get the recursions running from \(i = n, \ldots, 1\), \(j = n, \ldots, i\) (the order is important!) such that \(L_{ji} \neq 0\) \[ S_{ji} = \begin{cases} \frac{1}{L_{ii}^2} - \frac{1}{L_{ii}}\sum_{k=i+1}^{n} L_{ki} S_{kj} \qquad& \text{if } i=j, \\ - \frac{1}{L_{ii}}\sum_{k=i+1}^{n} L_{ki} S_{kj} & \text{otherwise}. \end{cases} \]

If you recall our discussion way back when about the way the non-zero structure of the \(j\) the column of \(L\) relates to the non-zero structure of the \(i\) th column for \(j \geq i\), it’s clear that we have computed enough44 of \(S\) at every step to complete the recursions.

Now we just need to Python it. (And thanks to Finn Lindgren who helped me understand how to implement this, which he may or may not remember because it happened about five years ago.)

Actually, we need this to be JAX-traceable, so we are going to implement a very basic primitive. In particular, we don’t need to implement a derivative or anything like that, just an abstract evaluation and an implementation.

sparse_partial_inverse_p = core.Primitive("sparse_partial_inverse")

def sparse_partial_inverse(L_indices, L_indptr, L_x, out_indices, out_indptr):
  """
  Computes the elements (out_indices, out_indptr) of the inverse of a sparse matrix (A_indices, A_indptr, A_x)
   with Choleksy factor (L_indices, L_indptr, L_x). (out_indices, out_indptr) is assumed to be either
   the sparsity pattern of A or a subset of it in lower triangular form. 
  """
  return sparse_partial_inverse_p.bind(L_indices, L_indptr, L_x, out_indices, out_indptr)

@sparse_partial_inverse_p.def_abstract_eval
def sparse_partial_inverse_abstract_eval(L_indices, L_indptr, L_x, out_indices, out_indptr):
  return abstract_arrays.ShapedArray(out_indices.shape, L_x.dtype)

@sparse_partial_inverse_p.def_impl
def sparse_partial_inverse_impl(L_indices, L_indptr, L_x, out_indices, out_indptr):
  n = len(L_indptr) - 1
  Linv = sparse.dok_array((n,n), dtype = L_x.dtype)
  counter = len(L_x) - 1
  for col in range(n-1, -1, -1):
    for row in L_indices[L_indptr[col]:L_indptr[col+1]][::-1]:
      if row != col:
        Linv[row, col] = Linv[col, row] = 0.0
      else:
        Linv[row, col] = 1 / L_x[L_indptr[col]]**2
      L_col  = L_x[L_indptr[col]+1:L_indptr[col+1]] / L_x[L_indptr[col]]
 
      for k, L_kcol in zip(L_indices[L_indptr[col]+1:L_indptr[col+1]], L_col):
         Linv[col,row] = Linv[row,col] =  Linv[row, col] -  L_kcol * Linv[k, row]
        
  Linv_x = sparse.tril(Linv, format = "csc").data
  if len(out_indices) == len(L_indices):
    return Linv_x

  out_x = np.zeros(len(out_indices))
  for col in range(n):
    ind = np.nonzero(np.in1d(L_indices[L_indptr[col]:L_indptr[col+1]],
      out_indices[out_indptr[col]:out_indptr[col+1]]))[0]
    out_x[out_indptr[col]:out_indptr[col+1]] = Linv_x[L_indptr[col] + ind]
  return out_x

The implementation makes use of the45 dictionary of keys representation of a sparse matrix from scipy.sparse. This is an efficient storage scheme when you need to modify the sparsity structure (as we are doing here) or do a lot of indexing. It would definitely be possible to implement this directly on the CSC data structure, but it gets a little bit tricky to access the elements of L_inv that are above the diagonal. The resulting code is honestly a mess and there’s lots of non-local memory access anyway, so I implemented it this way.

But let’s be honest: this thing is crying out for a proper symmetric matrix class with sensible reverse iterators. But hey. Python.

The second chunk of the code is just the opposite of our _structured_copy() function. It takes a matrix with the sparsity pattern of \(L\) and returns one with the sparsity pattern of out (which is assumed to be a subset, and is usually the sparsity pattern of \(A\) or a diagonal matrix).

Let’s check that it works.

A_indices, A_indptr, A_x, A = make_matrix(15)
n = len(A_indptr) - 1


L_indices, L_indptr, L_x = sparse_cholesky(A_indices, A_indptr, A_x)

a_inv_L = sparse_partial_inverse(L_indices, L_indptr, L_x, L_indices, L_indptr)

col_counts_L = [L_indptr[i+1] - L_indptr[i] for i in range(n)]
cols_L = np.repeat(range(n), col_counts_L)

true_inv = np.linalg.inv(A.todense())
truth_L = true_inv[L_indices, cols_L]

a_inv_A = sparse_partial_inverse(L_indices, L_indptr, L_x, A_indices, A_indptr)
col_counts_A = [A_indptr[i+1] - A_indptr[i] for i in range(n)]
cols_A = np.repeat(range(n), col_counts_A)
truth_A = true_inv[A_indices, cols_A]

print(f"""
Error in partial inverse (all of L): {np.linalg.norm(a_inv_L - truth_L): .2e}
Error in partial inverse (all of A): {np.linalg.norm(a_inv_A - truth_A): .2e}
""")

Error in partial inverse (all of L):  1.57e-15
Error in partial inverse (all of A):  1.53e-15

Putting the log-determinant together

All of our bits are in place, so now all we need is to implement the primitive for the log-determinant. One nice thing here is that we don’t need to implement a transposition rule as the function is not structurally linear in any of its arguments. At this point we take our small wins where we can get them.

There isn’t anything particularly interesting in the implementation. But do note that the trace has been implemented in a way that’s aware that we’re only storing the bottom triangle of \(A\).

sparse_log_det_p = core.Primitive("sparse_log_det")

def sparse_log_det(A_indices, A_indptr, A_x):
  return sparse_log_det_p.bind(A_indices, A_indptr, A_x)

@sparse_log_det_p.def_impl
def sparse_log_det_impl(A_indices, A_indptr, A_x):
  L_indices, L_indptr, L_x = sparse_cholesky(A_indices, A_indptr, A_x)
  return 2.0 * jnp.sum(jnp.log(L_x[L_indptr[:-1]]))

@sparse_log_det_p.def_abstract_eval
def sparse_log_det_abstract_eval(A_indices, A_indptr, A_x):
  return abstract_arrays.ShapedArray((1,), A_x.dtype)

def sparse_log_det_value_and_jvp(arg_values, arg_tangent):
  A_indices, A_indptr, A_x = arg_values
  _, _, A_xt = arg_tangent
  L_indices, L_indptr, L_x = sparse_cholesky(A_indices, A_indptr, A_x)
  value = 2.0 * jnp.sum(jnp.log(L_x[L_indptr[:-1]]))
  Ainv_x = sparse_partial_inverse(L_indices, L_indptr, L_x, A_indices, A_indptr)
  jvp = 2.0 * sum(Ainv_x * A_xt) - sum(Ainv_x[A_indptr[:-1]] * A_xt[A_indptr[:-1]])
  return value, jvp

ad.primitive_jvps[sparse_log_det_p] = sparse_log_det_value_and_jvp

Finally, we can test it out.

ld_true = np.log(np.linalg.det(A.todense())) #np.sum(np.log(lu.U.diagonal()))
print(f"Error in log-determinant = {ld_true - sparse_log_det(A_indices, A_indptr, A_x): .2e}")

def f(theta):
  Ax_theta = jnp.array(theta[0] * A_x) / n
  Ax_theta = Ax_theta.at[A_indptr[:-1]].add(theta[1])
  return sparse_log_det(A_indices, A_indptr, Ax_theta)

def f_jax(theta):
  Ax_theta = jnp.array(theta[0] * A.todense()) / n 
  Ax_theta = Ax_theta.at[np.arange(n),np.arange(n)].add(theta[1])
  L = jnp.linalg.cholesky(Ax_theta)
  return 2.0*jnp.sum(jnp.log(jnp.diag(L)))

primal1, jvp1 = jvp(f, (jnp.array([2., 3.]),), (jnp.array([1., 2.]),))
primal2, jvp2 = jvp(f_jax, (jnp.array([2., 3.]),), (jnp.array([1., 2.]),))

eps = 1e-4
jvp_fd = (f(jnp.array([2.,3.]) + eps * jnp.array([1., 2.]) ) - f(jnp.array([2.,3.]))) / eps

grad1 = grad(f)(jnp.array([2., 3.]))
grad2 = grad(f_jax)(jnp.array([2., 3.]))

print(f"""
Check the Derivatives!
Variable A:
  Primal difference: {np.linalg.norm(primal1 - primal2)}
  JVP difference: {np.linalg.norm(jvp1 - jvp2)}
  JVP difference (FD): {np.linalg.norm(jvp1 - jvp_fd)}
  Gradient difference: {np.linalg.norm(grad1 - grad2)}
""")
Error in log-determinant =  0.00e+00

Check the Derivatives!
Variable A:
  Primal difference: 0.0
  JVP difference: 0.000885009765625
  JVP difference (FD): 0.221893310546875
  Gradient difference: 1.526623782410752e-05

I’m not going to lie, I am not happy with that JVP difference. I was somewhat concerned that there was a bug somewhere in my code. I did a little bit of exploring and the error got larger as the problem got larger. It also depended a little bit more than I was comfortable on how I had implemented46 the baseline dense version.

That second fact suggested to me that it might be a floating point problem. By default, JAX uses single precision (32-bit) floating point. Most modern systems that don’t try and run on GPUs use double precision (64-bit) floating point. So I tried it with double precision and lo and behold, the problem disappears.

Matrix factorisations are bloody hard in single precision.

from jax.config import config
config.update("jax_enable_x64", True)

ld_true = np.log(np.linalg.det(A.todense())) #np.sum(np.log(lu.U.diagonal()))
print(f"Error in log-determinant = {ld_true - sparse_log_det(A_indices, A_indptr, A_x): .2e}")

def f(theta):
  Ax_theta = jnp.array(theta[0] * A_x, dtype = jnp.float64) / n
  Ax_theta = Ax_theta.at[A_indptr[:-1]].add(theta[1])
  return sparse_log_det(A_indices, A_indptr, Ax_theta)

def f_jax(theta):
  Ax_theta = jnp.array(theta[0] * A.todense(), dtype = jnp.float64) / n 
  Ax_theta = Ax_theta.at[np.arange(n),np.arange(n)].add(theta[1])
  L = jnp.linalg.cholesky(Ax_theta)
  return 2.0*jnp.sum(jnp.log(jnp.diag(L)))

primal1, jvp1 = jvp(f, (jnp.array([2., 3.], dtype = jnp.float64),), (jnp.array([1., 2.], dtype = jnp.float64),))
primal2, jvp2 = jvp(f_jax, (jnp.array([2., 3.], dtype = jnp.float64),), (jnp.array([1., 2.], dtype = jnp.float64),))

eps = 1e-7
jvp_fd = (f(jnp.array([2.,3.], dtype = jnp.float64) + eps * jnp.array([1., 2.], dtype = jnp.float64) ) - f(jnp.array([2.,3.], dtype = jnp.float64))) / eps

grad1 = grad(f)(jnp.array([2., 3.], dtype = jnp.float64))
grad2 = grad(f_jax)(jnp.array([2., 3.], dtype = jnp.float64))

print(f"""
Check the Derivatives!
Variable A:
  Primal difference: {np.linalg.norm(primal1 - primal2)}
  JVP difference: {np.linalg.norm(jvp1 - jvp2)}
  JVP difference (FD): {np.linalg.norm(jvp1 - jvp_fd)}
  Gradient difference: {np.linalg.norm(grad1 - grad2)}
""")
Error in log-determinant =  0.00e+00

Check the Derivatives!
Variable A:
  Primal difference: 0.0
  JVP difference: 8.526512829121202e-13
  JVP difference (FD): 4.171707900013644e-06
  Gradient difference: 8.881784197001252e-16

Much better!

Wrapping up

And that is where we will leave it for today. Next up, I’m probably going to need to do the autodiff for the Cholesky factorisation. It’s not hard, but it is tedious47 and this post is already very long.

After that we need a few more things:

  1. Compilation rules for all of these things. For the most part, we can just wrap the relevant parts of Eigen. The only non-trivial code would be the partial inverse. That will allow us to JIT shit.

  2. We need to beef up the sparse matrix class a little. In particular, we are going to need addition and scalar multiplication at the very minimum to make this useful.

  3. Work out how Aesara works so we can try to prototype a PyMC model.

That will be a lot more blog posts. But I’m having fun. So why the hell not.

Footnotes

  1. I am sorry Cholesky factorisation, this blog is already too long and there is simply too much code I need to make nicer to even start on that journey. So it will happen in a later blog.↩︎

  2. Which I have spent zero effort making pretty or taking to any level above scratch code↩︎

  3. Like making it clear how this works for a sparse matrix compared to a general one↩︎

  4. To the best of my knowledge, for example, we don’t know how to differentiate with respect to the order parameter \(\nu\) in the modified Bessel function of the second kind \(K_\nu(x)\). This is important in spatial statistics (and general GP stuff).↩︎

  5. You may need to convince yourself that this is possible. But it is. The cone of SPD matrices is very nice.↩︎

  6. Don’t despair if you don’t recognise the third line, it’s the Neumann series, which gives an approximation to \((I + B)^{-1}\) whenever \(\|B\| \ll 1\).↩︎

  7. I recognise that I’ve not explained why everything needs to be JAX-traceable. Basically it’s because JAX does clever transformations to the Jacobian-vector product code to produce things like gradients. And the only way that can happen is if the JVP code can take abstract JAX types. So we need to make it traceable because we really want to have gradients!↩︎

  8. Why not now, Daniel? Why not now? Well mostly because I might need to do some tweaking down the line, so I am not messing around until I am done.↩︎

  9. This is the primary difference between implementing forward mode and reverse mode: there is only one output here. When we move onto reverse mode, we will output a tuple Jacobian-transpose-vector products, one for each input. You can see the structure of that reflected in the transposition rule we are going to write later.↩︎

  10. Some things: Firstly your function needs to have the correct signature for this to work. Secondly, you could also use ad.defjvp() if you didn’t need to use the primal value to define the tangent (recall one of our tangents is \(A^{-1}\Delta c\), where \(c = A^{-1}b\) is the primal value).↩︎

  11. This is because it is the efficient way of computing a gradient. Forward-mode autodiff chains together Jacobian-vector products in such a way that a single sweep of the entire function computes a single directional derivative. Reverse-mode autodiff chains together Jacobian-transpose-vector products (aka vector-Jacobian products) in such a way that a single sweep produces an entire gradient. (This happens at the cost of quite a bit of storage.) Depending on what you are trying to do, you usually want one or the other (or sometimes a clever combination of both).↩︎

  12. or gradients or some sort of thing.↩︎

  13. to be honest, in Stan we sometimes just don’t dick around with the forward-mode autodiff, because gradients are our bread and butter.↩︎

  14. I mean, love you programming language people. But fuck me this paper could’ve been written in Babylonic cuneiform for all I understood it.↩︎

  15. That is, if you fix a value of \(y\), \(f_y(x) = f(x, y)\) is not an affine function.↩︎

  16. Details bore me.↩︎

  17. In general, there might need to be a little bit of reshaping, but it’s equivalent.↩︎

  18. Have you noticed this is like the third name I’ve used for this equivalent concept. Or the fourth? The code calls it a cotangent because that’s another damn synonym. I’m so very sorry.↩︎

  19. not difficult, I’m just lazy and Mike does it better that I can. Read his paper.↩︎

  20. For sparse matrices it’s just the non-zero mask of that.↩︎

  21. Yes. I know. Central differences. I am what I am.↩︎

  22. Some of the stuff I’ve done like normalising all of the inputs would help make these tests more stable. You should also just pick up Nick Higham’s backwards error analysis book to get some ideas of what your guarantees actually are in floating point, but I truly cannot be bothered. This is scratch code.↩︎

  23. It should be slightly bigger, it isn’t.↩︎

  24. The largest number \(\epsilon\) such that float(1.0) == float(1.0 + machine_eps) in single precision floating point.↩︎

  25. Fun fact: I implemented this and the error never spawned, so I guess JAX is keeping the index arrays concrete, which is very nice of it!↩︎

  26. actual damn numbers↩︎

  27. We want that auld triangle to go jingle bloody jangle↩︎

  28. We definitely do not want someone to write an eight hour, two part play that really seems to have the point of view that our Cholesky triangle deserved his downfall. Espoused while periodically reading deadshit tumblr posts. I mean, it would win a Tony. But we still do not want that.↩︎

  29. There are more arguments. Read the help. This is what we need↩︎

  30. What if I told you that this would work perfectly well if \(A\) was a linear partial differential operator or an integral operator? Probably not much because why would you give a shit?↩︎

  31. It can be more general, but it isn’t↩︎

  32. I think there is a typo in the docs↩︎

  33. Full disclosure: I screwed this up multiple times today and my tests caught it. What does that look like? The derivatives for \(A\) being off, but everything else being good.↩︎

  34. And some optional keyword arguments, but we don’t need to worry about those↩︎

  35. This is not quite the same but similar to something that functional programming people call currying, which was named after famous Australian Olympic swimmer Lisa Curry.↩︎

  36. and a shitload simpler!↩︎

  37. And we have to store a bunch more. This is less of a big deal when \(L\) is sparse, but for an ordinary linear solve, we’d be hauling around an extra \(\mathcal{O}(n^2)\) floats containing tangents for no good reason.↩︎

  38. If you are worrying about the suppressed constant, remember that \(A\) (and therefore \(n\) and \(\|A\|\)) is fixed.↩︎

  39. I think I’ve made this mistake about four times already while writing this blog. So I am going to write it out.↩︎

  40. Not to “some of my best friends are physicists”, but I do love them. I just wished a man would talk about me the way they talk about being coordinate free. Rather than with the same ambivalence physicist use when speaking about a specific atlas. I’ve been listening to lesbian folk music all evening. I’m having feelings.↩︎

  41. pronoun on purpose↩︎

  42. Takahashi, K., Fagan, J., Chen, M.S., 1973. Formation of a sparse bus impedance matrix and its application to short circuit study. In: Eighth PICA Conference Proceedings.IEEE Power Engineering Society, pp. 63–69 (Papers Presented at the 1973 Power Industry Computer Application Conference in Minneapolis, MN).↩︎

  43. Thanks to Jerzy Baranowski for finding a very very bad LaTeX error that made these questions quite wrong!↩︎

  44. Indeed, in the notation of post two \(\mathcal{L}_i \cap \{i+1, \dots, n\} \subseteq \mathcal{L}_j\) for all \(i \leq j\), where \(\mathcal{L}_i\) is the set of non-zeros in the \(i\)th column of \(L\).↩︎

  45. The sparse matrix is stored as a dictionary {(i,j): value}, which is a very natural way to build a sparse matrix, even if its quite inefficient to do anything with it in that form.↩︎

  46. You can’t just use jnp.linalg.det() because there’s a tendency towards nans. (The true value is something like r exp(250.49306761204593)!)↩︎

  47. Would it be less tedious if my implementation of the Cholesky was less shit? Yes. But hey. It was the first non-trivial piece of python code I’d written in more than a decade (or maybe ever?) so it is what it is. Anyway. I’m gonna run into the same problem I had in Part 3↩︎

Reuse

Citation

BibTeX citation:
@online{simpson2022,
  author = {Dan Simpson},
  editor = {},
  title = {Sparse Matrices 6: {To} Catch a Derivative, First You’ve Got
    to Think Like a Derivative},
  date = {2022-05-30},
  url = {https://dansblog.netlify.app/to-catch-a-derivative-first-youve-got-to-think-like-a-derivative},
  langid = {en}
}
For attribution, please cite this work as:
Dan Simpson. 2022. “Sparse Matrices 6: To Catch a Derivative, First You’ve Got to Think Like a Derivative.” May 30, 2022. https://dansblog.netlify.app/to-catch-a-derivative-first-youve-got-to-think-like-a-derivative.