Dex prelude

Runs before every Dex program unless an alternative is provided with --prelude.

Essentials

Primitive Types

Unit = %UnitType
Type = %TyKind
Effects = %EffKind
Fields = %LabeledRowKind
Int64 = %Int64
Int32 = %Int32
Float64 = %Float64
Float32 = %Float32
Word8 = %Word8
Word32 = %Word32
Word64 = %Word64
Byte = Word8
Char = Byte
Label = %Label
RawPtr : Type = %Word8Ptr
Int = Int32
Float = Float32

Casting

def internal_cast {a} (b:Type) (x:a) : b = %cast b x
def f64_to_f (x : Float64) : Float = internal_cast _ x
def f32_to_f (x : Float32) : Float = internal_cast _ x
def f_to_f64 (x : Float) : Float64 = internal_cast _ x
def f_to_f32 (x : Float) : Float32 = internal_cast _ x
def i64_to_i (x : Int64) : Int = internal_cast _ x
def i32_to_i (x : Int32) : Int = internal_cast _ x
def w8_to_i (x : Word8) : Int = internal_cast _ x
def i_to_i64 (x : Int) : Int64 = internal_cast _ x
def i_to_i32 (x : Int) : Int32 = internal_cast _ x
def i_to_w8 (x : Int) : Word8 = internal_cast _ x
def i_to_w32 (x : Int) : Word32 = internal_cast _ x
def i_to_w64 (x : Int) : Word64 = internal_cast _ x
def w32_to_w64 (x : Word32): Word64 = internal_cast _ x
def i_to_f (x:Int) : Float = internal_cast _ x
def f_to_i (x:Float) : Int = internal_cast _ x
def raw_ptr_to_i64 (x:RawPtr) : Int64 = internal_cast _ x
Nat = %Nat
NatRep = Word32
def nat_to_rep (x : Nat) : NatRep = %projNewtype x
def rep_to_nat (x : NatRep) : Nat = %newtype Nat x
def n_to_w8 (x : Nat) : Word8 = internal_cast _ $ nat_to_rep x
def n_to_w32 (x : Nat) : Word32 = internal_cast _ $ nat_to_rep x
def n_to_w64 (x : Nat) : Word64 = internal_cast _ $ nat_to_rep x
def n_to_i32 (x : Nat) : Int32 = internal_cast _ $ nat_to_rep x
def n_to_i64 (x : Nat) : Int64 = internal_cast _ $ nat_to_rep x
def n_to_f32 (x : Nat) : Float32 = internal_cast _ $ nat_to_rep x
def n_to_f64 (x : Nat) : Float64 = internal_cast _ $ nat_to_rep x
def n_to_f (x : Nat) : Float = internal_cast _ $ nat_to_rep x
def w8_to_n (x : Word8) : Nat = rep_to_nat $ internal_cast _ x
def w32_to_n (x : Word32) : Nat = rep_to_nat $ internal_cast _ x
def w64_to_n (x : Word64) : Nat = rep_to_nat $ internal_cast _ x
def i32_to_n (x : Int32) : Nat = rep_to_nat $ internal_cast _ x
def i64_to_n (x : Int64) : Nat = rep_to_nat $ internal_cast _ x
def f32_to_n (x : Float32) : Nat = rep_to_nat $ internal_cast _ x
def f64_to_n (x : Float64) : Nat = rep_to_nat $ internal_cast _ x
def f_to_n (x : Float) : Nat = rep_to_nat $ internal_cast _ x
interface FromUnsignedInteger a from_unsigned_integer : Word64 -> a
instance FromUnsignedInteger Word8 from_unsigned_integer = \x. internal_cast _ x
instance FromUnsignedInteger Word32 from_unsigned_integer = \x. internal_cast _ x
instance FromUnsignedInteger Word64 from_unsigned_integer = \x. x
instance FromUnsignedInteger Int32 from_unsigned_integer = \x. internal_cast _ x
instance FromUnsignedInteger Int64 from_unsigned_integer = \x. internal_cast _ x
instance FromUnsignedInteger Float32 from_unsigned_integer = \x. internal_cast _ x
instance FromUnsignedInteger Float64 from_unsigned_integer = \x. internal_cast _ x
instance FromUnsignedInteger Nat from_unsigned_integer = w64_to_n
interface FromInteger a from_integer : Int64 -> a
instance FromInteger Float32 from_integer = \x. internal_cast _ x
instance FromInteger Int32 from_integer = \x. internal_cast _ x
instance FromInteger Float64 from_integer = \x. internal_cast _ x
instance FromInteger Int64 from_integer = \x. x

Bitwise operations

interface Bits a (.<<.) : a -> Int -> a (.>>.) : a -> Int -> a (.|.) : a -> a -> a (.&.) : a -> a -> a (.^.) : a -> a -> a
instance Bits Word8 (.<<.) = \x y. %shl x (i_to_w8 y) (.>>.) = \x y. %shr x (i_to_w8 y) (.|.) = \x y. %or x y (.&.) = \x y. %and x y (.^.) = \x y. %xor x y
instance Bits Word32 (.<<.) = \x y. %shl x (i_to_w32 y) (.>>.) = \x y. %shr x (i_to_w32 y) (.|.) = \x y. %or x y (.&.) = \x y. %and x y (.^.) = \x y. %xor x y
instance Bits Word64 (.<<.) = \x y. %shl x (i_to_w64 y) (.>>.) = \x y. %shr x (i_to_w64 y) (.|.) = \x y. %or x y (.&.) = \x y. %and x y (.^.) = \x y. %xor x y
def low_word (x : Word64) : Word32 = internal_cast _ (x .>>. 32)
def high_word (x : Word64) : Word32 = internal_cast _ x

Basic Arithmetic

Add

Things that can be added. This defines the Add group and its operators.

interface Add a add : a -> a -> a zero : a
interface [Add a] Sub a sub : a -> a -> a
def (+) {a} [Add a] : a -> a -> a = add
def (-) {a} [Sub a] : a -> a -> a = sub
instance Add Float64 add = \x y. %fadd x y zero = 0
instance Sub Float64 sub = \x y. %fsub x y
instance Add Float32 add = \x y. %fadd x y zero = 0
instance Sub Float32 sub = \x y. %fsub x y
instance Add Int64 add = \x y. %iadd x y zero = 0
instance Sub Int64 sub = \x y. %isub x y
instance Add Int32 add = \x y. %iadd x y zero = 0
instance Sub Int32 sub = \x y. %isub x y
instance Add Word8 add = \x y. %iadd x y zero = 0
instance Sub Word8 sub = \x y. %isub x y
instance Add Word32 add = \x y. %iadd x y zero = 0
instance Sub Word32 sub = \x y. %isub x y
instance Add Word64 add = \x y. %iadd x y zero = 0
instance Sub Word64 sub = \x y. %isub x y
instance Add Nat add = \x y. rep_to_nat $ %iadd (nat_to_rep x) (nat_to_rep y) zero = 0
instance Add Unit add = \x y. () zero = ()
instance Sub Unit sub = \x y. ()
instance Add (n->a) given {n a} [Add a] add = \f g. \x. (f x) + (g x) zero = \x. zero
instance Sub (n->a) given {n a} [Sub a] sub = \f g. \x. (f x) - (g x)

Mul

Things that can be multiplied. This defines the Mul Monoid, and its operator.

interface Mul a mul : a -> a -> a one : a
def (*) {a} [Mul a] : a -> a -> a = mul
instance Mul Float64 mul = \x y. %fmul x y one = f_to_f64 1.0
instance Mul Float32 mul = \x y. %fmul x y one = f_to_f32 1.0
instance Mul Int64 mul = \x y. %imul x y one = 1
instance Mul Int32 mul = \x y. %imul x y one = 1
instance Mul Word8 mul = \x y. %imul x y one = 1
instance Mul Word32 mul = \x y. %imul x y one = 1
instance Mul Word64 mul = \x y. %imul x y one = 1
instance Mul Nat mul = \x y. rep_to_nat $ %imul (nat_to_rep x) (nat_to_rep y) one = 1
instance Mul Unit mul = \x y. () one = ()
instance Mul (n->a) given {n a} [Mul a] mul = \f g. \x. (f x) * (g x) one = \x. one

Integral

Integer-like things.

interface Integral a idiv : a->a->a rem : a->a->a
instance Integral Int64 idiv = \x y. %idiv x y rem = \x y. %irem x y
instance Integral Int32 idiv = \x y. %idiv x y rem = \x y. %irem x y
instance Integral Word8 idiv = \x y. %idiv x y rem = \x y. %irem x y
instance Integral Word32 idiv = \x y. %idiv x y rem = \x y. %irem x y
instance Integral Word64 idiv = \x y. %idiv x y rem = \x y. %irem x y
instance Integral Nat idiv = \x y. rep_to_nat $ %idiv (nat_to_rep x) (nat_to_rep y) rem = \x y. rep_to_nat $ %irem (nat_to_rep x) (nat_to_rep y)

Fractional

Rational-like things. Includes floating point and two field rational representations.

interface Fractional a divide : a -> a -> a
instance Fractional Float64 divide = \x y. %fdiv x y
instance Fractional Float32 divide = \x y. %fdiv x y

Index set interface and instances

interface Ix n size n : Nat ordinal : n -> Nat unsafe_from_ordinal n : Nat -> n
def Fin (n:Nat) : Type = %Fin n
-- version of subtraction on Nats that clamps at zero
def (-|) (x: Nat) (y:Nat) : Nat = x' = nat_to_rep x y' = nat_to_rep y requires_clamp = %ilt x' y' rep_to_nat $ %select requires_clamp 0 (%isub x' y')
def unsafe_nat_diff (x:Nat) (y:Nat) : Nat = x' = nat_to_rep x y' = nat_to_rep y rep_to_nat $ %isub x' y'
-- `(i..)` parses as `RangeFrom _ i`
data RangeFrom q:Type i:q = UnsafeMkRangeFrom Nat
-- `(i<..)` parses as `RangeFromExc _ i`
data RangeFromExc q:Type i:q = UnsafeMkRangeFromExc Nat
-- `(..i)` parses as `RangeTo _ i`
data RangeTo q:Type i:q = UnsafeMkRangeTo Nat
-- `(..<i)` parses as `RangeToExc _ i`
data RangeToExc q:Type i:q = UnsafeMkRangeToExc Nat
instance Ix (RangeFrom q i) given {q:Type} {i:q} [Ix q] size = unsafe_nat_diff (size q) (ordinal i) ordinal = \(UnsafeMkRangeFrom j). j unsafe_from_ordinal = \j. UnsafeMkRangeFrom j
instance Ix (RangeFromExc q i) given {q:Type} {i:q} [Ix q] size = unsafe_nat_diff (size q) (ordinal i + 1) ordinal = \(UnsafeMkRangeFromExc j). j unsafe_from_ordinal = \j. UnsafeMkRangeFromExc j
instance Ix (RangeTo q i) given {q:Type} {i:q} [Ix q] size = ordinal i + 1 ordinal = \(UnsafeMkRangeTo j). j unsafe_from_ordinal = \j. UnsafeMkRangeTo j
instance Ix (RangeToExc q i) given {q:Type} {i:q} [Ix q] size = ordinal i ordinal = \(UnsafeMkRangeToExc j). j unsafe_from_ordinal = \j. UnsafeMkRangeToExc j
instance Ix Unit size = 1 ordinal = \_. 0 unsafe_from_ordinal = \_. ()
def iota (n:Type) [Ix n] : n=>Nat = view i. ordinal i

