Polymorphic Operations on Linear Maps

Often when writing numerical code, there are efficient specializations for e.g. diagonal matrices, lower-triangular, or sparse matrices. I'd like to be able to write just one version of this code without having to fall back to casting everything to full-rank unstructured matrices. This example is for brainstorming how the typeclass system might support efficient polymorphism for linear algebra over different types of structured maps.

import linalg

Todo: factor out a more general linear map typeclass, and make this one inherit from it.

-- interface [VSpace m, VSpace in, VSpace out] LinearMap m in out
-- apply: m -> in -> out
-- transpose': m -> m

Linear Endomorphisms

a.k.a. linear maps from a space back to that same space.

interface HasDeterminant(m) determinant': (m) -> Float transposeType : Type -- unused for now transpose': (m) -> m -- In future: m -> transposeType m identity': m
interface LinearEndo(m|HasDeterminant, v|VSpace) apply: (m, v) -> v diag: (m) -> v solve': (m, v) -> v

We'd like to remove v from the LinearEndo interface, and instead use associated types to specify a v for each m. This would let use combine it with HasDeterminant. But for now, fields in typeclasses can't refer to one another. This means that determinant' and other operations can't be part of this typeclass yet, because v is always ambiguous at its usage site.


Scalar maps

struct ScalarMap = val: Float
instance HasDeterminant(ScalarMap) def determinant'(a) = a.val transposeType = Float def transpose'(x) = x identity' = ScalarMap(1.0)
instance LinearEndo(ScalarMap, v) given (v|Mul|VSpace) def apply(a, b) = a.val .* b def diag(a) = a.val .* one def solve'(a, b) = b / a.val
instance Arbitrary(ScalarMap) def arb(k) = ScalarMap $ arb k

Diagonal maps

struct DiagMap(n|Ix) = val: (n=>Float)
instance HasDeterminant(DiagMap n) given (n|Ix) def determinant'(x) = prod x.val transposeType = n=>Float def transpose'(x) = x identity' = DiagMap one
instance LinearEndo(DiagMap n, n=>v) given (n|Ix, v|Mul|VSpace) def apply(x, y) = for i. x.val[i] .* y[i] def diag(a) = for i. a.val[i] .* one def solve'(a, b) = for i. b[i] / a.val[i]
instance Arbitrary(DiagMap n) given (n|Ix) def arb(k) = DiagMap $ arb k

Full-rank matrices.

I didn't use a newtype for these, but I'm not sure if that's the right call.

instance HasDeterminant(n=>n=>Float) given (n|Ix) def determinant'(m) = determinant m transposeType = n=>n=>Float def transpose'(m) = transpose m identity' = eye
instance LinearEndo(n=>n=>Float, n=>v) given (n|Ix, v|Mul|VSpace) def apply(x, y) = for i. sum for j. x[i, j] .* y[j] def diag(x) = for i. x[i, i] .* one def solve'(m, vec) = solve m vec

Lower-triangular maps

struct LowerTriMap(n|Ix) = val : ((i:n)=>(..i)=>Float)
instance HasDeterminant(LowerTriMap n) given (n|Ix) def determinant'(x) = prod $ lower_tri_diag x.val transposeType = UpperTriMat n Float def transpose'(_) = error "Can't transpose to different types yet." identity' = LowerTriMap for i j. select (ordinal i == ordinal j) one zero
instance LinearEndo(LowerTriMap n, n=>v) given (n|Ix, v|Mul|VSpace) def apply(x, y) = for i. sum for j. x.val[i, j] .* y[inject j] def diag(x) = for i. x.val[i, unsafe_project i] .* one def solve'(x, y) = forward_substitute x.val y
instance Arbitrary(LowerTriMap n) given (n|Ix) def arb(k) = LowerTriMap $ arb k

Upper-triangular maps

struct UpperTriMap(n|Ix) = val: ((i:n)=>(i..)=>Float)
instance HasDeterminant(UpperTriMap n) given (n|Ix) def determinant'(x) = prod $ upper_tri_diag x.val transposeType = LowerTriMat n Float def transpose'(_) = error "Can't transpose to different types yet." identity' = UpperTriMap for i j. select (0 == ordinal j) one zero
instance LinearEndo(UpperTriMap n, n=>v) given (n|Ix, v|Mul|VSpace) def apply(x, y) = for i. sum for j. x.val[i, j] .* y[inject j] def diag(x) = for i. x.val[i, 0@_] .* one def solve'(x, y) = backward_substitute x.val y
instance Arbitrary(UpperTriMap n) given (n|Ix) def arb(k) = UpperTriMap $ arb k

