ODE Integrator

Integrate systems of ordinary differential equations (ODEs) using the Dormand-Prince method for adaptive integration stepsize calculation. This version is a port of the Jax implementation. One difference is that it uses a lower-triangular matrix type for the Butcher tableau, and so avoids zero-padding everywhere.

import plot
Time = Float
-- Should this go in the prelude?
def length {d} (x: d=>Float) : Float = sqrt $ sum for i. sq x.i
def (./) {d} (x: d=>Float) (y: d=>Float) : d=>Float = for i. x.i / y.i
def fit_4th_order_polynomial {v} [VSpace v] (z0:v) (z1:v) (z_mid:v) (dz0:v) (dz1:v) (dt:Time) : (Fin 5)=>v = -- dz0 and dz1 are gradient evaluations. a = -2. * dt .* dz0 + 2. * dt .* dz1 - 8. .* z0 - 8. .* z1 + 16. .* z_mid b = 5. * dt .* dz0 - 3. * dt .* dz1 + 18. .* z0 + 14. .* z1 - 32. .* z_mid c = -4. * dt .* dz0 + dt .* dz1 - 11. .* z0 - 5. .* z1 + 16. .* z_mid d = dt .* dz0 e = z0 [a, b, c, d, e] -- polynomial coefficients.
dps_c_mid = [ 6025192743. /30085553152. /2., 0., 51252292925. /65400821598. /2., -2691868925. /45128329728. /2., 187940372067. /1594534317056. /2., -1776094331. /19743644256. /2., 11237099. /235043384. /2.]
def interp_fit_dopri {v} [VSpace v] (z0:v) (z1:v) (k:(Fin 7)=>v) (dt:Time) : (Fin 5)=>v = -- Fit a polynomial to the results of a Runge-Kutta step. z_mid = z0 + dt .* (dot dps_c_mid k) fit_4th_order_polynomial z0 z1 z_mid k.(0@_) k.(1@_) dt
def initial_step_size {d} (fun:d=>Float -> Time -> d=>Float) (t0:Time) (z0:d=>Float) (order:Int) (rtol:Float) (atol:Float) (f0:d=>Float) : Time = -- Algorithm from: E. Hairer, S. P. Norsett G. Wanner, -- Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4. scale = for i. atol + ((abs z0.i) * rtol) d0 = length (z0 ./ scale) d1 = length (f0 ./ scale) h0 = select ((d0 < 1.0e-5) || (d1 < 1.0e-5)) 1.0e-6 (0.01 * (d0 / d1)) z1 = z0 + h0 .* f0 f1 = fun z1 (t0 + h0) d2 = (length ((f1 - f0) ./ scale)) / h0 left = max 1.0e-6 (h0 * 1.0e-3) right = pow (0.01 / (d1 + d2)) (1. / ((i_to_f order) + 1.)) h1 = select ((d1 <= 1.0e-15) && (d2 <= 1.0e-15)) left right min (100. * h0) h1
-- Dopri5 Butcher tableaux
alpha = [1. / 5., 3. / 10., 4. / 5., 8. / 9., 1., 1.]
beta : ((i:Fin 6)=>(..i)=>Float) = -- triangular array type. [[1. / 5.], [3. / 40., 9. / 40.], [44. / 45., -56. / 15., 32.0 / 9.], [19372. / 6561., -25360. / 2187., 64448. / 6561., -212. / 729.], [9017./3168., -355./33., 46732./5247., 49./176., -5103./18656.], [35. / 384., 0., 500. / 1113., 125./192., -2187./6784., 11./84.]]
c_sol = [35. /384., 0., 500. /1113., 125. /192., -2187. /6784., 11. / 84., 0.]
c_error = [35. / 384. - 1951. / 21600., 0., 500. / 1113. - 22642. / 50085., 125. / 192. - 451. / 720., -2187. / 6784. + 12231. / 42400., 11. / 84. - 649. / 6300., -1. / 60.]
def runge_kutta_step {v} [VSpace v] (func:v->Time->v) (z0:v) (f0:v) (t0:Time) (dt:Time) : (v & v & v & (Fin 7)=>v) = evals_init = yield_state zero \r. r!(0@_) := f0 evals_filled = yield_state evals_init \func_evals. for i:(Fin 6). cur_evals = for j:(..i). get func_evals!((ordinal j)@_) ti = t0 + dt .* alpha.i zi = z0 + dt .* dot beta.i cur_evals func_evals!(((ordinal i) + 1)@_) := func zi ti z_last = z0 + dt .* dot c_sol evals_filled z_last_error = dt .* dot c_error evals_filled f_last = evals_filled.(6@_) (z_last, f_last, z_last_error, evals_filled)
def error_ratio {d} (error_estimates:d=>Float) (rtol:Float) (atol:Float) (z0:d=>Float) (z1:d=>Float) : Float = err_tols = for i. atol + rtol * (max (abs z0.i) (abs z1.i)) err_ratios = error_estimates ./ err_tols mean for i. sq err_ratios.i
def optimal_step_size (last_step:Time) (mean_error_ratio:Float) : Time = safety = 0.9 ifactor = 10. dfactor = 0.2 order = 5. dfactor = select (mean_error_ratio < 1.) 1. dfactor err_ratio = sqrt mean_error_ratio minfac = min ( (pow err_ratio (1. / order)) / safety) (1. / dfactor) factor = max (1. / ifactor) minfac select (mean_error_ratio == 0.) (last_step * ifactor) (last_step / factor)
def odeint {d n} (func: d=>Float -> Time -> d=>Float) (z0: d=>Float) (t0: Time) (times: n=>Time) : n=>d=>Float = -- Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation. -- Args: -- func: time derivative of the solution z at time t. -- z0: the initial value for the state. -- t: times for evaluation. values must be strictly increasing. -- Returns: -- Values of the solution at each time point in times. rtol = 1.4e-8 -- relative local error tolerance for solver. atol = 1.4e-8 -- absolute local error tolerance for solver. max_iters = 10000 integrate_to_next_time = \i init_carry. target_t = times.i continue_condition = \(_, _, t, dt, _, _). -- State of solver: (next state, next f, next time, dt, t, interp coeffs) -- def State (v:Type) : Type = (v & v & Time & Time & Time & (Fin 5)=>v) -- This ended up being unnecessary to spell anywhere, but was -- useful for debugging. (t < target_t) && (dt > 0.0) && (ordinal i < max_iters) possible_step = \(z, f, t, dt, last_t, interp_coeff). (next_z, next_f, next_z_error, k) = runge_kutta_step func z f t dt next_t = t + dt ratio = error_ratio next_z_error rtol atol z next_z new_interp_coeff = interp_fit_dopri z next_z k dt new_dt = optimal_step_size dt ratio move_state = (next_z, next_f, next_t, new_dt, t, new_interp_coeff) stay_state = ( z, f, t, new_dt, last_t, interp_coeff) select (ratio <= 1.0) move_state stay_state -- Take steps until we pass target_t new_state = yield_state init_carry \stateRef. iter \_. state = get stateRef if continue_condition state then stateRef := possible_step state Continue else Done () (_, _, t, _, last_t, interp_coeff) = new_state -- Interpolate to the target time. relative_output_time = (target_t - last_t) / (t - last_t) z_target = evalpoly interp_coeff relative_output_time (new_state, z_target) f0 = func z0 t0 init_step = initial_step_size func t0 z0 4 rtol atol f0 init_interp_coeff = zero -- dummy vals init_carry = (z0, f0, t0, init_step, t0, init_interp_coeff) snd $ scan init_carry integrate_to_next_time

Example: Linear dynamics

def myDyn {a} (z : a) (t:Time) : a = z
z0 = [1.0]
t0 = 0.0
t1 = [1.0]
approx_e = odeint myDyn z0 t0 t1
:p approx_e
[[2.719158]]
exact_e = [[exp 1.0]]
:p (approx_e - exact_e) -- amount of numerical error
[[0.000877]]
times = linspace (Fin 100) 0.00001 1.0
ys = odeint myDyn z0 t0 times
:html show_plot $ y_plot for i. ys.i.(from_ordinal _ 0)