Basis Function Regression

import plot
-- Conjugate gradients solver
def solve {m} (mat:m=>m=>Float) (b:m=>Float) : m=>Float = x0 = for i:m. 0.0 ax = mat **. x0 r0 = b - ax (xOut, _, _) = fold (x0, r0, r0) $ \s:m (x, r, p). ap = mat **. p alpha = vdot r r / vdot p ap x' = x + alpha .* p r' = r - alpha .* ap beta = vdot r' r' / (vdot r r + 0.000001) p' = r' + beta .* p (x', r', p') xOut

Make some synthetic data

Nx = Fin 100
noise = 0.1
[k1, k2] = split_key (new_key 0)
def trueFun (x:Float) : Float = x + sin (5.0 * x)
xs : Nx=>Float = for i. rand (ixkey k1 i)
ys : Nx=>Float = for i. trueFun xs.i + noise * randn (ixkey k2 i)
:html show_plot $ xy_plot xs ys

Implement basis function regression

def regress {d n} (featurize: Float -> d=>Float) (xRaw:n=>Float) (y:n=>Float) : d=>Float = x = map featurize xRaw xT = transpose x solve (xT ** x) (xT **. y)

Fit a third-order polynomial

def poly {d} [Ix d] (x:Float) : d=>Float = for i. pow x (n_to_f (ordinal i))
params : (Fin 4)=>Float = regress poly xs ys
def predict (x:Float) : Float = vdot params (poly x)
xsTest = linspace (Fin 200) 0.0 1.0
:html show_plot $ xy_plot xsTest (map predict xsTest)

RMS error

def rmsErr {n} (truth:n=>Float) (pred:n=>Float) : Float = sqrt $ mean for i. sq (pred.i - truth.i)
:p rmsErr ys (map predict xs)
0.239831
def lrToEither {a b} : Iso {left:a | right:b} (a|b) = fwd = \v. case v of {|left=x|} -> Left x {|right=x|} -> Right x bwd = \v. case v of Left x -> {|left=x|} Right x -> {|right=x|} MkIso {fwd, bwd}
instance Ix {left:n | right:m} given {n m} [Ix n, Ix m] size = size n + size m ordinal = ordinal <<< app_iso lrToEither unsafe_from_ordinal = unsafe_from_ordinal _ >>> rev_iso lrToEither
def tabCat {n m a} (xs:n=>a) (ys:m=>a) : ({left:n|right:m})=>a = for i. case i of {| left = i' |} -> xs.i' {| right = i' |} -> ys.i'
xsPlot = tabCat xs xsTest
ysPlot = tabCat ys $ map predict xsTest
:html show_plot $ xyc_plot xsPlot ysPlot $ for i. case i of {| left = _ |} -> 0.0 {| right = _ |} -> 1.0