Polymorphic Operations on Linear Maps
Often when writing numerical code, there are efficient specializations for
e.g. diagonal matrices, lower-triangular, or sparse matrices.
I'd like to be able to write just one version of this code without having to
fall back to casting everything to full-rank unstructured matrices.
This example is for brainstorming how the typeclass system might support
efficient polymorphism for linear algebra over different types of
structured maps.
Todo: factor out a more general linear map typeclass, and make this one inherit from it.
Linear Endomorphisms
a.k.a. linear maps from a space back to that same space.
interface HasDeterminant(m)
determinant': (m) -> Float
transposeType : Type
transpose': (m) -> m
identity': m
interface LinearEndo(m|HasDeterminant, v|VSpace)
apply: (m, v) -> v
diag: (m) -> v
solve': (m, v) -> v
We'd like to remove v
from the LinearEndo
interface,
and instead use associated types to specify a v
for each m
.
This would let use combine it with HasDeterminant
.
But for now, fields in typeclasses can't refer to one another.
This means that determinant'
and other operations can't be part
of this typeclass yet, because v
is always ambiguous at its usage site.
struct ScalarMap = val: Float
instance HasDeterminant(ScalarMap)
def determinant'(a) = a.val
transposeType = Float
def transpose'(x) = x
identity' = ScalarMap(1.0)
instance LinearEndo(ScalarMap, v) given (v|Mul|VSpace)
def apply(a, b) = a.val .* b
def diag(a) = a.val .* one
def solve'(a, b) = b / a.val
instance Arbitrary(ScalarMap)
def arb(k) = ScalarMap $ arb k
struct DiagMap(n|Ix) = val: (n=>Float)
instance HasDeterminant(DiagMap n) given (n|Ix)
def determinant'(x) = prod x.val
transposeType = n=>Float
def transpose'(x) = x
identity' = DiagMap one
instance LinearEndo(DiagMap n, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = for i. x.val[i] .* y[i]
def diag(a) = for i. a.val[i] .* one
def solve'(a, b) = for i. b[i] / a.val[i]
instance Arbitrary(DiagMap n) given (n|Ix)
def arb(k) = DiagMap $ arb k
Full-rank matrices.
I didn't use a newtype for these, but I'm not sure if that's the right call.
instance HasDeterminant(n=>n=>Float) given (n|Ix)
def determinant'(m) = determinant m
transposeType = n=>n=>Float
def transpose'(m) = transpose m
identity' = eye
instance LinearEndo(n=>n=>Float, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = for i. sum for j. x[i, j] .* y[j]
def diag(x) = for i. x[i, i] .* one
def solve'(m, vec) = solve m vec
struct LowerTriMap(n|Ix) = val : ((i:n)=>(..i)=>Float)
instance HasDeterminant(LowerTriMap n) given (n|Ix)
def determinant'(x) = prod $ lower_tri_diag x.val
transposeType = UpperTriMat n Float
def transpose'(_) = error "Can't transpose to different types yet."
identity' = LowerTriMap for i j. select (ordinal i == ordinal j) one zero
instance LinearEndo(LowerTriMap n, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = for i. sum for j. x.val[i, j] .* y[inject j]
def diag(x) = for i. x.val[i, unsafe_project i] .* one
def solve'(x, y) = forward_substitute x.val y
instance Arbitrary(LowerTriMap n) given (n|Ix)
def arb(k) = LowerTriMap $ arb k
struct UpperTriMap(n|Ix) = val: ((i:n)=>(i..)=>Float)
instance HasDeterminant(UpperTriMap n) given (n|Ix)
def determinant'(x) = prod $ upper_tri_diag x.val
transposeType = LowerTriMat n Float
def transpose'(_) = error "Can't transpose to different types yet."
identity' = UpperTriMap for i j. select (0 == ordinal j) one zero
instance LinearEndo(UpperTriMap n, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = for i. sum for j. x.val[i, j] .* y[inject j]
def diag(x) = for i. x.val[i, 0@_] .* one
def solve'(x, y) = backward_substitute x.val y
instance Arbitrary(UpperTriMap n) given (n|Ix)
def arb(k) = UpperTriMap $ arb k
struct SkewSymmetricMap(n|Ix) = val: ((i:n)=>(..<i)=>Float)
instance HasDeterminant(SkewSymmetricMap n) given (n|Ix)
def determinant'(a) = case is_odd (size n) of
True -> zero
False ->
dense_rep = skew_symmetric_prod a.val eye
determinant dense_rep
transposeType = (i:n)=>(..<i)=>Float --
def transpose'(x) = SkewSymmetricMap (-x.val)
identity' = error "Skew symmetric matrices can't represent the identity map."
instance LinearEndo(SkewSymmetricMap n, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = skew_symmetric_prod x.val y
def diag(x) = zero
def solve'(x, y) =
dense_rep = skew_symmetric_prod x.val eye
solve dense_rep y
instance Arbitrary(SkewSymmetricMap n) given (n|Ix)
def arb(k) = SkewSymmetricMap $ arb k
interface HasStandardNormal(a:Type)
randNormal : (Key) -> a
instance HasStandardNormal(Float32)
def randNormal(k) = randn k
instance HasStandardNormal(n=>a) given (n|Ix, a|HasStandardNormal)
def randNormal(key) =
for i. randNormal (ixkey key i)
def multivariate_gaussian_sample(mean:v, covroot:m, key:Key)
-> v given (v|HasStandardNormal, m) (LinearEndo m v) =
noise = randNormal key
mean + apply covroot noise
:t multivariate_gaussian_sample 1.0 (ScalarMap 2.0) (new_key 0)
Float32
Generic log pdf of a multivariate Gaussian
This single definition of a Gaussian log pdf should work
efficiently for any type of covariance matrix for which
an efficient solve and determinant is known.
def get_VSpace_dim(x:v) -> Float given (v|Mul|VSpace|InnerProd) =
one' : v = one
inner_prod one' one'
def gaussian_log_pdf(mean:v, covroot:m, x:v)
-> Float given (m, v|Mul|InnerProd) (LinearEndo m v) =
dim = get_VSpace_dim x
squarepart = inner_prod (x - mean) (solve' (transpose' covroot)
(solve' covroot (x - mean)))
const = dim * log (2.0 * pi) + log (sq (determinant' covroot))
-0.5 * (squarepart + const)
full_mat_type = (vec_len=>vec_len=>Float)
Check application of the identity is a no-op.
def check_identity(m|HasDeterminant, given () (LinearEndo m v)) -> Bool =
i : m = identity'
vec : v = (arb $ new_key 0)
vec ~~ apply i vec
check_identity $ LowerTriMap vec_len
True
check_identity $ UpperTriMap vec_len
True
check_identity $ vec_len=>vec_len=>Float
True
check_identity $ DiagMap vec_len
True
check_identity $ ScalarMap
True
def check_inverse(m|Arbitrary, given () (LinearEndo m v, LinearEndo m full_mat_type))
-> Bool =
a : m = arb $ new_key 0
vec : v = arb $ new_key 0
inv : full_mat_type = solve' a eye
full : full_mat_type = apply a eye
apply inv full ~~ eye
check_inverse $ SkewSymmetricMap vec_len
True
check_inverse $ LowerTriMap vec_len
True
check_inverse $ UpperTriMap vec_len
True
check_inverse $ vec_len=>vec_len=>Float
True
check_inverse $ DiagMap vec_len
True
check_inverse ScalarMap
True
def check_transpose(m|HasDeterminant|Arbitrary, given () (LinearEndo m v)) -> Bool =
a : m = arb $ new_key 0
(vec1, vec2) : (v, v) = arb $ new_key 1
hitleft = inner_prod (apply a vec1) vec2
hitright = inner_prod (apply (transpose' a) vec2) vec1
hitleft ~~ hitright
check_transpose $ SkewSymmetricMap vec_len
True
check_transpose $ vec_len=>vec_len=>Float
True
check_transpose $ DiagMap vec_len
True
check_transpose ScalarMap
True
xs = linspace sizen (-span) span
integral = 2.0 * span * mean for i.
exp $ gaussian_log_pdf (-0.1) (ScalarMap 0.07) xs[i]
def check_2D_Gaussian_normalizes(m|Arbitrary, given () (LinearEndo m ((Fin 2) => Float32)))
-> Bool =
sizen = Fin 200
span = 10.0
xs = linspace sizen (-span) span
covroot : m = arb $ new_key 0
meanvec : ((Fin 2) => Float32) = arb $ new_key 1
integral = (sq (2.0 * span)) * mean for pair:(sizen, sizen).
(i, j) = pair
x = [xs[i], xs[j]]
exp $ gaussian_log_pdf meanvec covroot x
integral ~~ 1.0
check_2D_Gaussian_normalizes $ (Fin 2)=>(Fin 2)=>Float
True
check_2D_Gaussian_normalizes $ DiagMap (Fin 2)
True
def Drift(v|VSpace) -> Type = (v, Time) -> v
def Diffusion(m:Type, v:Type, given () (LinearEndo m v)) -> Type = (v, Time) -> m
def SDE(m:Type, v:Type, given () (LinearEndo m v)) -> Type =
(Drift v, Diffusion m v)
def radon_nikodym(
drift1:Drift s,
drift2:Drift s,
diffusion:Diffusion m s,
state:s,
t:Time
) -> Float given (m, s|InnerProd) (LinearEndo m s) =
difference = (drift1 state t) - (drift2 state t)
cur_diffusion = diffusion state t
a = solve' cur_diffusion difference
0.5 * inner_prod a a
Stationary SDEs
From Equation 3, Section 2.1 of "A Complete Recipe for Stochastic Gradient
MCMC":
Every SDE with a stationary distribution can be parameterized
by:
- A state-dependent energy function
- A state-dependent skew-symmetric matrix
- A state-dependent diffusion matrix
The function below converts these matrices into the drift and diffusion which,
if followed, will converge to a stationary distribution whose marginal
log-density is equal to the negative energy function (plus a constant).
def StationaryDiffusion(m:Type, v:Type, given () (LinearEndo m v)) -> Type = (v)->m
def NegEnergyFunc(v:Type) -> Type = (v)->Float
def SkewSymmetricFunc(n|Ix, v:Type) -> Type = (v)->((i:n)=>(..<i)=>Float)
def StationarySDEParts(n|Ix, v|VSpace, given () (LinearEndo n v)) -> Type =
(NegEnergyFunc v, SkewSymmetricFunc n v, StationaryDiffusion n v)
def stationary_SDE_parts_to_SDE(
parts:StationarySDEParts n (n=>v)
) -> (SDE n (n=>v)) given (n|Ix, v|Mul|VSpace) (LinearEndo n (n=>v)) =
(neg_energy_func, skew_symm_map, diffusion_func) = parts
drift = \state time.
diffusion_prod = \vec.
cur_diffusion_root = diffusion_func state
0.5 .* (apply cur_diffusion_root (apply cur_diffusion_root vec))
neg_energy_grad = (grad neg_energy_func) state
skew_term = skew_symmetric_prod (skew_symm_map state) neg_energy_grad
diff_term = diffusion_prod neg_energy_grad
gammapart = \state.
skew_term' = skew_symmetric_prod (skew_symm_map state) one
diff_term' = diffusion_prod one
skew_term' + diff_term'
gamma_term = jvp gammapart state one
skew_term + diff_term + gamma_term
diffusion = \state time. diffusion_func state
(drift, diffusion)