Polymorphic Operations on Linear Maps
Often when writing numerical code, there are efficient specializations for
e.g. diagonal matrices, lower-triangular, or sparse matrices.
I'd like to be able to write just one version of this code without having to
fall back to casting everything to full-rank unstructured matrices.
This example is for brainstorming how the typeclass system might support
efficient polymorphism for linear algebra over different types of
structured maps.
Todo: factor out a more general linear map typeclass, and make this one inherit from it.
Linear Endomorphisms
a.k.a. linear maps from a space back to that same space.
interface HasDeterminant(m)
determinant': (m) -> Float
transposeType : Type
transpose': (m) -> m
identity': m
interface LinearEndo(m|HasDeterminant, v|VSpace)
apply: (m, v) -> v
diag: (m) -> v
solve': (m, v) -> v
We'd like to remove v from the LinearEndo interface,
and instead use associated types to specify a v for each m.
This would let use combine it with HasDeterminant.
But for now, fields in typeclasses can't refer to one another.
This means that determinant' and other operations can't be part
of this typeclass yet, because v is always ambiguous at its usage site.
struct ScalarMap = val: Float
instance HasDeterminant(ScalarMap)
def determinant'(a) = a.val
transposeType = Float
def transpose'(x) = x
identity' = ScalarMap(1.0)
instance LinearEndo(ScalarMap, v) given (v|Mul|VSpace)
def apply(a, b) = a.val .* b
def diag(a) = a.val .* one
def solve'(a, b) = b / a.val
instance Arbitrary(ScalarMap)
def arb(k) = ScalarMap $ arb k
struct DiagMap(n|Ix) = val: (n=>Float)
instance HasDeterminant(DiagMap n) given (n|Ix)
def determinant'(x) = prod x.val
transposeType = n=>Float
def transpose'(x) = x
identity' = DiagMap one
instance LinearEndo(DiagMap n, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = for i. x.val[i] .* y[i]
def diag(a) = for i. a.val[i] .* one
def solve'(a, b) = for i. b[i] / a.val[i]
instance Arbitrary(DiagMap n) given (n|Ix)
def arb(k) = DiagMap $ arb k
Full-rank matrices.
I didn't use a newtype for these, but I'm not sure if that's the right call.
instance HasDeterminant(n=>n=>Float) given (n|Ix)
def determinant'(m) = determinant m
transposeType = n=>n=>Float
def transpose'(m) = transpose m
identity' = eye
instance LinearEndo(n=>n=>Float, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = for i. sum for j. x[i, j] .* y[j]
def diag(x) = for i. x[i, i] .* one
def solve'(m, vec) = solve m vec
struct LowerTriMap(n|Ix) = val : ((i:n)=>(..i)=>Float)
instance HasDeterminant(LowerTriMap n) given (n|Ix)
def determinant'(x) = prod $ lower_tri_diag x.val
transposeType = UpperTriMat n Float
def transpose'(_) = error "Can't transpose to different types yet."
identity' = LowerTriMap for i j. select (ordinal i == ordinal j) one zero
instance LinearEndo(LowerTriMap n, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = for i. sum for j. x.val[i, j] .* y[inject j]
def diag(x) = for i. x.val[i, unsafe_project i] .* one
def solve'(x, y) = forward_substitute x.val y
instance Arbitrary(LowerTriMap n) given (n|Ix)
def arb(k) = LowerTriMap $ arb k
struct UpperTriMap(n|Ix) = val: ((i:n)=>(i..)=>Float)
instance HasDeterminant(UpperTriMap n) given (n|Ix)
def determinant'(x) = prod $ upper_tri_diag x.val
transposeType = LowerTriMat n Float
def transpose'(_) = error "Can't transpose to different types yet."
identity' = UpperTriMap for i j. select (0 == ordinal j) one zero
instance LinearEndo(UpperTriMap n, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = for i. sum for j. x.val[i, j] .* y[inject j]
def diag(x) = for i. x.val[i, 0@_] .* one
def solve'(x, y) = backward_substitute x.val y
instance Arbitrary(UpperTriMap n) given (n|Ix)
def arb(k) = UpperTriMap $ arb k
struct SkewSymmetricMap(n|Ix) = val: ((i:n)=>(..<i)=>Float)
instance HasDeterminant(SkewSymmetricMap n) given (n|Ix)
def determinant'(a) = case is_odd (size n) of
True -> zero
False ->
dense_rep = skew_symmetric_prod a.val eye
determinant dense_rep
transposeType = (i:n)=>(..<i)=>Float --
def transpose'(x) = SkewSymmetricMap (-x.val)
identity' = error "Skew symmetric matrices can't represent the identity map."
instance LinearEndo(SkewSymmetricMap n, n=>v) given (n|Ix, v|Mul|VSpace)
def apply(x, y) = skew_symmetric_prod x.val y
def diag(x) = zero
def solve'(x, y) =
dense_rep = skew_symmetric_prod x.val eye
solve dense_rep y
instance Arbitrary(SkewSymmetricMap n) given (n|Ix)
def arb(k) = SkewSymmetricMap $ arb k
interface HasStandardNormal(a:Type)
randNormal : (Key) -> a
instance HasStandardNormal(Float32)
def randNormal(k) = randn k
instance HasStandardNormal(n=>a) given (n|Ix, a|HasStandardNormal)
def randNormal(key) =
for i. randNormal (ixkey key i)
def multivariate_gaussian_sample(mean:v, covroot:m, key:Key)
-> v given (v|HasStandardNormal, m) (LinearEndo m v) =
noise = randNormal key
mean + apply covroot noise
:t multivariate_gaussian_sample 1.0 (ScalarMap 2.0) (new_key 0)
Float32
Generic log pdf of a multivariate Gaussian
This single definition of a Gaussian log pdf should work
efficiently for any type of covariance matrix for which
an efficient solve and determinant is known.
def get_VSpace_dim(x:v) -> Float given (v|Mul|VSpace|InnerProd) =
one' : v = one
inner_prod one' one'
def gaussian_log_pdf(mean:v, covroot:m, x:v)
-> Float given (m, v|Mul|InnerProd) (LinearEndo m v) =
dim = get_VSpace_dim x
squarepart = inner_prod (x - mean) (solve' (transpose' covroot)
(solve' covroot (x - mean)))
const = dim * log (2.0 * pi) + log (sq (determinant' covroot))
-0.5 * (squarepart + const)
full_mat_type = (vec_len=>vec_len=>Float)
Check application of the identity is a no-op.
def check_identity(m|HasDeterminant, given () (LinearEndo m v)) -> Bool =
i : m = identity'
vec : v = (arb $ new_key 0)
vec ~~ apply i vec
check_identity $ LowerTriMap vec_len
True
check_identity $ UpperTriMap vec_len
True
check_identity $ vec_len=>vec_len=>Float
True
check_identity $ DiagMap vec_len
True
check_identity $ ScalarMap
True
def check_inverse(m|Arbitrary, given () (LinearEndo m v, LinearEndo m full_mat_type))
-> Bool =
a : m = arb $ new_key 0
vec : v = arb $ new_key 0
inv : full_mat_type = solve' a eye
full : full_mat_type = apply a eye
apply inv full ~~ eye
check_inverse $ SkewSymmetricMap vec_len
True
check_inverse $ LowerTriMap vec_len
True
check_inverse $ UpperTriMap vec_len
True
check_inverse $ vec_len=>vec_len=>Float
True
check_inverse $ DiagMap vec_len
True
check_inverse ScalarMap
True
def check_transpose(m|HasDeterminant|Arbitrary, given () (LinearEndo m v)) -> Bool =
a : m = arb $ new_key 0
(vec1, vec2) : (v, v) = arb $ new_key 1
hitleft = inner_prod (apply a vec1) vec2
hitright = inner_prod (apply (transpose' a) vec2) vec1
hitleft ~~ hitright
check_transpose $ SkewSymmetricMap vec_len
True
check_transpose $ vec_len=>vec_len=>Float
True
check_transpose $ DiagMap vec_len
True
check_transpose ScalarMap
True
xs = linspace sizen (-span) span
integral = 2.0 * span * mean for i.
exp $ gaussian_log_pdf (-0.1) (ScalarMap 0.07) xs[i]
def check_2D_Gaussian_normalizes(m|Arbitrary, given () (LinearEndo m ((Fin 2) => Float32)))
-> Bool =
sizen = Fin 200
span = 10.0
xs = linspace sizen (-span) span
covroot : m = arb $ new_key 0
meanvec : ((Fin 2) => Float32) = arb $ new_key 1
integral = (sq (2.0 * span)) * mean for pair:(sizen, sizen).
(i, j) = pair
x = [xs[i], xs[j]]
exp $ gaussian_log_pdf meanvec covroot x
integral ~~ 1.0
check_2D_Gaussian_normalizes $ (Fin 2)=>(Fin 2)=>Float
True
check_2D_Gaussian_normalizes $ DiagMap (Fin 2)
True
def Drift(v|VSpace) -> Type = (v, Time) -> v
def Diffusion(m:Type, v:Type, given () (LinearEndo m v)) -> Type = (v, Time) -> m
def SDE(m:Type, v:Type, given () (LinearEndo m v)) -> Type =
(Drift v, Diffusion m v)
def radon_nikodym(
drift1:Drift s,
drift2:Drift s,
diffusion:Diffusion m s,
state:s,
t:Time
) -> Float given (m, s|InnerProd) (LinearEndo m s) =
difference = (drift1 state t) - (drift2 state t)
cur_diffusion = diffusion state t
a = solve' cur_diffusion difference
0.5 * inner_prod a a
Stationary SDEs
From Equation 3, Section 2.1 of "A Complete Recipe for Stochastic Gradient
MCMC":
Every SDE with a stationary distribution can be parameterized
by:
- A state-dependent energy function
- A state-dependent skew-symmetric matrix
- A state-dependent diffusion matrix
The function below converts these matrices into the drift and diffusion which,
if followed, will converge to a stationary distribution whose marginal
log-density is equal to the negative energy function (plus a constant).
def StationaryDiffusion(m:Type, v:Type, given () (LinearEndo m v)) -> Type = (v)->m
def NegEnergyFunc(v:Type) -> Type = (v)->Float
def SkewSymmetricFunc(n|Ix, v:Type) -> Type = (v)->((i:n)=>(..<i)=>Float)
def StationarySDEParts(n|Ix, v|VSpace, given () (LinearEndo n v)) -> Type =
(NegEnergyFunc v, SkewSymmetricFunc n v, StationaryDiffusion n v)
def stationary_SDE_parts_to_SDE(
parts:StationarySDEParts n (n=>v)
) -> (SDE n (n=>v)) given (n|Ix, v|Mul|VSpace) (LinearEndo n (n=>v)) =
(neg_energy_func, skew_symm_map, diffusion_func) = parts
drift = \state time.
diffusion_prod = \vec.
cur_diffusion_root = diffusion_func state
0.5 .* (apply cur_diffusion_root (apply cur_diffusion_root vec))
neg_energy_grad = (grad neg_energy_func) state
skew_term = skew_symmetric_prod (skew_symm_map state) neg_energy_grad
diff_term = diffusion_prod neg_energy_grad
gammapart = \state.
skew_term' = skew_symmetric_prod (skew_symm_map state) one
diff_term' = diffusion_prod one
skew_term' + diff_term'
gamma_term = jvp gammapart state one
skew_term + diff_term + gamma_term
diffusion = \state time. diffusion_func state
(drift, diffusion)