Arithmetic instances for table types

instance Add (n=>a) given {a n} [Add a] add = \xs ys. for i. xs.i + ys.i zero = for _. zero
instance Sub (n=>a) given {a n} [Sub a] sub = \xs ys. for i. xs.i - ys.i
instance Add ((i:n) => (i..) => a) given {a n} [Add a] -- Upper triangular tables add = \xs ys. for i. xs.i + ys.i zero = for _. zero
instance Sub ((i:n) => (i..) => a) given {a n} [Sub a] -- Upper triangular tables sub = \xs ys. for i. xs.i - ys.i
instance Add ((i:n) => (..i) => a) given {a n} [Add a] -- Lower triangular tables add = \xs ys. for i. xs.i + ys.i zero = for _. zero
instance Sub ((i:n) => (..i) => a) given {a n} [Sub a] -- Lower triangular tables sub = \xs ys. for i. xs.i - ys.i
instance Add ((i:n) => (..<i) => a) given {a n} [Add a] add = \xs ys. for i. xs.i + ys.i zero = for _. zero
instance Sub ((i:n) => (..<i) => a) given {a n} [Sub a] sub = \xs ys. for i. xs.i - ys.i
instance Add ((i:n) => (i<..) => a) given {a n} [Add a] add = \xs ys. for i. xs.i + ys.i zero = for _. zero
instance Sub ((i:n) => (i<..) => a) given {a n} [Sub a] sub = \xs ys. for i. xs.i - ys.i
instance Mul (n=>a) given {a n} [Mul a] mul = \xs ys. for i. xs.i * ys.i one = for _. one

Basic polymorphic functions and types

def (&) (a:Type) (b:Type) : Type = %PairType a b
def (,) {a b} (x:a) (y:b) : (a & b) = %pair x y
def fst {a b} ((x, _): (a & b)) : a = x
def snd {a b} ((_, y): (a & b)) : b = y
def swap {a b} ((x, y):(a&b)) : (b&a) = (y, x)
instance Add (a & b) given {a b} [Add a, Add b] add = \(a, b) (c, d). ( (a + c), (b + d)) zero = (zero, zero)
instance Sub (a & b) given {a b} [Sub a, Sub b] sub = \(a, b) (c, d). ( (a - c), (b - d))
instance Ix (a & b) given {a b} [Ix a, Ix b] size = size a * size b ordinal = \(i, j). (ordinal i * size b) + ordinal j unsafe_from_ordinal = \o. bs = size b (unsafe_from_ordinal a (idiv o bs), unsafe_from_ordinal b (rem o bs))
def (<<<) {a b c} (f: b -> c) (g: a -> b) : a -> c = \x. f (g x)
def (>>>) {a b c} (g: a -> b) (f: b -> c) : a -> c = \x. f (g x)
def flip {a b c} (f: a -> b -> c) : (b -> a -> c) = \x y. f y x
def uncurry {a b c} (f: a -> b -> c) : (a & b) -> c = \(x,y). f x y
def const {a b} (x: a) (_: b) : a = x

Vector spaces

interface [Add a, Sub a] VSpace a scale_vec : Float -> a -> a
def (.*) {a} [VSpace a] : Float -> a -> a = scale_vec
def (*.) {a} [VSpace a] : a -> Float -> a = flip scale_vec
def (/) {a} [VSpace a] (v:a) (s:Float) : a = divide 1.0 s .* v
def neg {a} [VSpace a] (v:a) : a = (-1.0) .* v
instance VSpace Float scale_vec = \x y. x * y
instance VSpace (n=>a) given {a n} [VSpace a] scale_vec = \s xs. for i. s .* xs.i
instance VSpace (a & b) given {a b} [VSpace a, VSpace b] scale_vec = \ s (a, b) . (scale_vec s a, scale_vec s b)
instance VSpace ((i:n) => (..i) => a) given {a n} [VSpace a] scale_vec = \s xs. for i. s .* xs.i
instance VSpace ((i:n) => (i..) => a) given {a n} [VSpace a] scale_vec = \s xs. for i. s .* xs.i
instance VSpace ((i:n) => (..<i) => a) given {a n} [VSpace a] scale_vec = \s xs. for i. s .* xs.i
instance VSpace ((i:n) => (i<..) => a) given {a n} [VSpace a] scale_vec = \s xs. for i. s .* xs.i
instance VSpace (n->a) given {a n} [VSpace a] scale_vec = \s f. \x. s .* (f x)
instance VSpace Unit scale_vec = \_ _. ()

Boolean type

data Bool = False True
def b_to_w8 (x : Bool) : Word8 = %dataConTag x
def w8_to_b (x : Word8) : Bool = %toEnum Bool x
def (&&) (x:Bool) (y:Bool) : Bool = x' = b_to_w8 x y' = b_to_w8 y w8_to_b $ %and x' y'
def (||) (x:Bool) (y:Bool) : Bool = x' = b_to_w8 x y' = b_to_w8 y w8_to_b $ %or x' y'
def not (x:Bool) : Bool = x' = b_to_w8 x w8_to_b $ %not x'

More Boolean operations

TODO: move these with the others?

def select {a} (p:Bool) (x:a) (y:a) : a = case p of True -> x False -> y
def b_to_i (x:Bool) : Int = w8_to_i $ b_to_w8 x
def b_to_n (x:Bool) : Nat = w8_to_n $ b_to_w8 x
def b_to_f (x:Bool) : Float = i_to_f (b_to_i x)

Ordering

TODO: move this down to with Ord?

data Ordering = LT EQ GT
def o_to_w8 (x : Ordering) : Word8 = %dataConTag x

Sum types

A sum type, or tagged union can hold values from a fixed set of types, distinguished by tags. For those familiar with the C language, they can be though of as a combination of an enum with a union. Here we define several basic kinds, and some operators on them.