Skew-symmetric maps

struct SkewSymmetricMap(n|Ix) = val: ((i:n)=>(..<i)=>Float)
instance HasDeterminant(SkewSymmetricMap n) given (n|Ix) def determinant'(a) = case is_odd (size n) of True -> zero False -> dense_rep = skew_symmetric_prod a.val eye determinant dense_rep -- Naive algorithm, could be done with Pfaffian transposeType = (i:n)=>(..<i)=>Float -- def transpose'(x) = SkewSymmetricMap (-x.val) identity' = error "Skew symmetric matrices can't represent the identity map."
instance LinearEndo(SkewSymmetricMap n, n=>v) given (n|Ix, v|Mul|VSpace) def apply(x, y) = skew_symmetric_prod x.val y def diag(x) = zero def solve'(x, y) = dense_rep = skew_symmetric_prod x.val eye -- Fall back to naive algorithm solve dense_rep y
instance Arbitrary(SkewSymmetricMap n) given (n|Ix) def arb(k) = SkewSymmetricMap $ arb k
-- Todo: Sparse matrices. Need hashmaps for these to be practical.

Application 1: Gaussians

-- This typeclass will be obsolete once the `Basis` typeclass can be written.
interface HasStandardNormal(a:Type) randNormal : (Key) -> a
instance HasStandardNormal(Float32) def randNormal(k) = randn k
instance HasStandardNormal(n=>a) given (n|Ix, a|HasStandardNormal) def randNormal(key) = for i. randNormal (ixkey key i)
def multivariate_gaussian_sample(mean:v, covroot:m, key:Key) -> v given (v|HasStandardNormal, m) (LinearEndo m v) = noise = randNormal key mean + apply covroot noise
:t multivariate_gaussian_sample 1.0 (ScalarMap 2.0) (new_key 0)

Generic log pdf of a multivariate Gaussian

This single definition of a Gaussian log pdf should work efficiently for any type of covariance matrix for which an efficient solve and determinant is known.

-- This helper will be osbolete once the basis typeclass works.
def get_VSpace_dim(x:v) -> Float given (v|Mul|VSpace|InnerProd) = one' : v = one inner_prod one' one'
def gaussian_log_pdf(mean:v, covroot:m, x:v) -> Float given (m, v|Mul|InnerProd) (LinearEndo m v) = dim = get_VSpace_dim x squarepart = inner_prod (x - mean) (solve' (transpose' covroot) (solve' covroot (x - mean))) const = dim * log (2.0 * pi) + log (sq (determinant' covroot)) -0.5 * (squarepart + const)


vec_len = (Fin 4)
full_mat_type = (vec_len=>vec_len=>Float)
v = vec_len=>Float

Check application of the identity is a no-op.