data Maybe a = Nothing Just a
def is_nothing {a} (x:Maybe a) : Bool = case x of Nothing -> True Just _ -> False
def is_just {a} (x:Maybe a) : Bool = not $ is_nothing x
def maybe {a b} (d: b) (f : (a -> b)) (x: Maybe a) : b = case x of Nothing -> d Just x' -> f x'
data (|) a b = Left a Right b
instance Ix (a | b) given {a b} [Ix a, Ix b] size = size a + size b ordinal = \i. case i of Left ai -> ordinal ai Right bi -> ordinal bi + size a unsafe_from_ordinal = \o. as = nat_to_rep $ size a o' = nat_to_rep o -- TODO: Reshuffle the prelude to be able to use (<) here case w8_to_b $ %ilt o' as of True -> Left $ unsafe_from_ordinal a o -- TODO: Reshuffle the prelude to be able to use `diff_nat` here False -> Right $ unsafe_from_ordinal b (rep_to_nat (%isub o' as))

Subtraction on Nats

-- TODO: think more about the right API here

def unsafe_i_to_n (x:Int) : Nat = rep_to_nat $ internal_cast _ x
def n_to_i (x:Nat) : Int = internal_cast _ $ nat_to_rep x
def i_to_n (x:Int) : Maybe Nat = if w8_to_b $ %ilt x (0::Int) then Nothing else Just $ unsafe_i_to_n x

Fencepost index sets

data Post segment:Type = UnsafeMkPost Nat
instance Ix (Post segment) given {segment} [Ix segment] size = size segment + 1 ordinal = \(UnsafeMkPost i). i unsafe_from_ordinal = \i. UnsafeMkPost i
def left_post {n} [Ix n] (i:n) : Post n = unsafe_from_ordinal (Post n) (ordinal i)
def right_post {n} [Ix n] (i:n) : Post n = unsafe_from_ordinal (Post n) (ordinal i + 1)
interface [Ix n] NonEmpty n first_ix : n
def last_ix {n} [NonEmpty n] : n = unsafe_from_ordinal _ $ unsafe_i_to_n $ n_to_i (size n) - 1
instance NonEmpty (Post n) given {n} [Ix n] first_ix = unsafe_from_ordinal (Post n) 0
instance NonEmpty Unit first_ix = unsafe_from_ordinal _ 0

Monoid

A monoid is a things that have an associative binary operator and an identity element. This is a very useful and general calls of things. It includes:

  • Addition and Multiplication of Numbers
  • Boolean Logic
  • Concatenation of Lists (including strings) Monoids support fold operations, and similar.
interface Monoid a mempty : a mcombine : a -> a -> a -- can't use `<>` just for parser reasons?
def (<>) {a} [Monoid a] : a -> a -> a = mcombine
instance Monoid (n=>a) given {a n} [Monoid a] mempty = for i. mempty mcombine = \x y. for i. mcombine x.i y.i
named-instance AndMonoid : Monoid Bool mempty = True mcombine = (&&)
named-instance OrMonoid : Monoid Bool mempty = False mcombine = (||)
named-instance AddMonoid : Monoid a given (a:Type) [Add a] mempty = zero mcombine = add
named-instance MulMonoid : Monoid a given (a:Type) [Mul a] mempty = one mcombine = mul

Effects

def Ref (r:Type) (a:Type) : Type = %Ref r a
def get {h s} (ref:Ref h s) : {State h} s = %get ref
def (:=) {h s} (ref:Ref h s) (x:s) : {State h} Unit = %put ref x
def ask {h r} (ref:Ref h r) : {Read h} r = %ask ref
data AccumMonoidData h w = UnsafeMkAccumMonoidData b:Type (Monoid b)
interface AccumMonoid h w getAccumMonoidData : AccumMonoidData h w
instance AccumMonoid h (n=>w) given {n h w} [Ix n] [am : AccumMonoid h w] getAccumMonoidData = (UnsafeMkAccumMonoidData b bm) = %projMethod0 am UnsafeMkAccumMonoidData b bm
def (+=) {h w} [am:AccumMonoid h w] (ref:Ref h w) (x:w) : {Accum h} Unit = (UnsafeMkAccumMonoidData b bm) = %projMethod0 am empty = %projMethod0 bm combine = %projMethod1 bm %mextend ref empty (\x y. combine x y) x
def (!) {h n a} (ref:Ref h (n=>a)) (i:n) : Ref h a = %indexRef ref i
def fst_ref {h a b} (ref: Ref h (a & b)) : Ref h a = %fstRef ref
def snd_ref {h a b} (ref: Ref h (a & b)) : Ref h b = %sndRef ref
def run_reader {r a eff} (init:r) (action: ((h:Type) ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = def explicitAction (h':Type) (ref:Ref h' r) : {Read h'|eff} a = action ref %runReader init explicitAction
def with_reader {r a eff} (init:r) (action: ((h:Type) ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = run_reader init action
def MonoidLifter (b:Type) (w:Type) : Type = (h:Type) ?-> AccumMonoid h b ?=> AccumMonoid h w
def run_accum {a b w eff} [mlift:MonoidLifter b w] (bm:Monoid b) (action: ((h:Type) ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} (a & w) = empty = %projMethod0 bm combine = %projMethod1 bm def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = accumMonoidData : (AccumMonoidData h' b) = UnsafeMkAccumMonoidData b bm accumBaseMonoid = %explicitDict (AccumMonoid h' b) accumMonoidData action' = %explicitApply (%explicitApply action h') accumBaseMonoid action' ref %runWriter empty (\x y. combine x y) explicitAction
def yield_accum {a b w eff} [mlift:MonoidLifter b w] (m:Monoid b) (action: ((h:Type) ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} w = snd $ run_accum m action
def run_state {a s eff} (init:s) (action: (h:Type) ?-> Ref h s -> {State h |eff} a) : {|eff} (a & s) = def explicitAction (h':Type) (ref:Ref h' s) : {State h'|eff} a = action ref %runState init explicitAction
def with_state {a s eff} (init:s) (action: (h:Type) ?-> Ref h s -> {State h |eff} a) : {|eff} a = fst $ run_state init action
def yield_state {a s eff} (init:s) (action: (h:Type) ?-> Ref h s -> {State h |eff} a) : {|eff} s = snd $ run_state init action
def unsafe_io {a eff} (f: Unit -> {IO|eff} a) : {|eff} a = %runIO f
def unreachable {a} (():Unit) : a = unsafe_io do %throwError a

Type classes

Eq and Ord

Eq

Equatable. Things that we can tell if they are equal or not to other things.

interface Eq a (==) : a -> a -> Bool
def (/=) {a} [Eq a] (x:a) (y:a) : Bool = not $ x == y

Ord

Orderable / Comparable. Things that can be place in a total order. i.e. things that can be compared to other things to find if larger, smaller or equal in value.

We take the standard false-hood and pretend that this applies to Floats, even though strictly speaking this not true as our floats follow IEEE754, and thus have NaN < 1.0 == false and 1.0 < NaN == false.

interface [Eq a] Ord a (>) : a -> a -> Bool (<) : a -> a -> Bool
def (<=) {a} [Ord a] (x:a) (y:a) : Bool = x<y || x==y
def (>=) {a} [Ord a] (x:a) (y:a) : Bool = x>y || x==y
instance Eq Float64 (==) = \x y. w8_to_b $ %feq x y
instance Eq Float32 (==) = \x y. w8_to_b $ %feq x y
instance Eq Int64 (==) = \x y. w8_to_b $ %ieq x y
instance Eq Int32 (==) = \x y. w8_to_b $ %ieq x y
instance Eq Word8 (==) = \x y. w8_to_b $ %ieq x y
instance Eq Word32 (==) = \x y. w8_to_b $ %ieq x y
instance Eq Word64 (==) = \x y. w8_to_b $ %ieq x y
instance Eq Bool (==) = \x y. b_to_w8 x == b_to_w8 y
instance Eq Unit (==) = \x y. True
instance Eq (a | b) given {a b} [Eq a, Eq b] (==) = \x y. case x of Left x -> case y of Left y -> x == y Right y -> False Right x -> case y of Left y -> False Right y -> x == y
instance Eq (Maybe a) given {a} [Eq a] (==) = \x y. case x of Just x -> case y of Just y -> x == y Nothing -> False Nothing -> case y of Just y -> False Nothing -> True
instance Eq RawPtr (==) = \x y. raw_ptr_to_i64 x == raw_ptr_to_i64 y
instance Ord Float64 (>) = \x y. w8_to_b $ %fgt x y (<) = \x y. w8_to_b $ %flt x y
instance Ord Float32 (>) = \x y. w8_to_b $ %fgt x y (<) = \x y. w8_to_b $ %flt x y
instance Ord Int64 (>) = \x y. w8_to_b $ %igt x y (<) = \x y. w8_to_b $ %ilt x y
instance Ord Int32 (>) = \x y. w8_to_b $ %igt x y (<) = \x y. w8_to_b $ %ilt x y
instance Ord Word8 (>) = \x y. w8_to_b $ %igt x y (<) = \x y. w8_to_b $ %ilt x y
instance Ord Word32 (>) = \x y. w8_to_b $ %igt x y (<) = \x y. w8_to_b $ %ilt x y
instance Ord Word64 (>) = \x y. w8_to_b $ %igt x y (<) = \x y. w8_to_b $ %ilt x y
instance Ord Unit (>) = \x y. False (<) = \x y. False
instance Eq (a & b) given {a b} [Eq a, Eq b] (==) = \(x1,x2) (y1,y2). x1 == y1 && x2 == y2
instance Ord (a & b) given {a b} [Ord a, Ord b] (>) = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) (<) = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2)
instance Eq Ordering (==) = \x y. o_to_w8 x == o_to_w8 y
instance Eq Nat (==) = \x y. nat_to_rep x == nat_to_rep y
instance Ord Nat (>) = \x y. nat_to_rep x > nat_to_rep y (<) = \x y. nat_to_rep x < nat_to_rep y
-- TODO: we want Eq and Ord for all index sets, not just `Fin n`
instance Eq (Fin n) given {n} (==) = \x y. ordinal x == ordinal y
instance Ord (Fin n) given {n} (>) = \x y. ordinal x > ordinal y (<) = \x y. ordinal x < ordinal y
instance Ix Bool size = 2 ordinal = \b. case b of False -> 0 True -> 1 unsafe_from_ordinal = \i. i > 0
instance Ix (Maybe a) given {a} [Ix a] size = size a + 1 ordinal = \i. case i of Just ai -> ordinal ai Nothing -> size a unsafe_from_ordinal = \o. case o == size a of False -> Just $ unsafe_from_ordinal _ o True -> Nothing
instance NonEmpty Bool first_ix = unsafe_from_ordinal _ 0
instance NonEmpty (a & b) given {a b} [NonEmpty a, NonEmpty b] first_ix = unsafe_from_ordinal _ 0
instance NonEmpty (a|b) given {a b} [Ix b, NonEmpty a] first_ix = unsafe_from_ordinal _ 0
-- The below instance is valid, but causes "multiple candidate dictionaries"
-- errors if both Left and Right are NonEmpty.
-- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b]
-- first_ix = unsafe_from_ordinal _ 0
instance NonEmpty (Maybe a) given {a} [Ix a] first_ix = unsafe_from_ordinal _ 0
def scan {a b n} [Ix n] (init:a) (body:n->a->(a&b)) : (a & n=>b) = swap $ run_state init \s. for i. c = get s (c', y) = body i c s := c' y
def fold {a n} [Ix n] (init:a) (body:(n->a->a)) : a = fst $ scan init \i x. (body i x, ())
def compare {a} [Ord a] (x:a) (y:a) : Ordering = if x < y then LT else if x == y then EQ else GT
instance Monoid Ordering mempty = EQ mcombine = \x y. case x of LT -> LT GT -> GT EQ -> y
instance Eq (n=>a) given {a n} [Eq a] (==) = \xs ys. yield_accum AndMonoid \ref. for i. ref += xs.i == ys.i
instance Ord (n=>a) given {a n} [Ord a] (>) = \xs ys. f: Ordering = fold EQ $ \i c. c <> (compare xs.i ys.i) f == GT (<) = \xs ys. f: Ordering = fold EQ $ \i c. c <> (compare xs.i ys.i) f == LT

Subset class

interface Subset a_sub a inject' : a_sub -> a project' : a -> Maybe a_sub unsafe_project' : a -> a_sub
def project {a} (a_sub:Type) [Subset a_sub a] (x:a) : Maybe a_sub = project' x
def unsafe_project {a} (a_sub:Type) [Subset a_sub a] (x:a) : a_sub = unsafe_project' x
def inject {a_sub} (a:Type) [Subset a_sub a] (x:a_sub) : a = inject' x
instance Subset a c given {a b c} [Subset a b, Subset b c] inject' = \x. inject c $ inject b x project' = \x. case project b x of Nothing -> Nothing Just y -> project a y unsafe_project' = \x. unsafe_project a $ unsafe_project b x
def unsafe_project_rangefrom {q:Type} {i:q} [Ix q] (j:q) : RangeFrom q i = UnsafeMkRangeFrom $ unsafe_nat_diff (ordinal j) (ordinal i)
instance Subset (RangeFrom q i) q given {q:Type} {i:q} [Ix q] inject' = \(UnsafeMkRangeFrom j). unsafe_from_ordinal _ $ j + ordinal i project' = \j. j' = ordinal j i' = ordinal i if j' < i' then Nothing else Just $ UnsafeMkRangeFrom $ unsafe_nat_diff j' i' unsafe_project' = \j. UnsafeMkRangeFrom $ unsafe_nat_diff (ordinal j) (ordinal i)
instance Subset (RangeFromExc q i) q given {q:Type} {i:q} [Ix q] inject' = \(UnsafeMkRangeFromExc j). unsafe_from_ordinal _ $ j + ordinal i + 1 project' = \j. j' = ordinal j i' = ordinal i if j' <= i' then Nothing else Just $ UnsafeMkRangeFromExc $ unsafe_nat_diff j' (i' + 1) unsafe_project' = \j. UnsafeMkRangeFromExc $ unsafe_nat_diff (ordinal j) (ordinal i + 1)
instance Subset (RangeTo q i) q given {q:Type} {i:q} [Ix q] inject' = \(UnsafeMkRangeTo j). unsafe_from_ordinal _ j project' = \j. j' = ordinal j i' = ordinal i if j' > i' then Nothing else Just $ UnsafeMkRangeTo j' unsafe_project' = \j. UnsafeMkRangeTo (ordinal j)
instance Subset (RangeToExc q i) q given {q:Type} {i:q} [Ix q] inject' = \(UnsafeMkRangeToExc j). unsafe_from_ordinal _ j project' = \j. j' = ordinal j i' = ordinal i if j' >= i' then Nothing else Just $ UnsafeMkRangeToExc j' unsafe_project' = \j. UnsafeMkRangeToExc (ordinal j)
instance Subset (RangeToExc q i) (RangeTo q i) given {q:Type} {i:q} [Ix q] inject' = \(UnsafeMkRangeToExc j). unsafe_from_ordinal _ j project' = \j. j' = ordinal j i' = ordinal i if j' >= i' then Nothing else Just $ UnsafeMkRangeToExc j' unsafe_project' = \j. UnsafeMkRangeToExc (ordinal j)

Elementary/Special Functions

This is more or less the standard LibM fare. Roughly it lines up with some definitions of the set of Elementary and/or Special. In truth, nothing is elementary or special except that we humans have decided it is. Many, but not all of these functions are Transcendental.

interface Floating a exp : a -> a exp2 : a -> a log : a -> a log2 : a -> a log10 : a -> a log1p : a -> a sin : a -> a cos : a -> a tan : a -> a sinh : a -> a cosh : a -> a tanh : a -> a floor : a -> a ceil : a -> a round : a -> a sqrt : a -> a pow : a -> a -> a lgamma : a -> a erf : a -> a erfc : a -> a
def lbeta {a} [Sub a, Floating a] : a -> a -> a = \x y. lgamma x + lgamma y - lgamma (x + y)
-- Todo: better numerics for very large and small values.
-- Using %exp here to avoid circular definition problems.
def float32_sinh (x:Float32) : Float32 = %fdiv (%fsub (%exp x) (%exp (%fsub 0.0 x))) 2.0
def float32_cosh (x:Float32) : Float32 = %fdiv (%fadd (%exp x) (%exp (%fsub 0.0 x))) 2.0
def float32_tanh (x:Float32) : Float32 = %fdiv (%fsub (%exp x) (%exp (%fsub 0.0 x))) (%fadd (%exp x) (%exp (%fsub 0.0 x)))
-- Todo: unify this with float32 functions.
def float64_sinh (x:Float64) : Float64 = %fdiv (%fsub (%exp x) (%exp (%fsub (f_to_f64 0.0) x))) (f_to_f64 2.0)
def float64_cosh (x:Float64) : Float64 = %fdiv (%fadd (%exp x) (%exp (%fsub (f_to_f64 0.0) x))) (f_to_f64 2.0)
def float64_tanh (x:Float64) : Float64 = (%fdiv (%fsub (%exp x) (%exp (%fsub (f_to_f64 0.0) x))) (%fadd (%exp x) (%exp (%fsub (f_to_f64 0.0) x))))
instance Floating Float64 exp = \x. %exp x exp2 = \x. %exp2 x log = \x. %log x log2 = \x. %log2 x log10 = \x. %log10 x log1p = \x. %log1p x sin = \x. %sin x cos = \x. %cos x tan = \x. %tan x sinh = float64_sinh cosh = float64_cosh tanh = float64_tanh floor = \x. %floor x ceil = \x. %ceil x round = \x. %round x sqrt = \x. %sqrt x pow = \x y. %fpow x y lgamma = \x. %lgamma x erf = \x. %erf x erfc = \x. %erfc x
instance Floating Float32 exp = \x. %exp x exp2 = \x. %exp2 x log = \x. %log x log2 = \x. %log2 x log10 = \x. %log10 x log1p = \x. %log1p x sin = \x. %sin x cos = \x. %cos x tan = \x. %tan x sinh = float32_sinh cosh = float32_cosh tanh = float32_tanh floor = \x. %floor x ceil = \x. %ceil x round = \x. %round x sqrt = \x. %sqrt x pow = \x y. %fpow x y lgamma = \x. %lgamma x erf = \x. %erf x erfc = \x. %erfc x

Raw pointer operations

data Ptr a = MkPtr RawPtr
def cast_ptr {a b} (ptr: Ptr a) : Ptr b = (MkPtr rawPtr) = ptr MkPtr rawPtr
interface Storable a store : Ptr a -> a -> {IO} Unit load : Ptr a -> {IO} a storage_size a : Nat
instance Storable Word8 store = \(MkPtr ptr) x. %ptrStore ptr x load = \(MkPtr ptr) . %ptrLoad ptr storage_size = 1
instance Storable Int32 store = \(MkPtr ptr) x. %ptrStore (internal_cast %Int32Ptr ptr) x load = \(MkPtr ptr) . %ptrLoad (internal_cast %Int32Ptr ptr) storage_size = 4
instance Storable Word32 store = \(MkPtr ptr) x. %ptrStore (internal_cast %Word32Ptr ptr) x load = \(MkPtr ptr) . %ptrLoad (internal_cast %Word32Ptr ptr) storage_size = 4
instance Storable Float32 store = \(MkPtr ptr) x. %ptrStore (internal_cast %Float32Ptr ptr) x load = \(MkPtr ptr) . %ptrLoad (internal_cast %Float32Ptr ptr) storage_size = 4
instance Storable Nat store = \(MkPtr ptr) x. store (MkPtr ptr) $ nat_to_rep x load = \(MkPtr ptr) . rep_to_nat $ load (MkPtr ptr) storage_size = storage_size NatRep
instance Storable (Ptr a) given {a} store = \(MkPtr ptr) (MkPtr x). %ptrStore (internal_cast %PtrPtr ptr) x load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internal_cast %PtrPtr ptr) storage_size = 8 -- TODO: something more portable?
-- TODO: Storable instances for other types
def malloc {a} [Storable a] (n:Nat) : {IO} (Ptr a) = numBytes = storage_size a * n MkPtr $ %alloc (nat_to_rep numBytes)
def free {a} (ptr:Ptr a) : {IO} Unit = (MkPtr ptr') = ptr %free ptr'
def (+>>) {a} [Storable a] (ptr:Ptr a) (i:Nat) : Ptr a = (MkPtr ptr') = ptr i' = nat_to_rep $ i * storage_size a MkPtr $ %ptrOffset ptr' i'
-- TODO: consider making a Storable instance for tables instead
def store_table {a n} [Storable a] (ptr: Ptr a) (tab:n=>a) : {IO} Unit = for_ i. store (ptr +>> ordinal i) tab.i
def memcpy {a} [Storable a] (dest:Ptr a) (src:Ptr a) (n:Nat) : {IO} Unit = for_ i:(Fin n). i' = ordinal i store (dest +>> i') (load $ src +>> i')
-- TODO: generalize these brackets to allow other effects
-- TODO: make sure that freeing happens even if there are run-time errors
def with_alloc {a b} [Storable a] (n:Nat) (action: Ptr a -> {IO} b) : {IO} b = ptr = malloc n result = action ptr free ptr result
def with_table_ptr {a b n} [Storable a] (xs:n=>a) (action : Ptr a -> {IO} b) : {IO} b = ptr <- with_alloc (size n) for i. store (ptr +>> ordinal i) xs.i action ptr
def table_from_ptr {a} [Storable a] (n:Type) [Ix n] (ptr:Ptr a) : {IO} n=>a = for i. load $ ptr +>> ordinal i

Miscellaneous common utilities

pi : Float = 3.141592653589793
def id {a} (x:a) : a = x
def dup {a} (x:a) : (a & a) = (x, x)
def map {a b n eff} (f:a->{|eff} b) (xs: n=>a) : {|eff} (n=>b) = for i. f xs.i
def zip {a b n} (xs:n=>a) (ys:n=>b) : (n=>(a&b)) = view i. (xs.i, ys.i)
def unzip {a b n} (xys:n=>(a&b)) : (n=>a & n=>b) = (map fst xys, map snd xys)
def fanout {a} (n:Type) [Ix n] (x:a) : n=>a = view i. x
def sq {a} [Mul a] (x:a) : a = x * x
def abs {a} [Sub a, Ord a] (x:a) : a = select (x > zero) x (zero - x)
def mod {a} [Add a, Integral a] (x:a) (y:a) : a = rem (y + rem x y) y

Table Operations

instance Floating (n=>a) given {a n} [Floating a] exp = map exp exp2 = map exp2 log = map log log2 = map log2 log10 = map log10 log1p = map log1p sin = map sin cos = map cos tan = map tan sinh = map sinh cosh = map cosh tanh = map tanh floor = map floor ceil = map ceil round = map round sqrt = map sqrt pow = \x y. for i. pow x.i y.i lgamma = map lgamma erf = map erf erfc = map erfc

Axis Restructuring

def axis1 {a b c} (x : a => b => c) : b => a => c = for j. for i. x.i.j
def axis2 {a b c d} (x : a => b => c => d) : c => a => b => d = for k. for i. for j. x.i.j.k
def reindex {a b v} [Ix b] (ixr: b -> a) (tab: a=>v) : b=>v = for i. tab.(ixr i)

Reductions

-- `combine` should be a commutative and associative, and form a
-- commutative monoid with `identity`
def reduce {a n} (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = -- TODO: implement with the accumulator effect fold identity (\i c. combine c xs.i)
-- TODO: call this `scan` and call the current `scan` something else
def scan' {a n} [Ix n] (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x)
-- TODO: allow tables-via-lambda and get rid of this
def fsum {n} (xs:n=>Float) : Float = yield_accum (AddMonoid Float) \ref. for i. ref += xs.i
def sum {n v} [Add v] (xs:n=>v) : v = reduce zero (+) xs
def prod {n v} [Mul v] (xs:n=>v) : v = reduce one (*) xs
def mean {n v} [VSpace v] (xs:n=>v) : v = sum xs / n_to_f (size n)
def std {n v} [Mul v, Sub v, VSpace v, Floating v] (xs:n=>v) : v = sqrt $ mean (map sq xs) - sq (mean xs)
def any {n} (xs:n=>Bool) : Bool = reduce False (||) xs
def all {n} (xs:n=>Bool) : Bool = reduce True (&&) xs

apply_n

def apply_n {a} (n:Nat) (x:a) (f:a -> a) : a = yield_state x \ref. for _:(Fin n). ref := f (get ref)

Linear Algebra

def linspace (n:Type) [Ix n] (low:Float) (high:Float) : n=>Float = dx = (high - low) / n_to_f (size n) for i:n. low + n_to_f (ordinal i) * dx
def transpose {n m a} (x:n=>m=>a) : m=>n=>a = view i j. x.j.i
def vdot {n} (x:n=>Float) (y:n=>Float) : Float = fsum view i. x.i * y.i
def dot {n v} [VSpace v] (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j
-- matmul. Better symbol to use? `@`?
-- TODO: Improve auto-quantification to hoist the Ix n constraint to the type binder
def (**) {l m n} [Ix n] (x: l=>m=>Float) (y: m=>n=>Float) : (l=>n=>Float) = for i k. fsum view j. x.i.j * y.j.k
def (**.) {n m} (mat: n=>m=>Float) (v: m=>Float) : (n=>Float) = for i. vdot mat.i v
def (.**) {n m} (v: m=>Float) (mat: n=>m=>Float) : (n=>Float) = mat **. v
def inner {n m} (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = fsum view (i,j). x.i * mat.i.j * y.j
def eye {n a} [Add a, Mul a, Ix n] : n=>n=>a = for i j. select (ordinal i == ordinal j) one zero

cumulative sum

TODO: Move this to be with reductions? It's a kind of scan.

def cumsum {n a} [Add a] (xs: n=>a) : n=>a = total <- with_state zero for i. newTotal = get total + xs.i total := newTotal newTotal
def cumsum_low {n a} [Add a] (xs: n=>a) : n=>a = total <- with_state zero for i. oldTotal = get total total := oldTotal + xs.i oldTotal

Automatic differentiation

AD operations

-- TODO: add vector space constraints
def linearize {a b} (f:a->b) (x:a) : (b & a --o b) = %linearize (\x. f x) x
def jvp {a b} (f:a->b) (x:a) : a --o b = snd (linearize f x)
def transpose_linear {a b} (f:a --o b) : b --o a = f' : a --o b = \x. f x %linearTranspose f'
def vjp {a b} (f:a->b) (x:a) : (b & b --o a) = (y, df) = linearize f x (y, transpose_linear df)
def grad {a} (f:a->Float) (x:a) : a = snd (vjp f x) 1.0
def deriv (f:Float->Float) (x:Float) : Float = jvp f x 1.0
def deriv_rev (f:Float->Float) (x:Float) : Float = snd (vjp f x) 1.0
-- XXX: Watch out when editing this data type! We depend on its structure
-- deep inside the compiler (mostly in linearization and during rule registration).
data SymbolicTangent a = ZeroTangent SomeTangent a
def someTangent {a} [VSpace a] (x:SymbolicTangent a) : a = case x of ZeroTangent -> zero SomeTangent x' -> x'

Approximate Equality

TODO: move this outside the AD section to be with equality?

interface HasAllClose a allclose : a -> a -> a -> a -> Bool
interface HasDefaultTolerance a atol : a rtol : a
def (~~) {a} [HasAllClose a, HasDefaultTolerance a] : a -> a -> Bool = allclose atol rtol
instance HasAllClose Float32 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y)
instance HasAllClose Float64 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y)
instance HasDefaultTolerance Float32 atol = f_to_f32 0.00001 rtol = f_to_f32 0.0001
instance HasDefaultTolerance Float64 atol = f_to_f64 0.00000001 rtol = f_to_f64 0.00001
instance HasAllClose (a & b) given {a b} [ HasDefaultTolerance a, HasDefaultTolerance b , HasAllClose a, HasAllClose b] allclose = \atol rtol (a, b) (c, d). (a ~~ c) && (b ~~ d)
instance HasDefaultTolerance (a & b) given {a b} [HasDefaultTolerance a, HasDefaultTolerance b] atol = (atol, atol) rtol = (rtol, rtol)
instance HasAllClose (n=>t) given {n t} [HasAllClose t] allclose = \atol rtol a b. all for i:n. allclose atol.i rtol.i a.i b.i
instance HasDefaultTolerance (n=>t) given {n t} [HasDefaultTolerance t] atol = for i. atol rtol = for i. rtol

AD Checking tools

def check_deriv_base (f:Float->Float) (x:Float) : Bool = eps = 0.01 ansFwd = deriv f x ansRev = deriv_rev f x ansNumeric = (f (x + eps) - f (x - eps)) / (2. * eps) ansFwd ~~ ansNumeric && ansRev ~~ ansNumeric
def check_deriv (f:Float->Float) (x:Float) : Bool = check_deriv_base f x && check_deriv_base (deriv f) x

Length-erased lists

data List a = AsList n:Nat elements:(Fin n => a)
instance Eq (List a) given {a} [Eq a] (==) = \(AsList nx xs) (AsList ny ys). if nx /= ny then False else all for i:(Fin nx). xs.i == ys.(unsafe_from_ordinal _ (ordinal i))
def unsafe_cast_table {n a} (m:Type) [Ix m] (xs:n=>a) : m=>a = for i. xs.(unsafe_from_ordinal _ (ordinal i))
def to_list {n a} (xs:n=>a) : List a = n' = size n AsList _ $ unsafe_cast_table (Fin n') xs
instance Monoid (List a) given {a} mempty = AsList _ [] mcombine = \x y. (AsList nx xs) = x (AsList ny ys) = y nz = nx + ny AsList _ $ for i:(Fin nz). i' = ordinal i case i' < nx of True -> xs.(unsafe_from_ordinal _ i') False -> ys.(unsafe_from_ordinal _ $ unsafe_nat_diff i' nx)
named-instance ListMonoid : Monoid (List a) given (a:Type) mempty = mempty mcombine = mcombine
-- TODO Eliminate or reimplement this operation, since it costs O(n)
-- where n is the length of the list held in the reference.
def append {a h} [AccumMonoid h (List a)] (list: Ref h (List a)) (x:a) : {Accum h} Unit = list += AsList _ [x]

Isomorphisms

data Iso a b = MkIso { fwd: a -> b & bwd: b -> a }
def app_iso {a b} (iso: Iso a b) (x:a) : b = (MkIso {fwd, bwd}) = iso fwd x
def flip_iso {a b} (iso: Iso a b) : Iso b a = (MkIso {fwd, bwd}) = iso MkIso {fwd=bwd, bwd=fwd}
def rev_iso {a b} (iso: Iso a b) (x:b) : a = app_iso (flip_iso iso) x
def id_iso {a} : Iso a a = MkIso {fwd=id, bwd=id}
def (&>>) {a b c} (iso1: Iso a b) (iso2: Iso b c) : Iso a c = (MkIso {fwd=fwd1, bwd=bwd1}) = iso1 (MkIso {fwd=fwd2, bwd=bwd2}) = iso2 MkIso {fwd=(fwd1 >>> fwd2), bwd=(bwd1 <<< bwd2)}
def (<<&) {a b c} (iso2: Iso b c) (iso1: Iso a b) : Iso a c = iso1 &>> iso2

Lens-like accessors

note: #foo is an Iso {foo: a & ...b} (a & {&...b}))

def get_at {a b c} (iso: Iso a (b & c)) : a -> b = fst <<< app_iso iso
def pop_at {a b c} (iso: Iso a (b & c)) : a -> c = snd <<< app_iso iso
def push_at {a b c} (iso: Iso a (b & c)) (x:b) (r:c) : a = rev_iso iso (x, r)
def set_at {a b c} (iso: Iso a (b & c)) (x:b) (r:a) : a = push_at iso x $ pop_at iso r

Prism-like accessors

note: #?foo is an Iso {foo: a | ...b} (a | {|...b}))

def match_with {a b c} (iso: Iso a (b | c)) (x: a) : Maybe b = case app_iso iso x of Left x -> Just x Right _ -> Nothing
def build_with {a b c} (iso: Iso a (b | c)) (x: b) : a = rev_iso iso $ Left x
def swap_pair_iso {a b} : Iso (a & b) (b & a) = MkIso {fwd = \(a, b). (b, a), bwd = \(b, a). (a, b)}

Complement lens

Complement the focus of a lens-like isomorphism

def except_lens {a b c} (iso: Iso a (b & c)) : Iso a (c & b) = iso &>> swap_pair_iso
def swap_either_iso {a b} : Iso (a | b) (b | a) = fwd = \x. case x of Left l -> Right l Right r -> Left r bwd = \x. case x of Left r -> Right r Right l -> Left l MkIso {fwd, bwd}

Complement prism

Complement the focus of a prism-like isomorphism

def except_prism {a b c} : Iso a (b | c) -> Iso a (c | b) = \iso. iso &>> swap_either_iso
-- Use a lens-like iso to split a 1d table into a 2d table
def over_lens {a b c v} [Ix b, Ix c] (iso: Iso a (b & c)) (tab: a=>v) : (b=>c=>v) = for i j. tab.(rev_iso iso (i, j))

Zipper

Zipper isomorphisms to easily specify many record/variant fields:

#&foo is an Iso ({&...l} & {foo:a & ...r}) ({foo:a & ...l} & {&...r})
#|foo is an Iso ({|...l} | {foo:a | ...r}) ({foo:a | ...l} | {|...r})

Convert a record zipper isomorphism to a normal lens-like isomorphism by using splitR &>> iso

def split_r {a} : Iso a ({&} & a) = MkIso {fwd = \x. ({}, x), bwd = \({}, x). x}
def over_fields {a b c v} [Ix b, Ix c] (iso: Iso ({&} & a) (b & c)) (tab: a=>v) : b=>c=>v = over_lens (split_r &>> iso) tab

Convert a variant zipper isomorphism to a normal prism-like isomorphism by using splitV &>> iso

def split_v {a} : Iso a ({ | } | a) = MkIso {fwd = \x. Right x, bwd = \v. case v of Right x -> x}
def slice_fields {a b c v} [Ix b] (iso: Iso ({ | } | a) (b | c)) (tab: a=>v) : b=>v = reindex (build_with $ split_v &>> iso) tab
-- TODO: replace `slice` with this?
def post_slice {a n} (xs:n=>a) (start:Post n) (end:Post n) : List a = slice_size = unsafe_nat_diff (ordinal end) (ordinal start) to_list for i:(Fin slice_size). xs.(unsafe_from_ordinal n (ordinal i + ordinal start))

Dynamic buffer

-- TODO: would be nice to be able to use records here
data DynBuffer a = MkDynBuffer { size : Ptr Nat & maxSize : Ptr Nat & buffer : Ptr (Ptr a) }
def with_dynamic_buffer {a b} [Storable a] (action: DynBuffer a -> {IO} b) : {IO} b = initMaxSize = 256 sizePtr <- with_alloc 1 store sizePtr 0 maxSizePtr <- with_alloc 1 store maxSizePtr initMaxSize bufferPtr <- with_alloc 1 store bufferPtr $ malloc initMaxSize result = action $ MkDynBuffer { size = sizePtr , maxSize = maxSizePtr , buffer = bufferPtr } free $ load bufferPtr result
def maybe_increase_buffer_size {a} [Storable a] ((MkDynBuffer db): DynBuffer a) (sizeDelta:Nat) : {IO} Unit = size = load $ get_at #size db maxSize = load $ get_at #maxSize db bufPtr = load $ get_at #buffer db newSize = sizeDelta + size if newSize > maxSize then -- TODO: maybe this should use integer arithmetic? newMaxSize = f_to_n $ pow 2.0 (ceil $ log2 $ n_to_f newSize) newBufPtr = malloc newMaxSize memcpy newBufPtr bufPtr size free bufPtr store (get_at #maxSize db) newMaxSize store (get_at #buffer db) newBufPtr
def add_at_nat_ptr (ptr: Ptr Nat) (n:Nat) : {IO} Unit = store ptr (load ptr + n)
def extend_dynamic_buffer {a} [Storable a] (buf: DynBuffer a) (new:List a) : {IO} Unit = (AsList n xs) = new maybe_increase_buffer_size buf n (MkDynBuffer db) = buf bufPtr = load $ get_at #buffer db size = load $ get_at #size db store_table (bufPtr +>> size) xs add_at_nat_ptr (get_at #size db) n
def load_dynamic_buffer {a} [Storable a] (buf: DynBuffer a) : {IO} (List a) = (MkDynBuffer db) = buf bufPtr = load $ get_at #buffer db size = load $ get_at #size db AsList size $ table_from_ptr _ bufPtr
def push_dynamic_buffer {a} [Storable a] (buf: DynBuffer a) (x:a) : {IO} Unit = extend_dynamic_buffer buf $ AsList _ [x]

Strings and Characters

String : Type = List Char
def string_from_char_ptr (n:Word32) (ptr:Ptr Char) : {IO} String = AsList (rep_to_nat n) $ table_from_ptr _ ptr
-- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint
def codepoint (c:Char) : Int = w8_to_i c
data CString = MkCString RawPtr
-- TODO: check the string contains no nulls
def with_c_string {a} (s:String) (action: CString -> {IO} a) : {IO} a = (AsList n s') = s <> "\NUL" with_table_ptr s' \(MkPtr ptr). action $ MkCString ptr

Show interface

For things that can be shown. show gives a string representation of its input. No particular promises are made to exactly what that representation will contain. In particular it is not promised to be parseable. Nor does it promise a particular level of precision for numeric values.

interface Show a show : a -> String
instance Show String show = id
foreign "showInt32" showInt32 : Int32 -> {IO} (Word32 & RawPtr)
instance Show Int32 show = \x. unsafe_io do (n, ptr) = showInt32 x string_from_char_ptr n $ MkPtr ptr
foreign "showInt64" showInt64 : Int64 -> {IO} (Word32 & RawPtr)
instance Show Int64 show = \x. unsafe_io do (n, ptr) = showInt64 x string_from_char_ptr n $ MkPtr ptr
instance Show Nat show = \x. show $ n_to_i64 x
foreign "showFloat32" showFloat32 : Float32 -> {IO} (Word32 & RawPtr)
instance Show Float32 show = \x. unsafe_io do (n, ptr) = showFloat32 x string_from_char_ptr n $ MkPtr ptr
foreign "showFloat64" showFloat64 : Float64 -> {IO} (Word32 & RawPtr)
instance Show Float64 show = \x. unsafe_io do (n, ptr) = showFloat64 x string_from_char_ptr n $ MkPtr ptr
instance Show (a & b) given {a b} [Show a, Show b] show = \(a, b). "(" <> show a <> ", " <> show b <> ")"

Parse interface

For types that can be parsed from a String.

interface Parse a parseString : String -> Maybe a
foreign "strtof" strtofFFI : RawPtr -> RawPtr -> {IO} Float
instance Parse Float parseString = \str. unsafe_io do (AsList str_len _) = str with_c_string str \(MkCString str_ptr). with_alloc 1 \end_ptr:(Ptr (Ptr Char)). (MkPtr raw_end_ptr) = end_ptr result = strtofFFI str_ptr raw_end_ptr (MkPtr str_end_ptr) = load end_ptr consumed = raw_ptr_to_i64 str_end_ptr - raw_ptr_to_i64 str_ptr if consumed == (n_to_i64 str_len) then Just result else Nothing

pipe-like reverse function application

TODO: move this

def (|>) {a b} (x:a) (f: a -> b) : b = f x

Floating-point helper functions

TODO: Move these to be with Elementary/Special functions. Or move those to be here.

def sign (x:Float) : Float = case x > 0.0 of True -> 1.0 False -> case x < 0.0 of True -> -1.0 False -> x
def copysign (a:Float) (b:Float) : Float = case b > 0.0 of True -> a False -> case b < 0.0 of True -> (-a) False -> 0.0
-- Todo: use IEEE floating-point builtins.
infinity = 1.0 / 0.0
nan = 0.0 / 0.0
-- Todo: use IEEE floating-point builtins.
def isinf (x:Float) : Bool = (x == infinity) || (x == -infinity)
def isnan (x:Float) : Bool = not (x >= x && x <= x)
-- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered.
def either_is_nan (x:Float) (y:Float) : Bool = (isnan x) || (isnan y)

File system operations

FilePath : Type = String
def is_null_raw_ptr (ptr:RawPtr) : Bool = raw_ptr_to_i64 ptr == 0
def from_nullable_raw_ptr {a} (ptr:RawPtr) : Maybe (Ptr a) = if is_null_raw_ptr ptr then Nothing else Just $ MkPtr ptr
def c_string_ptr (s:CString) : Maybe (Ptr Char) = (MkCString ptr) = s from_nullable_raw_ptr ptr
data StreamMode = ReadMode WriteMode
data Stream mode:StreamMode = MkStream RawPtr

Stream IO

foreign "fopen" fopenFFI : RawPtr -> RawPtr -> {IO} RawPtr
foreign "fclose" fcloseFFI : RawPtr -> {IO} Int64
foreign "fwrite" fwriteFFI : RawPtr -> Int64 -> Int64 -> RawPtr -> {IO} Int64
foreign "fread" freadFFI : RawPtr -> Int64 -> Int64 -> RawPtr -> {IO} Int64
foreign "fflush" fflushFFI : RawPtr -> {IO} Int64
def fopen (path:String) (mode:StreamMode) : {IO} (Stream mode) = modeStr = case mode of ReadMode -> "r" WriteMode -> "w" with_c_string path \(MkCString pathPtr). with_c_string modeStr \(MkCString modePtr). MkStream $ fopenFFI pathPtr modePtr
def fclose {mode} (stream:Stream mode) : {IO} Unit = (MkStream stream') = stream fcloseFFI stream' ()
def fwrite (stream:Stream WriteMode) (s:String) : {IO} Unit = (MkStream stream') = stream (AsList n s') = s with_table_ptr s' \(MkPtr ptr). fwriteFFI ptr (i_to_i64 1) (n_to_i64 n) stream' fflushFFI stream' ()

Iteration

TODO: move this out of the file-system section

def while {eff} (body: Unit -> {|eff} Bool) : {|eff} Unit = body' : Unit -> {|eff} Word8 = \_. b_to_w8 $ body () %while body'
data IterResult a = Continue Done a
-- TODO: can we improve effect inference so we don't need this?
def lift_state {a b c h eff} (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|eff} b = f x
-- A little iteration combinator
def iter {a eff} (body: Nat -> {|eff} IterResult a) : {|eff} a = result = yield_state Nothing \resultRef. i <- with_state 0 while do continue = is_nothing $ get resultRef if continue then case lift_state resultRef (lift_state i body) (get i) of Continue -> i := get i + 1 Done result -> resultRef := Just result continue case result of Just ans -> ans Nothing -> unreachable ()
def bounded_iter {a eff} (maxIters:Nat) (fallback:a) (body: Nat -> {|eff} IterResult a) : {|eff} a = iter \i. if i >= maxIters then Done fallback else body i

Environment Variables

def from_c_string (s:CString) : {IO} (Maybe String) = case c_string_ptr s of Nothing -> Nothing Just ptr -> Just $ with_dynamic_buffer \buf. iter \i. c = load $ ptr +>> i if c == '\NUL' then Done $ load_dynamic_buffer buf else push_dynamic_buffer buf c Continue
foreign "getenv" getenvFFI : RawPtr -> {IO} RawPtr
def get_env (name:String) : {IO} Maybe String = with_c_string name \(MkCString ptr). from_c_string $ MkCString $ getenvFFI ptr
def check_env (name:String) : {IO} Bool = is_just $ get_env name

More Stream IO

def fread (stream:Stream ReadMode) : {IO} String = (MkStream stream') = stream -- TODO: allow reading longer files! n = 4096 ptr:(Ptr Char) <- with_alloc n buf <- with_dynamic_buffer iter \_. (MkPtr rawPtr) = ptr numRead = i_to_w32 $ i64_to_i $ freadFFI rawPtr (i_to_i64 1) (n_to_i64 n) stream' extend_dynamic_buffer buf $ string_from_char_ptr numRead ptr if numRead == n_to_w32 n then Continue else Done () load_dynamic_buffer buf

Print

def get_output_stream (_:Unit) : {IO} Stream WriteMode = MkStream $ %outputStream
def print (s:String) : {IO} Unit = stream = get_output_stream () fwrite stream s fwrite stream "\n"

Shelling Out

foreign "popen" popenFFI : RawPtr -> RawPtr -> {IO} RawPtr
foreign "remove" removeFFI : RawPtr -> {IO} Int64
foreign "mkstemp" mkstempFFI : RawPtr -> {IO} Int32
foreign "close" closeFFI : Int32 -> {IO} Int32
def shell_out (command:String) : {IO} String = modeStr = "r" with_c_string command \(MkCString commandPtr). with_c_string modeStr \(MkCString modePtr). pipe = MkStream $ popenFFI commandPtr modePtr fread pipe

Partial functions

A partial function in this context is a function that can error. i.e. a function that is not actually defined for all of its supposed domain. Not to be confused with a partially applied function

Error throwing

def error {a} (s:String) : a = unsafe_io do print s %throwError a
def todo {a} : a = error "TODO: implement it!"

File Operations

def delete_file (f:FilePath) : {IO} Unit = with_c_string f \(MkCString ptr). removeFFI ptr ()
def with_file {a} (f:FilePath) (mode:StreamMode) (action: Stream mode -> {IO} a) : {IO} a = stream = fopen f mode (MkStream stream') = stream if is_null_raw_ptr stream' then error $ "Unable to open file: " <> f else result = action stream fclose stream result
def write_file (f:FilePath) (s:String) : {IO} Unit = with_file f WriteMode \stream. fwrite stream s
def read_file (f:FilePath) : {IO} String = with_file f ReadMode \stream. fread stream
def has_file (f:FilePath) : {IO} Bool = stream = fopen f ReadMode (MkStream stream') = stream result = not (is_null_raw_ptr stream') if result then fclose stream result

Temporary Files

def new_temp_file (_:Unit) : {IO} FilePath = with_c_string "/tmp/dex-XXXXXX" \(MkCString ptr). fd = mkstempFFI ptr closeFFI fd string_from_char_ptr 15 (MkPtr ptr)
def with_temp_file {a} (action: FilePath -> {IO} a) : {IO} a = tmpFile = new_temp_file () result = action tmpFile delete_file tmpFile result
def with_temp_files {n a} [Ix n] (action: (n=>FilePath) -> {IO} a) : {IO} a = tmpFiles = for i. new_temp_file () result = action tmpFiles for i. delete_file tmpFiles.i result

Table operations

@noinline def from_ordinal_error (i:Nat) (upper:Nat) : String = "Ordinal index out of range:" <> show i <> " >= " <> show upper
def from_ordinal (n:Type) [Ix n] (i:Nat) : n = case (0 <= i) && (i < size n) of True -> unsafe_from_ordinal _ i False -> error $ from_ordinal_error i $ size n
-- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy
-- TODO: safe (runtime-checked) and unsafe versions
def cast_table {n a} (m:Type) [Ix m] (xs:n=>a) : m=>a = case size m == size n of True -> unsafe_cast_table _ xs False -> error $ "Table size mismatch in cast: " <> show (size m) <> " vs " <> show (size n)
def asidx {n} [Ix n] (i:Nat) : n = from_ordinal n i
def (@) (i:Nat) (n:Type) [Ix n] : n = from_ordinal n i
def slice {n a} (xs:n=>a) (start:Nat) (m:Type) [Ix m] : m=>a = for i. xs.(from_ordinal _ (ordinal i + start))
def head {n a} (xs:n=>a) : a = xs.(0@_)
def tail {n a} (xs:n=>a) (start:Nat) : List a = numElts = size n -| start to_list $ slice xs start (Fin numElts)

Pseudorandom number generator utilities

Dex does not use a stateful random number generator. Rather it uses what is known as a split-able random number generator, which is based on a hash function. Dex's PRNG system is modelled directly after JAX's, which is based on a well established but shockingly underused idea from the functional programming community: the splittable PRNG. It's a good idea for many reasons, but it's especially helpful in a parallel setting. If you want to read more, Splittable pseudorandom number generators using cryptographic hashing describes the splitting model itself and D.E. Shaw Research's counter-based PRNG proposes the particular hash function we use.

Key functions

-- TODO: newtype
Key = Word64
@noinline def threefry_2x32 (k:Word64) (count:Word64) : Word64 = -- Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins rotations1 = [13, 15, 26, 6] rotations2 = [17, 29, 16, 24] k0 = low_word k k1 = high_word k -- TODO: add a fromHex k2 = k0 .^. k1 .^. (n_to_w32 466688986) -- 0x1BD11BDA x = low_word count y = high_word count x = x + k0 y = y + k1 rotations = [rotations1, rotations2] ks = [k1, k2, k0] (x, y) = yield_state (x, y) \ref. for i:(Fin 5). for j. (x, y) = get ref rotationIndex = unsafe_from_ordinal _ (mod (ordinal i) 2) rot = rotations.rotationIndex.j x = x + y y = (y .<<. rot) .|. (y .>>. (32 - rot)) y = x .^. y ref := (x, y) (x, y) = get ref x = x + ks.(unsafe_from_ordinal _ (mod (ordinal i) 3)) y = y + ks.(unsafe_from_ordinal _ (mod ((ordinal i)+1) 3)) + n_to_w32 ((ordinal i)+1) ref := (x, y) (w32_to_w64 x .<<. 32) .|. (w32_to_w64 y)
def hash (x:Key) (y:Nat) : Key = y64 = n_to_w64 y threefry_2x32 x y64
def new_key (x:Nat) : Key = hash (n_to_w64 0) x
def many {a n} [Ix n] (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i))
def ixkey {n} (k:Key) (i:n) [Ix n] : Key = hash k (ordinal i)
def split_key {n} (k:Key) : Fin n => Key = for i. ixkey k i

Sample Generators

These functions generate samples taken from, different distributions. Such as rand_mat with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. Note that additional standard distributions are provided by the stats library.

def rand (k:Key) : Float = exponent_bits = 1065353216 -- 1065353216 = 127 << 23 mantissa_bits = (high_word k .&. 8388607) -- 8388607 == (1 << 23) - 1 bits = exponent_bits .|. mantissa_bits (%bitcast Float bits) - 1.0
def rand_vec {a} (n:Nat) (f: Key -> a) (k: Key) : Fin n => a = for i:(Fin n). f (ixkey k i)
def rand_mat {a} (n:Nat) (m:Nat) (f: Key -> a) (k: Key) : Fin n => Fin m => a = for i j. f (ixkey k (i, j))
def randn (k:Key) : Float = [k1, k2] = split_key k -- rand is uniform between 0 and 1, but implemented such that it rounds to 0 -- (in float32) once every few million draws, but never rounds to 1. u1 = 1.0 - (rand k1) u2 = rand k2 sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2)
-- TODO: Make this better...
def rand_int (k:Key) : Nat = (internal_cast Nat k) `mod` 2147483647
def bern (p:Float) (k:Key) : Bool = rand k < p
def randn_vec {n} [Ix n] (k:Key) : n=>Float = for i. randn (ixkey k i)
def rand_idx {n} [Ix n] (k:Key) : n = unif = rand k unsafe_from_ordinal n $ f_to_n $ floor $ unif * n_to_f (size n)

Inner product typeclass

interface [VSpace v] InnerProd v inner_prod : v->v->Float
instance InnerProd Float inner_prod = \x y. x * y
instance InnerProd (n=>a) given {a n} [InnerProd a] inner_prod = \x y. sum for i. inner_prod x.i y.i

Arbitrary

Type class for generating example values

interface Arbitrary a arb : Key -> a
instance Arbitrary Bool arb = \key. key .&. 1 == 0
instance Arbitrary Float32 arb = randn
instance Arbitrary Int32 arb = \key. f_to_i $ randn key * 5.0
instance Arbitrary Nat arb = \key. f_to_n $ randn key * 5.0
instance Arbitrary (n=>a) given {n a} [Arbitrary a] arb = \key. for i. arb $ ixkey key i
instance Arbitrary ((i:n)=>(..<i) => a) given {n a} [Arbitrary a] arb = \x. for i. arb $ new_key (ordinal i)
instance Arbitrary ((i:n)=>(..i) => a) given {n a} [Arbitrary a] arb = \x. for i. arb $ new_key (ordinal i)
instance Arbitrary ((i:n)=>(i..) => a) given {n a} [Arbitrary a] arb = \x. for i. arb $ new_key (ordinal i)
instance Arbitrary ((i:n)=>(i<..) => a) given {n a} [Arbitrary a] arb = \x. for i. arb $ new_key (ordinal i)
instance Arbitrary (a & b) given {a b} [Arbitrary a, Arbitrary b] arb = \key. [k1, k2] = split_key key (arb k1, arb k2)
instance Arbitrary (Fin n) given {n} arb = rand_idx

Ord on Arrays

Searching

returns the highest index i such that xs.i <= x

def search_sorted {n a} [Ord a] (xs:n=>a) (x:a) : Maybe n = if size n == 0 then Nothing else if x < xs.(from_ordinal _ 0) then Nothing else with_state 0 \low. with_state (size n) \high. iter \_. numLeft = n_to_i (get high) - n_to_i (get low) if numLeft == 1 then Done $ Just $ from_ordinal _ $ get low else centerIx = get low + unsafe_i_to_n (idiv numLeft 2) if x < xs.(from_ordinal _ centerIx) then high := centerIx else low := centerIx Continue

min / max etc

def min_by {a o} [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y
def max_by {a o} [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y
def min {o} [Ord o] (x1: o) (x2: o) : o = min_by id x1 x2
def max {o} [Ord o] (x1: o) (x2: o) : o = max_by id x1 x2
def minimum_by {a n o} [Ord o] (f:a->o) (xs:n=>a) : a = reduce xs.(0@_) (min_by f) xs
def maximum_by {a n o} [Ord o] (f:a->o) (xs:n=>a) : a = reduce xs.(0@_) (max_by f) xs
def minimum {n o} [Ord o] (xs:n=>o) : o = minimum_by id xs
def maximum {n o} [Ord o] (xs:n=>o) : o = maximum_by id xs

argmin/argmax

TODO: put in same section as searchsorted

def argscan {n o} (comp:o->o->Bool) (xs:n=>o) : n = zeroth = (0@_, xs.(0@_)) compare = \(idx1, x1) (idx2, x2). select (comp x1 x2) (idx1, x1) (idx2, x2) zipped = for i. (i, xs.i) fst $ reduce zeroth compare zipped
def argmin {n o} [Ord o] (xs:n=>o) : n = argscan (<) xs
def argmax {n o} [Ord o] (xs:n=>o) : n = argscan (>) xs
def lexical_order {n} [Ord n] (compareElements:n->n->Bool) (compareLengths:Nat->Nat->Bool) ((AsList nx xs):List n) ((AsList ny ys):List n) : Bool = -- Orders Lists according to the order of their elements, -- in the same way a dictionary does. -- For example, this lets us sort Strings. -- -- More precisely, it returns True iff compareElements xs.i ys.i is true -- at the first location they differ. -- -- This function operates serially and short-circuits -- at the first difference. One could also write this -- function as a parallel reduction, but it would be -- wasteful in the case where there is an early difference, -- because we can't short circuit. iter \i. case i == min nx ny of True -> Done $ compareLengths nx ny False -> xi = xs.(unsafe_from_ordinal _ i) yi = ys.(unsafe_from_ordinal _ i) case compareElements xi yi of True -> Done True False -> case xi == yi of True -> Continue False -> Done False
instance Ord (List n) given {n} [Ord n] (>) = lexical_order (>) (>) (<) = lexical_order (<) (<)

clip

def clip {a} [Ord a] ((low,high):(a&a)) (x:a) : a = min high $ max low x

Trigonometric functions

TODO: these should be with the other Elementary/Special Functions

atan/atan2

def atan_inner (x:Float) : Float = -- From "Computing accurate Horner form approximations to -- special functions in finite precision arithmetic" -- https://arxiv.org/abs/1508.03211 -- Only accurate in the range [-1, 1] s = x * x r = 0.0027856871 r = r * s - 0.0158660002 r = r * s + 0.042472221 r = r * s - 0.0749753043 r = r * s + 0.106448799 r = r * s - 0.142070308 r = r * s + 0.199934542 r = r * s - 0.333331466 r = r * s r * x + x
def min_and_max {a} [Ord a] (x:a) (y:a) : (a & a) = select (x < y) (x, y) (y, x) -- get both with one comparison.
def atan2 (y:Float) (x:Float) : Float = -- Based off of the Tensorflow implementation at -- github.com/tensorflow/mlir-hlo/blob/master/lib/ -- Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147 -- With a fix to the nan propagation. abs_x = abs x abs_y = abs y (min_abs_x_y, max_abs_x_y) = min_and_max abs_x abs_y a = atan_inner (min_abs_x_y / max_abs_x_y) a = select (abs_x <= abs_y) ((pi / 2.0) -a) a a = select (x < 0.0) (pi - a) a t = select (x < 0.0) pi 0.0 a = select (y == 0.0) t a t = select (x < 0.0) (3.0 * pi / 4.0) (pi / 4.0) a = select (isinf x && isinf y) t a -- Handle infinite inputs. a = copysign a y select (either_is_nan x y) nan a -- Propagate NaNs.
def atan (x:Float) : Float = atan2 x 1.0

Miscellaneous utilities

TODO: all of these should be in some other section

def reflect {n} [Ix n] (i:n) : n = unsafe_from_ordinal n $ unsafe_nat_diff (size n) (ordinal i + 1)
def reverse {n a} (x:n=>a) : n=>a = for i. x.(reflect i)
def wrap_periodic (n:Type) [Ix n] (i:Nat) : n = unsafe_from_ordinal n $ mod i (size n)
def pad_to {n a} (m:Type) [Ix m] (x:a) (xs:n=>a) : m=>a = n' = size n for i. i' = ordinal i case i' < n' of True -> xs.(i'@_) False -> x
def idiv_ceil (x:Nat) (y:Nat) : Nat = idiv x y + b_to_n (rem x y /= 0)
def intdiv2 (x:Nat) : Nat = rep_to_nat $ %shr (nat_to_rep x) (1 :: NatRep)
def intpow2 (power:Nat) : Nat = rep_to_nat $ %shl (1 :: NatRep) (nat_to_rep power)
def is_odd (x:Nat) : Bool = rem x 2 == 1
def is_even (x:Nat) : Bool = rem x 2 == 0
def is_power_of_2 (x:Nat) : Bool = -- A fast trick based on bitwise AND. -- This works on integer types larger than 8 bits. -- Note: The bitwise and operator (.&.) -- is only defined for Byte, which is why -- we use %and here. TODO: Make (.&.) polymorphic. x' = nat_to_rep x if x' == 0 then False else 0 == %and x' (%isub x' (1::NatRep))
def natlog2 (x:Nat) : Nat = tmp = yield_state 0 \ans. cmp <- run_state 1 while do if x >= (get cmp) then ans := (get ans) + 1 cmp := rep_to_nat $ %shl (nat_to_rep $ get cmp) (1 :: NatRep) True else False unsafe_nat_diff tmp 1 -- TODO: something less horrible
def general_integer_power {a} (times:a->a->a) (one:a) (base:a) (power:Nat) : a = -- Implements exponentiation by squaring. -- This could be nicer if there were a way to explicitly -- specify which typelcass instance to use for Mul. yield_state one \ans. pow <- with_state power z <- with_state base while do if get pow > 0 then if is_odd (get pow) then ans := times (get ans) (get z) z := times (get z) (get z) pow := intdiv2 (get pow) True else False
def intpow {a} [Mul a] (base:a) (power:Nat) : a = general_integer_power (*) one base power
def from_just {a} (x:Maybe a) : a = case x of Just x' -> x'
def any_sat {a n} (f:a -> Bool) (xs:n=>a) : Bool = any (map f xs)
def seq_maybes {n a} (xs : n=>Maybe a) : Maybe (n => a) = -- is it possible to implement this safely? (i.e. without using partial -- functions) case any_sat is_nothing xs of True -> Nothing False -> Just $ map from_just xs
def linear_search {n a} [Eq a] (xs:n=>a) (query:a) : Maybe n = yield_state Nothing \ref. for i. case xs.i == query of True -> ref := Just i False -> ()
def list_length {a} ((AsList n _):List a) : Nat = n
-- This is for efficiency (rather than using `<>` repeatedly)
-- TODO: we want this for any monoid but this implementation won't work.
def concat {n a} (lists:n=>(List a)) : List a = totalSize = sum for i. list_length lists.i AsList _ $ with_state 0 \listIdx. eltIdx <- with_state 0 for i:(Fin totalSize). while do continue = get eltIdx >= list_length (lists.((get listIdx)@_)) if continue then eltIdx := 0 listIdx := get listIdx + 1 else () continue (AsList _ xs) = lists.((get listIdx)@_) eltIdxVal = get eltIdx eltIdx := eltIdxVal + 1 xs.(eltIdxVal@_)
def cat_maybes {a n} (xs:n=>Maybe a) : List a = (num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref. for i. case xs.i of Just _ -> ix = get $ fst_ref ref (snd_ref ref) ! (unsafe_from_ordinal _ ix) := Just i fst_ref ref := ix + 1 Nothing -> () to_list $ for i:(Fin num_res). case res_inds.(unsafe_from_ordinal _ $ ordinal i) of Just j -> case xs.j of Just x -> x Nothing -> todo -- Impossible Nothing -> todo -- Impossible
def filter {a n} [Ix n] (condition:a->Bool) (xs:n=>a) : List a = cat_maybes $ for i. if condition xs.i then Just xs.i else Nothing
def arg_filter {a n} [Ix n] (condition:a->Bool) (xs:n=>a) : List n = cat_maybes $ for i. if condition xs.i then Just i else Nothing
-- TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead
def prev_ix {n} [Ix n] (i:n) : Maybe n = case i_to_n (n_to_i (ordinal i) - 1) of Nothing -> Nothing Just i_prev -> Just $ unsafe_from_ordinal n i_prev
def lines (source:String) : List String = (AsList _ s) = source (AsList num_lines newline_ixs) = cat_maybes for i_char. if s.i_char == '\n' then Just i_char else Nothing to_list for i_line:(Fin num_lines). start = case prev_ix i_line of Nothing -> first_ix Just i -> right_post newline_ixs.i end = left_post newline_ixs.i_line post_slice s start end

Probability

-- cdf should include 0.0 but not 1.0
def categorical_from_cdf {n} (cdf: n=>Float) (key: Key) : n = r = rand key case search_sorted cdf r of Just i -> i
def normalize_pdf {d} (xs: d=>Float) : d=>Float = xs / sum xs
def cdf_for_categorical {n} (logprobs: n=>Float) : n=>Float = maxLogProb = maximum logprobs cumsum_low $ normalize_pdf $ map exp $ for i. logprobs.i - maxLogProb
def categorical {n} (logprobs: n=>Float) (key: Key) : n = categorical_from_cdf (cdf_for_categorical logprobs) key
-- batch variant to share the work of forming the cumsum
-- (alternatively we could rely on hoisting of loop constants)
def categorical_batch {n m} [Ix m] (logprobs: n=>Float) (key: Key) : m=>n = cdf = cdf_for_categorical logprobs for i. categorical_from_cdf cdf $ ixkey key i
def logsumexp {n} (x: n=>Float) : Float = m = maximum x m + (log $ sum for i. exp (x.i - m))
def logsoftmax {n} (x: n=>Float) : n=>Float = lse = logsumexp x for i. x.i - lse
def softmax {n} (x: n=>Float) : n=>Float = m = maximum x e = for i. exp (x.i - m) s = sum e for i. e.i / s

Polynomials

TODO: Move this somewhere else

def evalpoly {n v} [VSpace v] (coefficients:n=>v) (x:Float) : v = -- Evaluate a polynomial at x. Same as Numpy's polyval. fold zero \i c. coefficients.i + x .* c

TestMode

TODO: move this to be in Testing Helpers

def dex_test_mode (():Unit) : Bool = unsafe_io do check_env "DEX_TEST_MODE"

Exception effect

TODO: move error and todo to here.

def catch {a eff} (f:Unit -> {Except|eff} a) : {|eff} Maybe a = ans = %catchException f case %sumToVariant ans of {|c=() |} -> Nothing {|c|c=val|} -> Just val
def throw {a} (_:Unit) : {Except} a = %throwException a
def assert (b:Bool) : {Except} Unit = if not b then throw ()

Misc instances that require error

instance Subset a (a|b) given {a b} inject' = \x. Left x project' = \x. case x of Left y -> Just y Right x -> Nothing unsafe_project' = \x. case x of Left x -> x Right x -> error "Can't project Right branch to Left branch"
instance Subset a (b|a) given {a b} inject' = \x. Right x project' = \x. case x of Left x -> Nothing Right y -> Just y unsafe_project' = \x. case x of Left x -> error "Can't project Left branch to Right branch" Right x -> x

Testing Helpers

-- -- Reliably causes a segfault if pointers aren't initialized to zero.
-- -- TODO: add this test when we cache modules
-- justSomeDataToTestCaching = toList for i:(Fin 100).
-- if ordinal i == 0
-- then Left (toList [1,2,3])
-- else Right 1

Index set for tables

def int_to_reversed_digits {a b} [Ix a, Ix b] (k:Nat) : a=>b = base = size b snd $ scan k \_ cur_k. next_k = idiv cur_k base digit = mod cur_k base (next_k, unsafe_from_ordinal b digit)
def reversed_digits_to_int {a b} [Ix a, Ix b] (digits: a=>b) : Nat = base = size b fst $ fold (0, 1) \j (cur_k, cur_base). next_k = cur_k + ordinal digits.j * cur_base next_base = cur_base * base (next_k, next_base)
instance Ix (a=>b) given {a b} [Ix a, Ix b] -- 0@a is the least significant digit, -- while (size a - 1)@a is the most significant digit. size = intpow (size b) (size a) ordinal = reversed_digits_to_int unsafe_from_ordinal = int_to_reversed_digits
instance NonEmpty (a=>b) given {a b} [Ix a, NonEmpty b] first_ix = unsafe_from_ordinal _ 0