def check_identity(m|HasDeterminant, given () (LinearEndo m v)) -> Bool = i : m = identity' vec : v = (arb $ new_key 0) vec ~~ apply i vec
check_identity $ LowerTriMap vec_len
check_identity $ UpperTriMap vec_len
check_identity $ vec_len=>vec_len=>Float
check_identity $ DiagMap vec_len
check_identity $ ScalarMap
def check_inverse(m|Arbitrary, given () (LinearEndo m v, LinearEndo m full_mat_type)) -> Bool = a : m = arb $ new_key 0 vec : v = arb $ new_key 0 inv : full_mat_type = solve' a eye full : full_mat_type = apply a eye apply inv full ~~ eye
check_inverse $ SkewSymmetricMap vec_len
check_inverse $ LowerTriMap vec_len
check_inverse $ UpperTriMap vec_len
check_inverse $ vec_len=>vec_len=>Float
check_inverse $ DiagMap vec_len
check_inverse ScalarMap
def check_transpose(m|HasDeterminant|Arbitrary, given () (LinearEndo m v)) -> Bool = a : m = arb $ new_key 0 (vec1, vec2) : (v, v) = arb $ new_key 1 hitleft = inner_prod (apply a vec1) vec2 hitright = inner_prod (apply (transpose' a) vec2) vec1 hitleft ~~ hitright
check_transpose $ SkewSymmetricMap vec_len
-- Disabled until associated types are implemented.
-- check_transpose $ LowerTriMap vec_len
-- > True
-- check_transpose $ UpperTriMap vec_len
-- > True
check_transpose $ vec_len=>vec_len=>Float
check_transpose $ DiagMap vec_len
check_transpose ScalarMap
-- Check that 1D Gaussian sums to 1
sizen = Fin 2000
span = 10.0
xs = linspace sizen (-span) span
integral = 2.0 * span * mean for i. exp $ gaussian_log_pdf (-0.1) (ScalarMap 0.07) xs[i]
integral ~~ 1.0
def check_2D_Gaussian_normalizes(m|Arbitrary, given () (LinearEndo m ((Fin 2) => Float32))) -> Bool = sizen = Fin 200 span = 10.0 xs = linspace sizen (-span) span covroot : m = arb $ new_key 0 meanvec : ((Fin 2) => Float32) = arb $ new_key 1 integral = (sq (2.0 * span)) * mean for pair:(sizen, sizen). (i, j) = pair x = [xs[i], xs[j]] exp $ gaussian_log_pdf meanvec covroot x integral ~~ 1.0
-- Disable due to FP tolerance issue
-- check_2D_Gaussian_normalizes (SkewSymmetricMap (Fin 2))
-- > True
-- Disabled until associated types are implemented.
-- check_2D_Gaussian_normalizes (LowerTriMat (Fin 2) Float)
-- > True
-- check_2D_Gaussian_normalizes (UpperTriMat (Fin 2) Float)
-- > True
check_2D_Gaussian_normalizes $ (Fin 2)=>(Fin 2)=>Float
check_2D_Gaussian_normalizes $ DiagMap (Fin 2)

Application 2: SDEs

struct Time = val: Float
def Drift(v|VSpace) -> Type = (v, Time) -> v
def Diffusion(m:Type, v:Type, given () (LinearEndo m v)) -> Type = (v, Time) -> m
def SDE(m:Type, v:Type, given () (LinearEndo m v)) -> Type = (Drift v, Diffusion m v)
def radon_nikodym( drift1:Drift s, drift2:Drift s, diffusion:Diffusion m s, state:s, t:Time ) -> Float given (m, s|InnerProd) (LinearEndo m s) = -- By Girsanov's theorem, this gives a simple Monte Carlo estimator -- of the (loosely speaking) instantaneous KL divergence between -- two SDEs that share a diffusion function. difference = (drift1 state t) - (drift2 state t) cur_diffusion = diffusion state t a = solve' cur_diffusion difference 0.5 * inner_prod a a

Stationary SDEs

From Equation 3, Section 2.1 of "A Complete Recipe for Stochastic Gradient MCMC": Every SDE with a stationary distribution can be parameterized by:

  1. A state-dependent energy function
  2. A state-dependent skew-symmetric matrix
  3. A state-dependent diffusion matrix

The function below converts these matrices into the drift and diffusion which, if followed, will converge to a stationary distribution whose marginal log-density is equal to the negative energy function (plus a constant).

def StationaryDiffusion(m:Type, v:Type, given () (LinearEndo m v)) -> Type = (v)->m
def NegEnergyFunc(v:Type) -> Type = (v)->Float
def SkewSymmetricFunc(n|Ix, v:Type) -> Type = (v)->((i:n)=>(..<i)=>Float)
def StationarySDEParts(n|Ix, v|VSpace, given () (LinearEndo n v)) -> Type = (NegEnergyFunc v, SkewSymmetricFunc n v, StationaryDiffusion n v)
def stationary_SDE_parts_to_SDE( parts:StationarySDEParts n (n=>v) ) -> (SDE n (n=>v)) given (n|Ix, v|Mul|VSpace) (LinearEndo n (n=>v)) = (neg_energy_func, skew_symm_map, diffusion_func) = parts drift = \state time. diffusion_prod = \vec. -- Square the root of the covariance matrix cur_diffusion_root = diffusion_func state 0.5 .* (apply cur_diffusion_root (apply cur_diffusion_root vec)) neg_energy_grad = (grad neg_energy_func) state skew_term = skew_symmetric_prod (skew_symm_map state) neg_energy_grad diff_term = diffusion_prod neg_energy_grad gammapart = \state. skew_term' = skew_symmetric_prod (skew_symm_map state) one diff_term' = diffusion_prod one skew_term' + diff_term' gamma_term = jvp gammapart state one skew_term + diff_term + gamma_term diffusion = \state time. diffusion_func state (drift, diffusion)
-- Todo: tests for SDE functions.