Dex prelude

Runs before every Dex program unless an alternative is provided with --prelude.

Essentials

Primitive Types

Type = %TyKind()
Heap = %HeapType()
Effects = %EffKind()
Int64 = %Int64()
Int32 = %Int32()
Float64 = %Float64()
Float32 = %Float32()
Word8 = %Word8()
Word32 = %Word32()
Word64 = %Word64()
Byte = Word8
Char = Byte
RawPtr : Type = %Word8Ptr()
Int = Int32
Float = Float32
def the(a:Type, x:a) -> a = x
interface Data(a:Type) do_not_implement_this_interface_for_the_compiler_relies_on_the_invariant_it_protects : (a) -> a

Casting

def internal_cast(x:from) -> to given (from, to) = %cast(to, x)
def unsafe_coerce(x:from) -> to given (from|Data, to|Data) = %unsafeCoerce(to, x)
def uninitialized_value() -> a given (a|Data) = %garbageVal(a)
def f64_to_f(x: Float64) -> Float = internal_cast x
def f32_to_f(x: Float32) -> Float = internal_cast x
def f_to_f64(x: Float) -> Float64 = internal_cast x
def f_to_f32(x: Float) -> Float32 = internal_cast x
def i64_to_i(x: Int64) -> Int = internal_cast x
def i32_to_i(x: Int32) -> Int = internal_cast x
def w8_to_i(x: Word8) -> Int = internal_cast x
def i_to_i64(x: Int) -> Int64 = internal_cast x
def i_to_i32(x: Int) -> Int32 = internal_cast x
def i_to_w8(x: Int) -> Word8 = internal_cast x
def i_to_w32(x: Int) -> Word32 = internal_cast x
def i_to_w64(x: Int) -> Word64 = internal_cast x
def w32_to_w64(x: Word32)-> Word64 = internal_cast x
def i_to_f(x:Int) -> Float = internal_cast x
def f_to_i(x:Float) -> Int = internal_cast x
def raw_ptr_to_i64(x:RawPtr) -> Int64 = internal_cast x
Nat = %Nat()
NatRep = Word32
def nat_to_rep(x : Nat) -> NatRep = %projNewtype(x)
def rep_to_nat(x : NatRep) -> Nat = %NatCon(x)
def n_to_w8(x: Nat) -> Word8 = nat_to_rep x | internal_cast
def n_to_w32(x: Nat) -> Word32 = nat_to_rep x | internal_cast
def n_to_w64(x: Nat) -> Word64 = nat_to_rep x | internal_cast
def n_to_i32(x: Nat) -> Int32 = nat_to_rep x | internal_cast
def n_to_i64(x: Nat) -> Int64 = nat_to_rep x | internal_cast
def n_to_f32(x: Nat) -> Float32 = nat_to_rep x | internal_cast
def n_to_f64(x: Nat) -> Float64 = nat_to_rep x | internal_cast
def n_to_f(x: Nat) -> Float = nat_to_rep x | internal_cast
def w8_to_n(x : Word8) -> Nat = internal_cast x | rep_to_nat
def w32_to_n(x : Word32) -> Nat = internal_cast x | rep_to_nat
def w64_to_n(x : Word64) -> Nat = internal_cast x | rep_to_nat
def i32_to_n(x : Int32) -> Nat = internal_cast x | rep_to_nat
def i64_to_n(x : Int64) -> Nat = internal_cast x | rep_to_nat
def f32_to_n(x : Float32) -> Nat = internal_cast x | rep_to_nat
def f64_to_n(x : Float64) -> Nat = internal_cast x | rep_to_nat
def f_to_n(x : Float) -> Nat = internal_cast x | rep_to_nat
interface FromUnsignedInteger(a:Type) from_unsigned_integer : (Word64) -> a
instance FromUnsignedInteger(Word8) def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Word32) def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Word64) def from_unsigned_integer(x) = x
instance FromUnsignedInteger(Int32) def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Int64) def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Float32) def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Float64) def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Nat) def from_unsigned_integer(x) = w64_to_n(x)
interface FromInteger(a:Type) from_integer : (Int64) -> a
instance FromInteger(Float32) def from_integer(x) = internal_cast x
instance FromInteger(Int32) def from_integer(x) = internal_cast x
instance FromInteger(Float64) def from_integer(x) = internal_cast x
instance FromInteger(Int64) def from_integer(x) = x

Bitwise operations

interface Bits(a:Type) (.<<.) : (a, Int) -> a (.>>.) : (a, Int) -> a (.|.) : (a, a) -> a (.&.) : (a, a) -> a (.^.) : (a, a) -> a
instance Bits(Word8) def (.<<.)(x, y) = %shl(x, i_to_w8 y) def (.>>.)(x, y) = %shr(x, i_to_w8 y) def (.|.)(x, y) = %or( x, y) def (.&.)(x, y) = %and(x, y) def (.^.)(x, y) = %xor(x, y)
instance Bits(Word32) def (.<<.)(x, y) = %shl(x, i_to_w32 y) def (.>>.)(x, y) = %shr(x, i_to_w32 y) def (.|.)(x, y) = %or( x, y) def (.&.)(x, y) = %and(x, y) def (.^.)(x, y) = %xor(x, y)
instance Bits(Word64) def (.<<.)(x, y) = %shl(x, i_to_w64 y) def (.>>.)(x, y) = %shr(x, i_to_w64 y) def (.|.)(x, y) = %or( x ,y) def (.&.)(x, y) = %and(x ,y) def (.^.)(x, y) = %xor(x ,y)
def low_word( x : Word64) -> Word32 = internal_cast(x .>>. 32)
def high_word(x : Word64) -> Word32 = internal_cast(x)

Basic Arithmetic

Add

Things that can be added. This defines the Add group and its operators.

interface Add(a|Data) (+) : (a, a) -> a zero : a
interface Sub(a|Add) (-) : (a, a) -> a
instance Add(Float64) def (+)(x, y) = %fadd(x, y) zero = 0
instance Sub(Float64) def (-)(x, y) = %fsub(x, y)
instance Add(Float32) def (+)(x, y) = %fadd(x, y) zero = 0
instance Sub(Float32) def (-)(x, y) = %fsub(x, y)
instance Add(Int64) def (+)(x, y) = %iadd(x, y) zero = 0
instance Sub(Int64) def (-)(x, y) = %isub(x, y)
instance Add(Int32) def (+)(x, y) = %iadd(x, y) zero = 0
instance Sub(Int32) def (-)(x, y) = %isub(x, y)
instance Add(Word8) def (+)(x, y) = %iadd(x, y) zero = 0
instance Sub(Word8) def (-)(x, y) = %isub(x, y)
instance Add(Word32) def (+)(x, y) = %iadd(x, y) zero = 0
instance Sub(Word32) def (-)(x, y) = %isub(x, y)
instance Add(Word64) def (+)(x, y) = %iadd(x, y) zero = 0
instance Sub(Word64) def (-)(x, y) = %isub(x, y)
instance Add(Nat) def (+)(x, y) = rep_to_nat %iadd(nat_to_rep x, nat_to_rep y) zero = 0
instance Add(()) def (+)(x, y) = () zero = ()
instance Sub(()) def (-)(x, y) = ()

Mul

Things that can be multiplied. This defines the Mul Monoid, and its operator.

interface Mul(a|Data) (*) : (a, a) -> a one : a
instance Mul(Float64) def (*)(x, y) = %fmul(x, y) one = f_to_f64 1.0
instance Mul(Float32) def (*)(x, y) = %fmul(x, y) one = f_to_f32 1.0
instance Mul(Int64) def (*)(x, y) = %imul(x, y) one = 1
instance Mul(Int32) def (*)(x, y) = %imul(x, y) one = 1
instance Mul(Word8) def (*)(x, y) = %imul(x, y) one = 1
instance Mul(Word32) def (*)(x, y) = %imul(x, y) one = 1
instance Mul(Word64) def (*)(x, y) = %imul(x, y) one = 1
instance Mul(Nat) def(*)(x, y) = rep_to_nat %imul(nat_to_rep x, nat_to_rep y) one = 1
instance Mul(()) def (*)(x, y) = () one = ()

Integral

Integer-like things.

interface Integral(a) idiv : (a,a)->a rem : (a,a)->a
instance Integral(Int64) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y)
instance Integral(Int32) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y)
instance Integral(Word8) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y)
instance Integral(Word32) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y)
instance Integral(Word64) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y)
instance Integral(Nat) def idiv(x, y) = rep_to_nat %idiv(nat_to_rep x, (nat_to_rep y)) def rem(x, y) = rep_to_nat %irem(nat_to_rep x, (nat_to_rep y))

Fractional

Rational-like things. Includes floating point and two field rational representations.

interface Fractional(a) divide : (a, a) -> a
instance Fractional(Float64) def divide(x, y) = %fdiv(x, y)
instance Fractional(Float32) def divide(x, y) = %fdiv(x, y)

Index set interface and instances

interface Ix(n|Data) size' : () -> Nat ordinal : (n) -> Nat unsafe_from_ordinal : (Nat) -> n
def size(n|Ix) -> Nat = size'(n=n)
def Fin(n:Nat) -> Type = %Fin(n)
-- version of subtraction on Nats that clamps at zero
def (-|)(x: Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y requires_clamp = %ilt(x', y') rep_to_nat %select(requires_clamp, 0, (%isub(x', y')))
def unsafe_nat_diff(x:Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y rep_to_nat %isub(x', y')
-- `(i..)` parses as `RangeFrom _ i`
-- TODO: need to a way to indicate constructor as private
struct RangeFrom(q:Type, i:q) = val : Nat
-- `(i<..)` parses as `RangeFromExc _ i`
struct RangeFromExc(q:Type, i:q) = val : Nat
-- `(..i)` parses as `RangeTo _ i`
struct RangeTo(q:Type, i:q) = val : Nat
-- `(..<i)` parses as `RangeToExc _ i`
struct RangeToExc(q:Type, i:q) = val : Nat
instance Ix(RangeFrom q i) given (q|Ix, i:q) def size'() = unsafe_nat_diff(size q, ordinal i) def ordinal(j) = j.val def unsafe_from_ordinal(j) = RangeFrom(j)
instance Ix(RangeFromExc q i) given (q|Ix, i:q) def size'() = unsafe_nat_diff(size q, ordinal i + 1) def ordinal(j) = j.val def unsafe_from_ordinal(j) = RangeFromExc(j)
instance Ix(RangeTo q i) given (q|Ix, i:q) def size'() = ordinal i + 1 def ordinal(j) = j.val def unsafe_from_ordinal(j) = RangeTo(j)
instance Ix(RangeToExc q i) given (q|Ix, i:q) def size'() = ordinal i def ordinal(j) = j.val def unsafe_from_ordinal(j) = RangeToExc(j)
instance Ix(()) def size'() = 1 def ordinal(_) = 0 def unsafe_from_ordinal(_) = ()
def iota(n|Ix) -> n=>Nat = for i. ordinal i

Arithmetic instances for table types

instance Add(n=>a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero
instance Sub(n=>a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Add((i:n) => (i..) => a) given (a|Add, n|Ix) -- Upper triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero
instance Sub((i:n) => (i..) => a) given (a|Sub, n|Ix) -- Upper triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Add((i:n) => (..i) => a) given (a|Add, n|Ix) -- Lower triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero
instance Sub((i:n) => (..i) => a) given (a|Sub, n|Ix) -- Lower triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Add((i:n) => (..<i) => a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero
instance Sub((i:n) => (..<i) => a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Add((i:n) => (i<..) => a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero
instance Sub((i:n) => (i<..) => a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Mul(n=>a) given (a|Mul, n|Ix) def (*)(xs, ys) = for i. xs[i] * ys[i] one = for _. one

Basic polymorphic functions and types

def fst(pair:(a, b)) -> a given (a, b) = pair.0
def snd(pair:(a, b)) -> b given (a, b) = pair.1
def swap(pair:(a, b)) -> (b, a) given (a, b) = (x, y) = pair (y, x)
instance Add((a, b)) given (a|Add, b|Add) def (+)(x, y) = (x1, x2) = x (y1, y2) = y (x1 + y1, x2 + y2) zero = (zero, zero)
instance Sub((a, b)) given (a|Sub, b|Sub) def(-)(x, y) = (x1, x2) = x (y1, y2) = y (x1 - y1, x2 - y2)
instance Ix((a, b)) given (a|Ix, b|Ix) def size'() = size a * size b def ordinal(pair) = (i, j) = pair (ordinal i * size b) + ordinal j def unsafe_from_ordinal(o) = bs = size b (unsafe_from_ordinal(n=a, idiv(o, bs)), unsafe_from_ordinal(n=b, rem(o, bs)))
instance Ix((a, b, c)) given (a|Ix, b|Ix, c|Ix) def size'() = size a * size b * size c def ordinal(tup) = (i, j, k) = tup ordinal((i,(j,k))) def unsafe_from_ordinal(o) = (i, (j, k)) = unsafe_from_ordinal o (i, j, k)
instance Ix((a, b, c, d)) given (a|Ix, b|Ix, c|Ix, d|Ix) def size'() = size a * size b * size c * size d def ordinal(tup) = (i, j, k, m) = tup ordinal((i,(j,(k,m)))) def unsafe_from_ordinal(o) = (i, (j, (k, m))) = unsafe_from_ordinal o (i, j, k, m)

Vector spaces

interface VSpace(a|Add|Sub) (.*) : (Float, a) -> a
def (*.)(v:a, s:Float) -> a given (a|VSpace) = s .* v
def (/)( v:a, s:Float) -> a given (a|VSpace) = divide(1.0, s) .* v
def neg( v:a) -> a given (a|VSpace) = (-1.0) .* v
instance VSpace(Float) def (.*)(x, y) = x * y
instance VSpace(n=>a) given (a|VSpace, n|Ix) def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace((a, b)) given (a|VSpace, b|VSpace) def (.*)(s, pair) = (x, y) = pair (s .* x, s .* y)
instance VSpace((i:n) => (..i) => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace((i:n) => (i..) => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace((i:n) => (..<i) => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace((i:n) => (i<..) => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace(()) def (.*)(_, _) = ()

Boolean type

data Bool = False True
def b_to_w8(x:Bool) -> Word8 = %dataConTag(x)
def w8_to_b(x:Word8) -> Bool = %toEnum(Bool, x)
def (&&)(x:Bool, y:Bool) -> Bool = x' = b_to_w8 x y' = b_to_w8 y w8_to_b $ %and(x', y')
def (||)(x:Bool, y:Bool) -> Bool = x' = b_to_w8 x y' = b_to_w8 y w8_to_b $ %or(x', y')
def not(x:Bool) -> Bool = x' = b_to_w8 x w8_to_b $ %not(x')

More Boolean operations

TODO: move these with the others?

-- Can't use `%select` because it lowers to `ISelect`, which requires
-- `a` to be a `BaseTy`.
def select(p:Bool, x:a, y:a) -> a given (a) = case p of True -> x False -> y
def b_to_i(x:Bool) -> Int = w8_to_i(b_to_w8 x)
def b_to_n(x:Bool) -> Nat = w8_to_n(b_to_w8 x)
def b_to_f(x:Bool) -> Float = i_to_f(b_to_i x)

Ordering

TODO: move this down to with Ord?

data Ordering = LT EQ GT
def o_to_w8(x:Ordering) -> Word8 = %dataConTag(x)

Sum types

A sum type, or tagged union can hold values from a fixed set of types, distinguished by tags. For those familiar with the C language, they can be though of as a combination of an enum with a union. Here we define several basic kinds, and some operators on them.

data Maybe(a:Type) = Nothing Just(a)
def is_nothing(x:Maybe a) -> Bool given (a) = case x of Nothing -> True Just(_) -> False
def is_just(x:Maybe a) -> Bool given (a) = not $ is_nothing x
def maybe(d:b, f:(a)->b, x:Maybe a) -> b given (a, b) = case x of Nothing -> d Just(x') -> f x'
data Either(a:Type, b:Type) = Left(a) Right(b)
instance Ix(Either(a, b)) given (a|Ix, b|Ix) def size'() = size a + size b def ordinal(i) = case i of Left(ai) -> ordinal ai Right(bi) -> ordinal bi + size a def unsafe_from_ordinal(o) = as = nat_to_rep $ size a o' = nat_to_rep o -- TODO: Reshuffle the prelude to be able to use (<) here case w8_to_b $ %ilt(o', as) of True -> Left $ unsafe_from_ordinal(n=a, o) -- TODO: Reshuffle the prelude to be able to use `diff_nat` here False -> Right $ unsafe_from_ordinal(n=b, rep_to_nat (%isub(o', as)))

Subtraction on Nats

-- TODO: think more about the right API here

def unsafe_i_to_n(x:Int) -> Nat = rep_to_nat $ internal_cast x
def n_to_i(x:Nat) -> Int = internal_cast (nat_to_rep x)
def i_to_n(x:Int) -> Maybe Nat = if w8_to_b $ %ilt(x, 0::Int) then Nothing else Just $ unsafe_i_to_n x

Monoid

A monoid is a things that have an associative binary operator and an identity element. This is a very useful and general calls of things. It includes:

  • Addition and Multiplication of Numbers
  • Boolean Logic
  • Concatenation of Lists (including strings) Monoids support fold operations, and similar.
interface Monoid(a|Data) mempty : a (<>) : (a, a) -> a
instance Monoid(n=>a) given (a|Monoid, n|Ix) mempty = for i. mempty def (<>)(x, y) = for i. x[i] <> y[i]
named-instance AndMonoid : Monoid(Bool) mempty = True def (<>)(x, y) = x && y
named-instance OrMonoid : Monoid(Bool) mempty = False def (<>)(x, y) = x || y
named-instance AddMonoid(a|Add) -> Monoid(a) mempty = zero def (<>)(x, y) = x + y
named-instance MulMonoid(a|Mul) -> Monoid(a) mempty = one def (<>)(x, y) = x * y

Effects

def Ref(r:Heap, a|Data) -> Type = %Ref(r, a)
def get(ref:Ref h s) -> {State h} s given (h, s) = %get(ref)
def (:=)(ref:Ref h s, x:s) -> {State h} () given (h, s) = %put(ref, x)
def ask(ref:Ref h r) -> {Read h} r given (h, r) = %ask(ref)
data AccumMonoidData(h:Heap, w:Type) = UnsafeMkAccumMonoidData(b:Type, Monoid b)
interface AccumMonoid(h:Heap, w) getAccumMonoidData : AccumMonoidData(h, w)
instance AccumMonoid(h, n=>w) given (n|Ix, h, w) (am:AccumMonoid(h, w)) getAccumMonoidData = UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) UnsafeMkAccumMonoidData(b, bm)
def (+=)(ref:Ref h w, x:w) -> {Accum h} () given (h, w) (am:AccumMonoid(h, w)) = UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) empty = %applyMethod0(bm) %mextend(ref, empty, \x:b y:b. %applyMethod1(bm, x, y), x)
def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i)
def fst_ref(ref: Ref h (a,b)) -> Ref h a given (b, a|Data, h) = ref.0
def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = ref.1
def run_reader( init:r, action:(given (h), Ref h r) -> {Read h|eff} a ) -> {|eff} a given (r|Data, a, eff) = def explicitAction(h':Heap, ref:Ref h' r) -> {Read h'|eff} a = action ref %runReader(init, explicitAction)
def with_reader( init:r, action: (given (h), Ref(h,r)) -> {Read h|eff} a ) -> {|eff} a given (r|Data, a, eff) = run_reader(init, action)
def MonoidLifter(b:Type, w:Type) -> Type = (given (h) (AccumMonoid(h, b))) ->> AccumMonoid(h, w)
named-instance mk_accum_monoid (given (h, w), d:AccumMonoidData(h, w)) -> AccumMonoid(h, w) getAccumMonoidData = d
def run_accum( bm:Monoid b, action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a ) -> {|eff} (a, w) given (a, b, w|Data, eff) (MonoidLifter(b,w)) = empty = %applyMethod0(bm) def explicitAction(h':Heap, ref:Ref h' w) -> {Accum h'|eff} a = accumMonoidData : AccumMonoidData h' b = UnsafeMkAccumMonoidData b bm accumBaseMonoid = mk_accum_monoid accumMonoidData %explicitApply(action, h', accumBaseMonoid, ref) %runWriter(empty, \x:b y:b. %applyMethod1(bm, x, y), explicitAction)
def yield_accum( m:Monoid b, action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a ) -> {|eff} w given (a, b, w|Data, eff) (MonoidLifter b w) = snd $ run_accum(m, action)
def run_state( init:s, action: (given (h), Ref h s) -> {State h |eff} a ) -> {|eff} (a,s) given (a, s|Data, eff) = def explicitAction(h':Heap, ref:Ref h' s) -> {State h'|eff} a = action ref %runState(init, explicitAction)
def with_state( init:s, action: (given (h), Ref h s) -> {State h |eff} a ) -> {|eff} a given (a, s|Data, eff) = fst $ run_state(init, action)
def yield_state( init:s, action: (given (h), Ref h s) -> {State h |eff} a ) -> {|eff} s given (a, s|Data, eff) = snd $ run_state(init, action)
def unsafe_io( f:()->{IO|eff} a ) -> {|eff} a given (a, eff) = f' : (() -> {IO|eff} a) = \. f() %runIO(f')
def unreachable() -> a given (a|Data) = unsafe_io \. %throwError(a)

Type classes

Eq and Ord

Eq

Equatable. Things that we can tell if they are equal or not to other things.

interface Eq(a|Data) (==) : (a, a) -> Bool
def (/=)(x:a, y:a) -> Bool given (a|Eq) = not $ x == y

Ord

Orderable / Comparable. Things that can be place in a total order. i.e. things that can be compared to other things to find if larger, smaller or equal in value.

We take the standard false-hood and pretend that this applies to Floats, even though strictly speaking this not true as our floats follow IEEE754, and thus have NaN < 1.0 == false and 1.0 < NaN == false.

interface Ord(a|Eq) (>) : (a, a) -> Bool (<) : (a, a) -> Bool
def (<=)(x:a, y:a) -> Bool given (a|Ord) = x<y || x==y
def (>=)(x:a, y:a) -> Bool given (a|Ord) = x>y || x==y
instance Eq(Float64) def (==)(x, y) = w8_to_b $ %feq(x, y)
instance Eq(Float32) def (==)(x, y) = w8_to_b $ %feq(x, y)
instance Eq(Int64) def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Int32) def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Word8) def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Word32) def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Word64) def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Bool) def (==)(x, y) = b_to_w8 x == b_to_w8 y
instance Eq(()) def (==)(_, _) = True
instance Eq(Either(a, b)) given (a|Eq, b|Eq) def (==)(x, y) = case x of Left(x) -> case y of Left( y) -> x == y Right(y) -> False Right(x) -> case y of Left( y) -> False Right(y) -> x == y
instance Eq(Maybe a) given (a|Eq) def (==)(x, y) = case x of Just(x) -> case y of Just(y) -> x == y Nothing -> False Nothing -> case y of Just(y) -> False Nothing -> True
instance Eq(RawPtr) def (==)(x, y) = raw_ptr_to_i64 x == raw_ptr_to_i64 y
instance Ord(Float64) def (>)(x, y) = w8_to_b $ %fgt(x, y) def (<)(x, y) = w8_to_b $ %flt(x, y)
instance Ord(Float32) def (>)(x, y) = w8_to_b $ %fgt(x, y) def (<)(x, y) = w8_to_b $ %flt(x, y)
instance Ord(Int64) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(Int32) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(Word8) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(Word32) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(Word64) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(()) def (>)(x, y) = False def (<)(x, y) = False
instance Eq((a, b)) given (a|Eq, b|Eq) def (==)(p1, p2) = (x1, y1) = p1 (x2, y2) = p2 x1 == x2 && y1 == y2
instance Ord((a, b)) given (a|Ord, b|Ord) def (>)(p1, p2) = (x1, y1) = p1 (x2, y2) = p2 x1 > x2 || (x1 == x2 && y1 > y2) def (<)(p1, p2) = (x1, y1) = p1 (x2, y2) = p2 x1 < x2 || (x1 == x2 && y1 < y2)
instance Eq(Ordering) def (==)(x, y) = o_to_w8 x == o_to_w8 y
instance Eq(Nat) def (==)(x, y) = nat_to_rep x == nat_to_rep y
instance Ord(Nat) def (>)(x, y) = nat_to_rep x > nat_to_rep y def (<)(x, y) = nat_to_rep x < nat_to_rep y
-- TODO: we want Eq and Ord for all index sets, not just `Fin n`
instance Eq(Fin n) given (n) def (==)(x, y) = ordinal x == ordinal y
instance Ord(Fin n) given (n) def (>)(x, y) = ordinal x > ordinal y def (<)(x, y) = ordinal x < ordinal y
instance Ix(Bool) def size'() = 2 def ordinal(b) = case b of False -> 0 True -> 1 def unsafe_from_ordinal(i) = i > 0
instance Ix(Maybe a) given (a|Ix) def size'() = size a + 1 def ordinal(i) = case i of Just(ai) -> ordinal ai Nothing -> size a def unsafe_from_ordinal(o) = case o == size a of False -> Just $ unsafe_from_ordinal o True -> Nothing
interface NonEmpty(n|Ix) first_ix : n
instance NonEmpty(()) first_ix = unsafe_from_ordinal(0)
instance NonEmpty(Bool) first_ix = unsafe_from_ordinal 0
instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty) first_ix = unsafe_from_ordinal 0
instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix) first_ix = unsafe_from_ordinal 0
-- The below instance is valid, but causes "multiple candidate dictionaries"
-- errors if both Left and Right are NonEmpty.
-- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b]
-- first_ix = unsafe_from_ordinal _ 0
instance NonEmpty(Maybe a) given (a|Ix) first_ix = unsafe_from_ordinal 0

Fencepost index sets

struct Post(segment:Type) = val : Nat
instance Ix(Post segment) given (segment|Ix) def size'() = size segment + 1 def ordinal(i) = i.val def unsafe_from_ordinal(i) = Post(i)
def left_post(i:n) -> Post n given (n|Ix) = unsafe_from_ordinal(n=Post n, ordinal i)
def right_post(i:n) -> Post n given (n|Ix) = unsafe_from_ordinal(n=Post n, ordinal i + 1)
def left_fence(p:Post n) -> Maybe n given (n|Ix) = ix = ordinal p if ix == 0 then Nothing else Just $ unsafe_from_ordinal $ ix -| 1
def right_fence(p:Post n) -> Maybe n given (n|Ix) = ix = ordinal p if ix == size n then Nothing else Just $ unsafe_from_ordinal ix
def last_ix() ->> n given (n|NonEmpty) = unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))
instance NonEmpty(Post n) given (n|Ix) first_ix = unsafe_from_ordinal(n=Post n, 0)
def scan( init:a, body:(n, a)->(a,b) ) -> (a, n=>b) given (a|Data, b, n|Ix) = swap $ run_state(init) \s. for i. c = get s (c', y) = body(i, c) s := c' y
def fold(init:a, body:(n,a)->a) -> a given (n|Ix, a|Data) = fst $ scan init \i x. (body(i, x), ())
def compare(x:a, y:a) -> Ordering given (a|Ord) = if x < y then LT else if x == y then EQ else GT
instance Monoid(Ordering) mempty = EQ def (<>)(x, y) = case x of LT -> LT GT -> GT EQ -> y
instance Eq(n=>a) given (n|Ix, a|Eq) def (==)(xs, ys) = yield_accum AndMonoid \ref. for i. ref += xs[i] == ys[i]
instance Ord(n=>a) given (n|Ix, a|Ord) def (>)(xs, ys) = f: Ordering = fold EQ $ \i c. c <> compare(xs[i], ys[i]) f == GT def (<)(xs, ys) = f: Ordering = fold EQ $ \i c. c <> compare(xs[i], ys[i]) f == LT

Subset class

interface Subset(subset, superset) inject' : (subset) -> superset project' : (superset) -> Maybe subset unsafe_project' : (superset) -> subset
-- wrappers with more helpful implicit arg names
def inject(x:from) -> to given (to, from) (Subset(from, to)) = inject'(x)
def project(x:from) -> Maybe to given (to, from) (Subset(to, from)) = project'(x)
def unsafe_project(x:from) -> to given (to, from) (Subset(to, from)) = unsafe_project'(x)
instance Subset(a, c) given (a, b, c) (Subset(a, b), Subset(b, c)) def inject'(x) = inject $ inject(to=b, x) def project'(x) = case project(to=b, x) of Nothing -> Nothing Just(y)-> project y def unsafe_project'(x) = unsafe_project $ unsafe_project(to=b, x)
def unsafe_project_rangefrom(j:q) -> RangeFrom(q, i) given (q|Ix, i:q) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i)
instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i def project'(j) = j' = ordinal j i' = ordinal i if j' < i' then Nothing else Just $ RangeFrom $ unsafe_nat_diff(j', i') def unsafe_project'(j) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i)
instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i + 1 def project'(j) = j' = ordinal j i' = ordinal i if j' <= i' then Nothing else Just $ RangeFromExc unsafe_nat_diff(j', i' + 1) def unsafe_project'(j) = RangeFromExc unsafe_nat_diff(ordinal j, ordinal i + 1)
instance Subset(RangeTo(q, i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j i' = ordinal i if j' > i' then Nothing else Just $ RangeTo j' def unsafe_project'(j) = RangeTo (ordinal j)
instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j i' = ordinal i if j' >= i' then Nothing else Just $ RangeToExc j' def unsafe_project'(j) = RangeToExc (ordinal j)
instance Subset(RangeToExc(q, i), RangeTo(q, i)) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j i' = ordinal i if j' >= i' then Nothing else Just $ RangeToExc j' def unsafe_project'(j) = RangeToExc (ordinal j)

Elementary/Special Functions

This is more or less the standard LibM fare. Roughly it lines up with some definitions of the set of Elementary and/or Special. In truth, nothing is elementary or special except that we humans have decided it is. Many, but not all of these functions are Transcendental.

interface Floating(a:Type) exp : (a) -> a exp2 : (a) -> a log : (a) -> a log2 : (a) -> a log10 : (a) -> a log1p : (a) -> a sin : (a) -> a cos : (a) -> a tan : (a) -> a sinh : (a) -> a cosh : (a) -> a tanh : (a) -> a floor : (a) -> a ceil : (a) -> a round : (a) -> a sqrt : (a) -> a pow : (a, a) -> a lgamma : (a) -> a erf : (a) -> a erfc : (a) -> a
def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y)
-- Todo: better numerics for very large and small values.
-- Using %exp here to avoid circular definition problems.
def float32_sinh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))), 2.0)
def float32_cosh(x:Float32) -> Float32 = %fdiv(%fadd(%exp(x), %exp(%fsub(0.0,x))), 2.0)
def float32_tanh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))) ,%fadd(%exp(x), %exp(%fsub(0.0,x))))
-- Todo: unify this with float32 functions.
def float64_sinh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0)
def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0)
def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))) ,%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))))
instance Floating(Float64) def exp(x) = %exp(x) def exp2(x) = %exp2(x) def log(x) = %log(x) def log2(x) = %log2(x) def log10(x) = %log10(x) def log1p(x) = %log1p(x) def sin(x) = %sin(x) def cos(x) = %cos(x) def tan(x) = %tan( x) def sinh(x) = float64_sinh(x) def cosh(x) = float64_cosh(x) def tanh(x) = float64_tanh(x) def floor(x) = %floor(x) def ceil(x) = %ceil(x) def round(x) = %round(x) def sqrt(x) = %sqrt(x) def pow(x,y) = %fpow(x,y) def lgamma(x)= %lgamma(x) def erf(x) = %erf(x) def erfc(x) = %erfc(x)
instance Floating(Float32) def exp(x) = %exp(x) def exp2(x) = %exp2(x) def log(x) = %log(x) def log2(x) = %log2(x) def log10(x) = %log10(x) def log1p(x) = %log1p(x) def sin(x) = %sin(x) def cos(x) = %cos(x) def tan(x) = %tan(x) def sinh(x) = float32_sinh(x) def cosh(x) = float32_cosh(x) def tanh(x) = float32_tanh(x) def floor(x) = %floor(x) def ceil(x) = %ceil(x) def round(x) = %round(x) def sqrt(x) = %sqrt(x) def pow(x,y) = %fpow(x, y) def lgamma(x)= %lgamma(x) def erf(x) = %erf(x) def erfc(x) = %erfc(x)

Raw pointer operations

struct Ptr(a:Type) = val : RawPtr
def cast_ptr(ptr: Ptr a) -> Ptr b given (a, b) = Ptr(ptr.val)
interface Storable(a|Data) store : (Ptr a, a) -> {IO} () load : (Ptr a) -> {IO} a storage_size : () -> Nat
instance Storable(Word8) def store(ptr, x) = %ptrStore(ptr.val, x) def load(ptr) = %ptrLoad(ptr.val) def storage_size() = 1
instance Storable(Int32) def store(ptr, x) = %ptrStore(internal_cast(to=%Int32Ptr(), ptr.val), x) def load(ptr) = %ptrLoad(internal_cast(to=%Int32Ptr(), ptr.val)) def storage_size() = 4
instance Storable(Word32) def store(ptr, x) = %ptrStore(internal_cast(to=%Word32Ptr(), ptr), x) def load(ptr) = %ptrLoad(internal_cast(to=%Word32Ptr(), ptr)) def storage_size() = 4
instance Storable(Float32) def store(ptr, x) = %ptrStore(internal_cast(to=%Float32Ptr(), ptr.val), x) def load(ptr) = %ptrLoad(internal_cast(to=%Float32Ptr(), ptr.val)) def storage_size() = 4
instance Storable(Nat) def store(ptr, x) = store(Ptr(ptr.val), nat_to_rep x) def load(ptr) = rep_to_nat $ load(Ptr(ptr.val)) def storage_size() = storage_size(a=NatRep)
instance Storable(Ptr a) given (a) def store(ptr, x) = %ptrStore(internal_cast(to=%PtrPtr(), ptr.val), x.val) def load(ptr) = Ptr(%ptrLoad(internal_cast(to=%PtrPtr(), ptr))) def storage_size() = 8 -- TODO: something more portable?
-- TODO: Storable instances for other types
def malloc(n:Nat) -> {IO} (Ptr a) given (a|Storable) = numBytes = storage_size(a=a) * n Ptr(%alloc(nat_to_rep numBytes))
def free(ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val)
def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) = i' = nat_to_rep $ i * storage_size(a=a) Ptr(%ptrOffset(ptr.val, i'))
-- TODO: consider making a Storable instance for tables instead
def store_table(ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) = for_ i. store(ptr +>> ordinal i, tab[i])
def memcpy(dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) = for_ i:(Fin n). i' = ordinal i store(dest +>> i', load $ src +>> i')
-- TODO: generalize these brackets to allow other effects
-- TODO: make sure that freeing happens even if there are run-time errors
def with_alloc( n:Nat, action: (Ptr a) -> {IO} b ) -> {IO} b given (a|Storable, b) = ptr = malloc n result = action ptr free ptr result
def with_table_ptr( xs:n=>a, action: (Ptr a) -> {IO} b ) -> {IO} b given (a|Storable, b, n|Ix) = ptr <- with_alloc(size n) for i. store(ptr +>> ordinal i, xs[i]) action ptr
def table_from_ptr(ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) = for i. load $ ptr +>> ordinal i

Miscellaneous common utilities

pi : Float = 3.141592653589793
def id(x:a) -> a given (a) = x
def dup(x:a) -> (a, a) given (a) = (x, x)
def map(f:(a)->{|eff} b, xs: n=>a) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = for i. f xs[i]
-- map, flipped so that the function goes last
def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = for i. f xs[i]
def zip(xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a, b, n|Ix) = for i. (xs[i], ys[i])
def unzip(xys:n=>(a,b)) -> (n=>a , n=>b) given (a, b, n|Ix)= (each xys fst, each xys snd)
def fanout(x:a) -> n=>a given (n|Ix, a) = for i. x
def sq(x:a) -> a given (a|Mul) = x * x
def abs(x:a) -> a given (a|Sub|Ord) = select(x > zero, x, zero - x)
def mod(x:a, y:a) -> a given (a|Add|Integral) = rem(y + rem(x, y), y)
def (>>>)(f:(a) -> b, g:(b) -> c) -> (a) -> c given (a, b, c) = \x. g(f(x))
def (<<<)(f:(b) -> c, g:(a) -> b) -> (a) -> c given (a, b, c) = \x. f(g(x))

Table Operations

instance Floating(n=>a) given (a|Floating, n|Ix) def exp(x) = each x exp def exp2(x) = each x exp2 def log(x) = each x log def log2(x) = each x log2 def log10(x) = each x log10 def log1p(x) = each x log1p def sin(x) = each x sin def cos(x) = each x cos def tan(x) = each x tan def sinh(x) = each x sinh def cosh(x) = each x cosh def tanh(x) = each x tanh def floor(x) = each x floor def ceil(x) = each x ceil def round(x) = each x round def sqrt(x) = each x sqrt def pow(x, y) = for i. pow(x[i], y[i]) def lgamma(x) = each x lgamma def erf(x) = each x erf def erfc(x) = each x erfc

Reductions

-- `combine` should be a commutative and associative, and form a
-- commutative monoid with `identity`
def reduce(identity:a, combine:(a,a)->a, xs:n=>a) -> a given (a|Data, n|Ix) = -- TODO: implement with the accumulator effect fold identity \i c. combine(c, xs[i])
-- TODO: call this `scan` and call the current `scan` something else
def scan'(init:a, body:(n,a)->a) -> n=>a given (a|Data, n|Ix) = snd $ scan init \i x. dup(body(i, x))
def fsum(xs:n=>Float) -> Float given (n|Ix) = yield_accum(AddMonoid Float) \ref. for i. ref += xs[i]
def sum(xs:n=>v) -> v given (n|Ix, v|Add) = reduce(zero, (+), xs)
def prod(xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(one , (*), xs)
def mean(xs:n=>v) -> v given (n|Ix, v|VSpace) = sum xs / n_to_f (size n)
def std(xs:n=>v) -> v given (n|Ix, v|Mul|Sub|VSpace|Floating) = sqrt $ mean (each xs sq) - sq (mean xs)
def any(xs:n=>Bool) -> Bool given (n|Ix) = reduce(False, (||), xs)
def all(xs:n=>Bool) -> Bool given (n|Ix) = reduce(True , (&&), xs)

apply_n

def apply_n(n:Nat, x:a, f:(a) -> a) -> a given (a|Data) = yield_state x \ref. for _:(Fin n). ref := f (get ref)

cumulative sum

TODO: Move this to be with reductions? It's a kind of scan.

def cumsum(xs: n=>a) -> n=>a given (n|Ix, a|Add) = total <- with_state zero for i. newTotal = get total + xs[i] total := newTotal newTotal
def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) = total <- with_state zero for i. oldTotal = get total total := oldTotal + xs[i] oldTotal

Automatic differentiation

AD operations

-- TODO: add vector space constraints
def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a, b) = %linearize(\x. f x, x)
def jvp(f:(a)->b, x:a, t:a) -> b given (a, b) = (snd $ linearize(f, x))(t)
def transpose_linear(f:(a)->b) -> (b)->a given (a, b) = \ct. %linearTranspose(\x. f x, ct)
def vjp(f:(a)->b, x:a) -> (b, (b)->a) given (a, b) = (y, df) = linearize(f, x) (y, transpose_linear df)
def grad(f:(a)->Float, x:a) -> a given (a) = (snd vjp(f, x))(1.0)
def deriv(f:(Float)->Float, x:Float) -> Float = jvp(f, x, 1.0)
def deriv_rev(f:(Float)->Float, x:Float) -> Float = (snd vjp(f, x))(1.0)
-- XXX: Watch out when editing this data type! We depend on its structure
-- deep inside the compiler (mostly in linearization and during rule registration).
data SymbolicTangent(a) = ZeroTangent SomeTangent(a)
def someTangent(x:SymbolicTangent a) -> a given (a|VSpace) = case x of ZeroTangent -> zero SomeTangent(x') -> x'

Approximate Equality

TODO: move this outside the AD section to be with equality?

interface HasAllClose(a) allclose : (a, a, a, a) -> Bool
interface HasDefaultTolerance(a) default_atol : a default_rtol : a
def (~~)(x:a, y:a) -> Bool given (a|HasAllClose|HasDefaultTolerance) = allclose(default_atol, default_rtol, x, y)
instance HasAllClose(Float32) def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y)
instance HasAllClose(Float64) def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y)
instance HasDefaultTolerance(Float32) default_atol = f_to_f32 0.00001 default_rtol = f_to_f32 0.0001
instance HasDefaultTolerance(Float64) default_atol = f_to_f64 0.00000001 default_rtol = f_to_f64 0.00001
instance HasAllClose((a, b)) given ( a|HasDefaultTolerance|HasAllClose , b|HasDefaultTolerance|HasAllClose) def allclose(atol, rtol, pair1, pair2) = (x1, x2) = pair1 (y1, y2) = pair2 (x1 ~~ y1) && (x2 ~~ y2)
instance HasDefaultTolerance((a, b)) given (a|HasDefaultTolerance,b|HasDefaultTolerance) default_atol = (default_atol, default_atol) default_rtol = (default_rtol, default_rtol)
instance HasAllClose(n=>t) given (n|Ix, t|HasAllClose) def allclose(atol, rtol, a, b) = all for i:n. allclose(atol[i], rtol[i], a[i], b[i])
instance HasDefaultTolerance(n=>t) given (n|Ix, t|HasDefaultTolerance) default_atol = for i. default_atol default_rtol = for i. default_rtol

AD Checking tools

def check_deriv_base(f:(Float)->Float, x:Float) -> Bool = eps = 0.01 ansFwd = deriv( f, x) ansRev = deriv_rev( f, x) ansNumeric = (f (x + eps) - f (x - eps)) / (2. * eps) ansFwd ~~ ansNumeric && ansRev ~~ ansNumeric
def check_deriv(f:(Float)->Float, x:Float) -> Bool = check_deriv_base(f, x) && check_deriv_base(\x. deriv(f, x), x)

Length-erased lists

data List(a)= AsList(n:Nat, elements:(Fin n => a))
instance Eq(List a) given (a|Eq) def (==)(xsList, ysList) = AsList(nx,xs) = xsList AsList(ny,ys) = ysList if nx /= ny then False else all for i:(Fin nx). xs[i] == ys[unsafe_from_ordinal (ordinal i)]
def unsafe_cast_table(xs:from=>a) -> to=>a given (to|Ix, from|Ix, a) = for i. xs[unsafe_from_ordinal (ordinal i)]
def to_list(xs:n=>a) -> List a given (n|Ix, a) = n' = size n AsList(_, unsafe_cast_table(to=Fin n', xs))
instance Monoid(List a) given (a|Data) mempty = AsList(_, []) def (<>)(x, y) = AsList(nx,xs) = x AsList(ny,ys) = y nz = nx + ny to_list for i:(Fin nz). i' = ordinal i case i' < nx of True -> xs[unsafe_from_ordinal i'] False -> ys[unsafe_from_ordinal $ unsafe_nat_diff(i', nx)]
named-instance ListMonoid (a|Data) -> Monoid(List a) mempty = mempty def (<>)(x, y) = x <> y
-- TODO Eliminate or reimplement this operation, since it costs O(n)
-- where n is the length of the list held in the reference.
def append(list: Ref(h, List a), x:a) -> {Accum h} () given (a|Data, h) (AccumMonoid(h, List a)) = list += to_list [x]
-- TODO: replace `slice` with this?
def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a) = slice_size = unsafe_nat_diff(ordinal end, ordinal start) to_list for i:(Fin slice_size). xs[unsafe_from_ordinal(n=n, ordinal i + ordinal start)]

Strings and Characters

String : Type = List Char
def string_from_char_ptr(n:Word32, ptr:Ptr Char) -> {IO} String = AsList(rep_to_nat n, table_from_ptr ptr)
-- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint
def codepoint(c:Char) -> Int = w8_to_i c
struct CString = ptr : RawPtr
-- TODO: check the string contains no nulls
def with_c_string( s:String, action: (CString) -> {IO} a ) -> {IO} a given (a) = AsList(n, s') = s <> "\NUL" with_table_ptr s' \ptr. action CString(ptr.val)

Show interface

For things that can be shown. show gives a string representation of its input. No particular promises are made to exactly what that representation will contain. In particular it is not promised to be parseable. Nor does it promise a particular level of precision for numeric values.

interface Show(a) show : (a) -> String
instance Show(String) def show(x) = x
foreign "showInt32" showInt32 : (Int32) -> {IO} (Word32, RawPtr)
instance Show(Int32) def show(x) = unsafe_io \. (n, ptr) = showInt32 x string_from_char_ptr(n, Ptr ptr)
foreign "showInt64" showInt64 : (Int64) -> {IO} (Word32, RawPtr)
instance Show(Int64) def show(x) = unsafe_io \. (n, ptr) = showInt64 x string_from_char_ptr(n, Ptr ptr)
instance Show(Nat) def show(x) = show $ n_to_i64 x
foreign "showFloat32" showFloat32 : (Float32) -> {IO} (Word32, RawPtr)
instance Show(Float32) def show(x) = unsafe_io \. (n, ptr) = showFloat32 x string_from_char_ptr(n, Ptr ptr)
foreign "showFloat64" showFloat64 : (Float64) -> {IO} (Word32, RawPtr)
instance Show(Float64) def show(x) = unsafe_io \. (n, ptr) = showFloat64 x string_from_char_ptr(n, Ptr ptr)
instance Show(()) def show(_) = "()"
instance Show((a, b)) given (a|Show, b|Show) def show(tup) = (x, y) = tup "(" <> show x <> ", " <> show y <> ")"
instance Show((a, b, c)) given (a|Show, b|Show, c|Show) def show(tup) = (x, y, z) = tup "(" <> show x <> ", " <> show y <> ", " <> show z <> ")"
instance Show((a, b, c, d)) given (a|Show, b|Show, c|Show, d|Show) def show(tup) = (x, y, z, w) = tup "(" <> show x <> ", " <> show y <> ", " <> show z <> ", " <> show w <> ")"

Parse interface

For types that can be parsed from a String.

interface Parse(a) parseString : (String) -> Maybe a
foreign "strtof" strtofFFI : (RawPtr, RawPtr) -> {IO} Float
instance Parse(Float) def parseString(str) = unsafe_io \. AsList(str_len, _) = str with_c_string str \cStr. with_alloc 1 \end_ptr:(Ptr (Ptr Char)). result = strtofFFI(cStr.ptr, end_ptr.val) str_end_ptr = load end_ptr consumed = raw_ptr_to_i64 str_end_ptr.val - raw_ptr_to_i64 cStr.ptr if consumed == (n_to_i64 str_len) then Just result else Nothing

Floating-point helper functions

TODO: Move these to be with Elementary/Special functions. Or move those to be here.

def sign(x:Float) -> Float = case x > 0.0 of True -> 1.0 False -> case x < 0.0 of True -> -1.0 False -> x
def copysign(a:Float, b:Float) -> Float = case b > 0.0 of True -> a False -> case b < 0.0 of True -> (-a) False -> 0.0
-- Todo: use IEEE floating-point builtins.
infinity = 1.0 / 0.0
nan = 0.0 / 0.0
-- Todo: use IEEE floating-point builtins.
def isinf(x:Float) -> Bool = (x == infinity) || (x == -infinity)
def isnan(x:Float) -> Bool = not (x >= x && x <= x)
-- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered.
def either_is_nan(x:Float, y:Float) -> Bool = (isnan x) || (isnan y)

File system operations

FilePath : Type = String
def is_null_raw_ptr(ptr:RawPtr) -> Bool = raw_ptr_to_i64 ptr == 0
def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a) = if is_null_raw_ptr ptr then Nothing else Just $ Ptr ptr
def c_string_ptr(s:CString) -> Maybe (Ptr Char) = from_nullable_raw_ptr s.ptr
data StreamMode = ReadMode WriteMode
struct Stream(mode:StreamMode) = ptr : RawPtr

Stream IO

foreign "fopen" fopenFFI : (RawPtr, RawPtr) -> {IO} RawPtr
foreign "fclose" fcloseFFI : (RawPtr) -> {IO} Int64
foreign "fwrite" fwriteFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64
foreign "fread" freadFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64
foreign "fflush" fflushFFI : (RawPtr) -> {IO} Int64
def fopen(path:String, mode:StreamMode) -> {IO} (Stream mode) = modeStr = case mode of ReadMode -> "r" WriteMode -> "w" with_c_string path \cPath. with_c_string modeStr \cMode. Stream $ fopenFFI(cPath.ptr, cMode.ptr)
def fclose(stream:Stream mode) -> {IO} () given (mode) = fcloseFFI stream.ptr ()
def fwrite(stream:Stream WriteMode, s:String) -> {IO} () = AsList(n, s') = s with_table_ptr s' \ptr. fwriteFFI(ptr.val, i_to_i64 1, n_to_i64 n, stream.ptr) fflushFFI stream.ptr ()

Iteration

TODO: move this out of the file-system section

def while(body: () -> {|eff} Bool) -> {|eff} () given (eff) = body' : () -> {|eff} Word8 = \. b_to_w8 $ body() %while(body')
data IterResult(a|Data) = Continue Done(a)
-- TODO: can we improve effect inference so we don't need this?
def lift_state(ref: Ref(h, c), f:(a) -> {|eff} b, x:a) -> {State h|eff} b given (a, b, c, h, eff) = f x
-- A little iteration combinator
def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) = result = yield_state Nothing \resultRef. i <- with_state 0 while \. continue = is_nothing $ get resultRef if continue then case lift_state(resultRef, (\x. lift_state(i, body, x)), get i) of Continue -> i := get i + 1 Done(result) -> resultRef := Just result continue case result of Just(ans) -> ans Nothing -> unreachable()
def bounded_iter( maxIters:Nat, fallback:a, body:(Nat) -> {|eff} IterResult a ) -> {|eff} a given (a|Data, eff) = iter \i. if i >= maxIters then Done fallback else body i

Print

def get_output_stream() -> {IO} Stream WriteMode = Stream $ %outputStream()
@noinline def print(s:String) -> {IO} () = stream = get_output_stream() fwrite(stream, s) fwrite(stream, "\n")

Partial functions

A partial function in this context is a function that can error. i.e. a function that is not actually defined for all of its supposed domain. Not to be confused with a partially applied function

Error throwing

@noinline def error(s:String) -> a given (a|Data) = unsafe_io \. print s %throwError(a)
def todo() ->> a given (a|Data) = error "TODO: implement it!"

Table operations

@noinline def from_ordinal_error(i:Nat, upper:Nat) -> String = "Ordinal index out of range:" <> show i <> " >= " <> show upper
def from_ordinal(i:Nat) -> n given (n|Ix) = case i < size n of True -> unsafe_from_ordinal i False -> error $ from_ordinal_error(i, size n)
-- TODO: should this be called `from_ordinal`?
def to_ix(i:Nat) -> Maybe n given (n|Ix) = case i < size n of True -> Just $ unsafe_from_ordinal i False -> Nothing
-- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy
-- TODO: safe (runtime-checked) and unsafe versions
def cast_table(xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) = case size from == size to of True -> unsafe_cast_table xs False -> error $ "Table size mismatch in cast: " <> show (size from) <> " vs " <> show (size to)
def asidx(i:Nat) -> n given (n|Ix) = from_ordinal i
def (@)(i:Nat, n|Ix) -> n = from_ordinal i
def slice(xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a) = for i. xs[from_ordinal (ordinal i + start)]
def head(xs:n=>a) -> a given (n|Ix, a) = xs[0@_]
def tail(xs:n=>a, start:Nat) -> List a given (n|Ix, a) = numElts = size n -| start to_list $ slice(xs, start, Fin numElts)

Pseudorandom number generator utilities

Dex does not use a stateful random number generator. Rather it uses what is known as a split-able random number generator, which is based on a hash function. Dex's PRNG system is modelled directly after JAX's, which is based on a well established but shockingly underused idea from the functional programming community: the splittable PRNG. It's a good idea for many reasons, but it's especially helpful in a parallel setting. If you want to read more, Splittable pseudorandom number generators using cryptographic hashing describes the splitting model itself and D.E. Shaw Research's counter-based PRNG proposes the particular hash function we use.

Key functions

-- TODO: newtype
Key = Word64
@noinline def threefry_2x32(k:Word64, count:Word64) -> Word64 = -- Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins rotations1 = [13, 15, 26, 6] rotations2 = [17, 29, 16, 24] k0 = low_word k k1 = high_word k -- TODO: add a fromHex k2 = k0 .^. k1 .^. (n_to_w32 466688986) -- 0x1BD11BDA x = low_word count y = high_word count x = x + k0 y = y + k1 rotations = [rotations1, rotations2] ks = [k1, k2, k0] (x, y) = yield_state (x, y) \ref. for i:(Fin 5). for j. (x, y) = get ref rotationIndex = unsafe_from_ordinal (ordinal i `mod` 2) rot = rotations[rotationIndex, j] x = x + y y = (y .<<. rot) .|. (y .>>. (32 - rot)) y = x .^. y ref := (x, y) (x, y) = get ref x = x + ks[unsafe_from_ordinal (ordinal i `mod` 3)] y = y + ks[unsafe_from_ordinal (((ordinal i)+1) `mod` 3)] + n_to_w32 ((ordinal i)+1) ref := (x, y) (w32_to_w64 x .<<. 32) .|. (w32_to_w64 y)
def hash(x:Key, y:Nat) -> Key = y64 = n_to_w64 y threefry_2x32(x, y64)
def new_key(x:Nat) -> Key = hash(0, x)
def many(f:(Key)->a, k:Key, i:n) -> a given (a, n|Ix) = f hash(k, ordinal i)
def ixkey(k:Key, i:n) -> Key given (n|Ix) = hash(k, ordinal i)
def split_key(k:Key) -> Fin n => Key given (n:Nat) = for i. ixkey(k, i)

Sample Generators

These functions generate samples taken from, different distributions. Such as rand_mat with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. Note that additional standard distributions are provided by the stats library.

def rand(k:Key) -> Float = exponent_bits = 1065353216 -- 1065353216 = 127 << 23 mantissa_bits = (high_word k .&. 8388607) -- 8388607 == (1 << 23) - 1 bits = exponent_bits .|. mantissa_bits %bitcast(Float, bits) - 1.0
def rand_vec(n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (a) = for i:(Fin n). f ixkey(k, i)
def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a) = for i j. f ixkey(k, (i, j))
def randn(k:Key) -> Float = [k1, k2] = split_key k -- rand is uniform between 0 and 1, but implemented such that it rounds to 0 -- (in float32) once every few million draws, but never rounds to 1. u1 = 1.0 - (rand k1) u2 = rand k2 sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2)
-- TODO: Make this better...
def rand_int(k:Key) -> Nat = w64_to_n k `mod` 2147483647
def randn_vec(k:Key) -> n=>Float given (n|Ix) = for i. randn (ixkey(k, i))
def rand_idx(k:Key) -> n given (n|Ix) = rand k * n_to_f (size n) | floor | f_to_n | unsafe_from_ordinal

Inner product typeclass

interface InnerProd(v|VSpace) inner_prod : (v, v) -> Float
instance InnerProd(Float) def inner_prod(x, y) = x * y
instance InnerProd(n=>a) given (a|InnerProd, n|Ix) def inner_prod(x, y) =sum for i. inner_prod(x[i], y[i])

Arbitrary

Type class for generating example values

interface Arbitrary(a) arb : (Key) -> a
instance Arbitrary(Bool) def arb(key) = key .&. 1 == 0
instance Arbitrary(Float32) def arb(key) = randn key
instance Arbitrary(Int32) def arb(key) = f_to_i $ randn key * 5.0
instance Arbitrary(Nat) def arb(key) = f_to_n $ randn key * 5.0
instance Arbitrary(n=>a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((i:n)=>(..<i) => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((i:n)=>(..i) => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((i:n)=>(i..) => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((a, b)) given (a|Arbitrary, b|Arbitrary) def arb(key) = [k1, k2] = split_key key (arb k1, arb k2)
instance Arbitrary(Fin n) given (n) def arb(key) = rand_idx key

Ord on Arrays

Searching

Returns the bucket of x assuming boundaries xs as a Post n. The boundaries must already be sorted, and are inclusive on the left.

In other words, if there is an index i such that xs.i <= x, returns the right_post of the highest such index; otherwise returns first_ix : Post n, which is also the left_post of the minimum i.

This is equivalent to the right-biased formulation: if an index i exists such that x < xs.i, returns the left_post of the least such i, otherwise returns last_ix : Post n, i.e., the right_post of the maximum i.

def search_sorted(xs:n=>a, x:a) -> Post n given (n|Ix, a|Ord) = if size n == 0 then first_ix else if x < xs[from_ordinal 0] then first_ix else low <- with_state(0) high <- with_state(size n) _ <- iter numLeft = n_to_i (get high) - n_to_i (get low) if numLeft == 1 then Done $ right_post $ from_ordinal $ get low else centerIx = get low + unsafe_i_to_n (numLeft `idiv` 2) if x < xs[from_ordinal centerIx] then high := centerIx else low := centerIx Continue

If i exists such that xs.i == x, returns Just of the largest such i, otherwise returns Nothing.

def search_sorted_exact(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) = case left_fence(search_sorted(xs, x)) of Just i -> if xs[i] == x then Just i else Nothing Nothing -> Nothing

min / max etc

def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y)
def max_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x > f y, x, y)
def min(x1: o, x2: o) -> o given (o|Ord) = min_by(id, x1, x2)
def max(x1: o, x2: o) -> o given (o|Ord) = max_by(id, x1, x2)
def minimum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = reduce(xs[0@_], \x y. min_by(f, x, y), xs)
def maximum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = reduce(xs[0@_], \x y. max_by(f, x, y), xs)
def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(id, xs)
def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(id, xs)

argmin/argmax

-- TODO: put in same section as searchsorted

def argscan(comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) = zeroth = (0@_, xs[0@_]) compare = \p1 p2. (idx1, x1) = p1 (idx2, x2) = p2 select(comp(x1, x2), (idx1, x1), (idx2, x2)) zipped = for i. (i, xs[i]) fst $ reduce(zeroth, compare, zipped)
def argmin(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((<), xs)
def argmax(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((>), xs)
def lexical_order( compareElements:(n,n)->Bool, compareLengths: (Nat,Nat)->Bool, xList:List n, yList:List n ) -> Bool given (n|Ord) = -- Orders Lists according to the order of their elements, -- in the same way a dictionary does. -- For example, this lets us sort Strings. -- -- More precisely, it returns True iff compareElements xs.i ys.i is true -- at the first location they differ. -- -- This function operates serially and short-circuits -- at the first difference. One could also write this -- function as a parallel reduction, but it would be -- wasteful in the case where there is an early difference, -- because we can't short circuit. AsList(nx, xs) = xList AsList(ny, ys) = yList iter \i. case i == min(nx, ny) of True -> Done $ compareLengths(nx, ny) False -> xi = xs[unsafe_from_ordinal i] yi = ys[unsafe_from_ordinal i] case compareElements(xi, yi) of True -> Done True False -> case xi == yi of True -> Continue False -> Done False
instance Ord(List n) given (n|Ord) def (>)(xs, ys) = lexical_order((>), (>), xs, ys) def (<)(xs, ys) = lexical_order((<), (<), xs, ys)

clip

def clip(bounds:(a,a), x:a) -> a given (a|Ord) = (low,high) = bounds min(high, max(low, x))

Trigonometric functions

TODO: these should be with the other Elementary/Special Functions

atan/atan2

def atan_inner(x:Float) -> Float = -- From "Computing accurate Horner form approximations to -- special functions in finite precision arithmetic" -- https://arxiv.org/abs/1508.03211 -- Only accurate in the range [-1, 1] s = x * x r = 0.0027856871 r = r * s - 0.0158660002 r = r * s + 0.042472221 r = r * s - 0.0749753043 r = r * s + 0.106448799 r = r * s - 0.142070308 r = r * s + 0.199934542 r = r * s - 0.333331466 r = r * s r * x + x
def min_and_max(x:a, y:a) -> (a, a) given (a|Ord) = select(x < y, (x, y), (y, x)) -- get both with one comparison.
def atan2(y:Float, x:Float) -> Float = -- Based off of the Tensorflow implementation at -- github.com/tensorflow/mlir-hlo/blob/master/lib/ -- Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147 -- With a fix to the nan propagation. abs_x = abs x abs_y = abs y (min_abs_x_y, max_abs_x_y) = min_and_max(abs_x, abs_y) a = atan_inner (min_abs_x_y / max_abs_x_y) a = select(abs_x <= abs_y, (pi / 2.0) -a, a) a = select(x < 0.0, pi - a, a) t = select(x < 0.0, pi, 0.0) a = select(y == 0.0, t, a) t = select(x < 0.0, 3.0 * pi / 4.0, pi / 4.0) a = select(isinf x && isinf y, t, a) -- Handle infinite inputs. a = copysign(a, y) select(either_is_nan(x, y), nan, a) -- Propagate NaNs.
def atan(x:Float) -> Float = atan2(x, 1.0)

Miscellaneous utilities

TODO: all of these should be in some other section

def reflect(i:n) -> n given (n|Ix) = unsafe_from_ordinal $ unsafe_nat_diff(size n, ordinal i + 1)
def reverse(x:n=>a) -> n=>a given (n|Ix, a) = for i. x[reflect i]
def wrap_periodic(n|Ix, i:Nat) -> n = unsafe_from_ordinal(n=n, i `mod` size n)
def pad_to(m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) = n' = size n for i. i' = ordinal i case i' < n' of True -> xs[i'@_] False -> x
def idiv_ceil(x:Nat, y:Nat) -> Nat = x `idiv` y + b_to_n (x `rem` y /= 0)
def intdiv2(x:Nat) -> Nat = rep_to_nat $ %shr(nat_to_rep x, 1 :: NatRep)
def intpow2(power:Nat) -> Nat = rep_to_nat $ %shl(1 :: NatRep, nat_to_rep power)
def is_odd(x:Nat) -> Bool = rem(x, 2) == 1
def is_even(x:Nat) -> Bool = rem(x, 2) == 0
def is_power_of_2(x:Nat) -> Bool = -- A fast trick based on bitwise AND. -- This works on integer types larger than 8 bits. -- Note: The bitwise and operator (.&.) -- is only defined for Byte, which is why -- we use %and here. TODO: Make (.&.) polymorphic. x' = nat_to_rep x if x' == 0 then False else 0 == %and(x', (%isub(x', 1::NatRep)))
-- This computes the integer part of the binary logarithm of the input.
-- TODO: natlog2 0 should do something other than underflow the answer.
-- TODO: Use LLVM ctlz intrinsic instead. It needs a slightly new
-- code path in ImpToLLVM, because it's the first LLVM intrinsic
-- we have with a fixed-point argument.
-- https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic
def natlog2(x:Nat) -> Nat = tmp = yield_state 0 \ans. cmp <- run_state 1 while \. if x >= (get cmp) then ans := (get ans) + 1 cmp := rep_to_nat $ %shl(nat_to_rep $ get cmp, 1 :: NatRep) True else False unsafe_nat_diff(tmp, 1) -- TODO: something less horrible
def nextpow2(x:Nat) -> Nat = case is_power_of_2 x of True -> natlog2 x False -> 1 + natlog2 x
def general_integer_power( times:(a,a)->a, one:a, base:a, power:Nat ) -> a given (a|Data) = iters = if power == 0 then 0 else 1 + natlog2 power -- Implements exponentiation by squaring. -- This could be nicer if there were a way to explicitly -- specify which typelcass instance to use for Mul. yield_state one \ans. pow <- with_state power z <- with_state base for _:(Fin iters). if is_odd (get pow) then ans := times(get ans, get z) z := times(get z, get z) pow := intdiv2 (get pow)
def intpow(base:a, power:Nat) -> a given (a|Mul) = general_integer_power((*), one, base, power)
def from_just(x:Maybe a) -> a given (a) = case x of Just(x') -> x'
def any_sat(f:(a)->Bool, xs:n=>a) -> Bool given (a, n|Ix) = any(each xs f)
def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a) = -- is it possible to implement this safely? (i.e. without using partial -- functions) case any_sat(is_nothing, xs) of True -> Nothing False -> Just $ each xs from_just
def linear_search(xs:n=>a, query:a) -> Maybe n given (n|Ix, a|Eq) = yield_state Nothing \ref. for i. case xs[i] == query of True -> ref := Just i False -> ()
def list_length(l:List a) -> Nat given (a) = AsList(n, _) = l n
-- This is for efficiency (rather than using `<>` repeatedly)
-- TODO: we want this for any monoid but this implementation won't work.
def concat(lists:n=>(List a)) -> List a given (a, n|Ix) = totalSize = sum for i. list_length lists[i] to_list $ with_state 0 \listIdx. eltIdx <- with_state 0 for i:(Fin totalSize). while \. continue = get eltIdx >= list_length (lists[(get listIdx)@_]) if continue then eltIdx := 0 listIdx := get listIdx + 1 else () continue AsList(_, xs) = lists[(get listIdx)@_] eltIdxVal = get eltIdx eltIdx := eltIdxVal + 1 xs[eltIdxVal@_]
def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = (num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref. for i. case xs[i] of Just(_) -> ix = get ref.0 ref.1 ! (unsafe_from_ordinal ix) := Just i ref.0 := ix + 1 Nothing -> () to_list $ for i:(Fin num_res). case res_inds[unsafe_from_ordinal $ ordinal i] of Just(j) -> case xs[j] of Just(x) -> x Nothing -> todo -- Impossible Nothing -> todo -- Impossible
def filter(xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) = cat_maybes $ for i. if condition xs[i] then Just xs[i] else Nothing
def arg_filter(xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) = cat_maybes $ for i. if condition xs[i] then Just i else Nothing
-- TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead
def prev_ix(i:n) -> Maybe n given (n|Ix) = case i_to_n (n_to_i (ordinal i) - 1) of Nothing -> Nothing Just(i_prev) -> unsafe_from_ordinal(i_prev) | Just
def lines(source:String) -> List String = AsList(_, s) = source AsList(num_lines, newline_ixs) = cat_maybes for i_char. if s[i_char] == '\n' then Just(i_char) else Nothing to_list for i_line:(Fin num_lines). start = case prev_ix i_line of Nothing -> first_ix Just(i) -> right_post newline_ixs[i] end = left_post newline_ixs[i_line] post_slice(s, start, end)

Probability

-- cdf should include 0.0 but not 1.0
def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) = r = rand key from_just $ left_fence $ search_sorted(cdf, r)
def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs
def cdf_for_categorical(logprobs: n=>Float) -> n=>Float given (n|Ix) = maxLogProb = maximum logprobs cumsum_low $ normalize_pdf $ for i. exp(logprobs[i] - maxLogProb)
def categorical(logprobs: n=>Float, key: Key) -> n given (n|Ix) = categorical_from_cdf(cdf_for_categorical logprobs, key)
-- batch variant to share the work of forming the cumsum
-- (alternatively we could rely on hoisting of loop constants)
def categorical_batch(logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) = cdf = cdf_for_categorical logprobs for i. categorical_from_cdf(cdf, ixkey(key, i))
def logsumexp(x: n=>Float) -> Float given (n|Ix) = m = maximum x m + (log $ sum for i. exp (x[i] - m))
def logsoftmax(x: n=>Float) -> n=>Float given (n|Ix) = lse = logsumexp x for i. x[i] - lse
def softmax(x: n=>Float) -> n=>Float given (n|Ix) = m = maximum x e = for i. exp (x[i] - m) s = sum e for i. e[i] / s

Polynomials

TODO: Move this somewhere else

def evalpoly(coefficients:n=>v, x:Float) -> v given (n|Ix, v|VSpace) = -- Evaluate a polynomial at x. Same as Numpy's polyval. fold zero \i c. coefficients[i] + x .* c

Exception effect

-- TODO: move error and todo to here.

def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a, eff)= f' : (() -> {Except|eff} a) = \. f() %catchException(f')
def throw() -> {Except} a given (a) = %throwException(a)
def assert(b:Bool) -> {Except} () = if not b then throw()

Misc instances that require error

instance Subset(a, Either(a,b)) given (a|Data, b|Data) def inject'(x) = Left x def project'(x) = case x of Left( y) -> Just y Right(x) -> Nothing def unsafe_project'(x) = case x of Left( x) -> x Right(x) -> error "Can't project Right branch to Left branch"
instance Subset(b, Either(a,b)) given (a|Data, b|Data) def inject'(x) = Right x def project'(x) = case x of Left( x) -> Nothing Right(y) -> Just y def unsafe_project'(x) = case x of Left( x) -> error "Can't project Left branch to Right branch" Right(x) -> x

Index set for tables

def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) = base = size b snd $ scan k \_ cur_k. next_k = cur_k `idiv` base digit = cur_k `mod` base (next_k, unsafe_from_ordinal(n=b, digit))
def reversed_digits_to_int(digits: a=>b) -> Nat given (a|Ix, b|Ix) = base = size b fst $ fold (0, 1) \j pair. (cur_k, cur_base) = pair next_k = cur_k + ordinal digits[j] * cur_base next_base = cur_base * base (next_k, next_base)
instance Ix(a=>b) given (a|Ix, b|Ix) -- 0@a is the least significant digit, -- while (size a - 1)@a is the most significant digit. def size'() = size b `intpow` size a def ordinal(i) = reversed_digits_to_int i def unsafe_from_ordinal(i) = int_to_reversed_digits i
instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty) first_ix = unsafe_from_ordinal 0

Stack

struct Stack(h:Heap, a|Data) = size_ref : Ref h Nat buf_ref : Ref h (List a) def size() -> {State h} Nat = get self.size_ref def unsafe_get_buffer() -> {State h} (Ref(h, Fin 0 => a)) = get $ snd_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), self.buf_ref) def buf_size() -> {State h} Nat = get $ fst_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), self.buf_ref) def ensure_size_at_least(req_size:Nat) -> {State h} () = if req_size > self.buf_size() then new_buf_size = intpow2 $ nextpow2 req_size buf = self.unsafe_get_buffer() logical_size = self.size() cur_data = get $ unsafe_coerce(to=Ref(h, Fin logical_size => a), buf) self.buf_ref := to_list for i:(Fin new_buf_size). case to_ix(n=Fin logical_size, ordinal i) of Just(i') -> cur_data[i'] Nothing -> uninitialized_value() def read() -> {State h} (List a) = n = self.size() buf = unsafe_coerce(to=Ref(h, Fin n => a), self.unsafe_get_buffer()) AsList(n, get buf) @noinline def push(x:a) -> {State h} () = n_old = self.size() n_new = n_old + 1 self.ensure_size_at_least(n_new) buf = self.unsafe_get_buffer() buf ! (unsafe_from_ordinal n_old) := x self.size_ref := n_new @noinline def extend(x:n=>a) -> {State h} () given (n|Ix) = n_old = self.size() n_new = n_old + size n self.ensure_size_at_least(n_new) buf = self.unsafe_get_buffer() buf_slice = unsafe_coerce(to=Ref(h,n=>a), buf ! (unsafe_from_ordinal n_old)) buf_slice := x self.size_ref := n_new def pop() -> {State h} Maybe a = n_old = self.size() case n_old == 0 of True -> Nothing False -> n_new = unsafe_nat_diff(n_old, 1) buf = self.unsafe_get_buffer() self.size_ref := n_new Just $ get buf!(unsafe_from_ordinal n_new)
stack_init_size = 16
def with_stack( a|Data, action:(given (h:Heap), Stack(h, a)) -> {State h|eff} r ) -> {|eff} r given (eff, r) = init_stack = to_list for i:(Fin stack_init_size). uninitialized_value() with_state (0, init_stack) \ref . action(Stack(ref.0, ref.1))
def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) = stack.extend(x)
def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h) = stack.push(x)
def with_stack_internal(f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char = with_stack Char \stack. f stack stack.read()

Environment Variables

def from_c_string(s:CString) -> {IO} (Maybe String) = case c_string_ptr s of Nothing -> Nothing Just(ptr) -> Just do stack <- with_stack Char i <- iter c = load $ ptr +>> i if c == '\NUL' then Done $ stack.read() else stack.push(c) Continue
def show_any(x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x))
def coerce_table(m|Ix, x:n=>a) -> m => a given (n|Ix, a|Data) = if size m == size n then unsafe_coerce(to=m=>a, x) else error "mismatched sizes in table coercion"

GetEnv

foreign "getenv" getenvFFI : (RawPtr) -> {IO} RawPtr
def get_env(name:String) -> {IO} Maybe String = cStr <- with_c_string name getenvFFI cStr.ptr | CString | from_c_string
def check_env(name:String) -> {IO} Bool = is_just $ get_env name

Testing Helpers

-- -- Reliably causes a segfault if pointers aren't initialized to zero.
-- -- TODO: add this test when we cache modules
-- justSomeDataToTestCaching = toList for i:(Fin 100).
-- if ordinal i == 0
-- then Left (toList [1,2,3])
-- else Right 1

TestMode

def dex_test_mode() -> Bool = unsafe_io \. check_env "DEX_TEST_MODE"

More Stream IO

def fread(stream:Stream ReadMode) -> {IO} String = -- TODO: allow reading longer files! n = 4096 ptr:(Ptr Char) <- with_alloc n stack <- with_stack Char iter \_. numRead = i_to_w32 $ i64_to_i $ freadFFI(ptr.val, 1, n_to_i64 n, stream.ptr) AsList(_, new_chars) = string_from_char_ptr(numRead, ptr) stack.extend(new_chars) if numRead == n_to_w32 n then Continue else Done () stack.read()

Shelling Out

foreign "popen" popenFFI : (RawPtr, RawPtr) -> {IO} RawPtr
foreign "remove" removeFFI : (RawPtr) -> {IO} Int64
foreign "mkstemp" mkstempFFI : (RawPtr) -> {IO} Int32
foreign "close" closeFFI : (Int32) -> {IO} Int32
def shell_out(command:String) -> {IO} String = modeStr = "r" with_c_string command \command'. with_c_string modeStr \modeStr'. pipe = Stream $ popenFFI(command'.ptr, modeStr'.ptr) fread pipe

File Operations

def delete_file(f:FilePath) -> {IO} () = s <- with_c_string(f) removeFFI s.ptr ()
def with_file( f:FilePath, mode:StreamMode, action:(Stream mode) -> {IO} a ) -> {IO} a given (a|Data) = stream = fopen(f, mode) if is_null_raw_ptr stream.ptr then error $ "Unable to open file: " <> f else result = action stream fclose stream result
def write_file(f:FilePath, s:String) -> {IO} () = with_file(f, WriteMode) \stream. fwrite(stream, s)
def read_file(f:FilePath) -> {IO} String = with_file(f, ReadMode) \stream. fread stream
def has_file(f:FilePath) -> {IO} Bool = stream = fopen(f, ReadMode) result = not (is_null_raw_ptr stream.ptr) if result then fclose stream result

Temporary Files

def new_temp_file() -> {IO} FilePath = s <- with_c_string "/tmp/dex-XXXXXX" fd = mkstempFFI s.ptr closeFFI fd string_from_char_ptr(15, (Ptr s.ptr))
def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a) = tmpFile = new_temp_file() result = action tmpFile delete_file tmpFile result
def with_temp_files(action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) = tmpFiles = for i. new_temp_file() result = action tmpFiles for i. delete_file tmpFiles[i] result

Linear Algebra

def linspace(n|Ix, low:Float, high:Float) -> n=>Float = dx = (high - low) / n_to_f (size n) for i:n. low + n_to_f (ordinal i) * dx
def transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = for i j. x[j,i]
def vdot(x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i. x[i] * y[i]
def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j. s[j] .* vs[j]
def naive_matmul(x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) = for i k. fsum for j. x[i,j] * y[j,k]
-- A `FullTileIx` type represents `tile_ix`th full tile (of size
-- `tile_size`) iterating over the index set `n`.
-- This type is only well formed when tile_ix * tile_size < size n.
struct FullTileIx(n|Ix, tile_size:Nat, tile_ix:Nat) = unwrap : Fin tile_size
instance Ix(FullTileIx(n, tile_size, tile_ix)) given (n|Ix, tile_size:Nat, tile_ix:Nat) def size'() = tile_size def ordinal(i) = ordinal i.unwrap def unsafe_from_ordinal(i) = FullTileIx $ unsafe_from_ordinal i
instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat, tile_ix:Nat) def inject'(i) = unsafe_from_ordinal $ tile_size * tile_ix + ordinal i.unwrap def project'(i) = todo def unsafe_project'(i) = todo
-- A `CodaIx` type represents the last few elements of the index set `n`,
-- as might be left over after iterating by tiles.
-- This type is only well formed when size n == coda_offset + coda_size
struct CodaIx(n|Ix, coda_offset:Nat, coda_size:Nat) = unwrap : Fin coda_size
instance Ix(CodaIx(n, coda_offset, coda_size)) given (n|Ix, coda_offset:Nat, coda_size:Nat) def size'() = coda_size def ordinal(i) = ordinal i.unwrap def unsafe_from_ordinal(i) = CodaIx $ unsafe_from_ordinal i
instance Subset(CodaIx(n, coda_offset, coda_size), n) given (n|Ix, coda_offset:Nat, coda_size:Nat) def inject'(i) = unsafe_from_ordinal $ coda_offset + ordinal i.unwrap def project'(i) = todo def unsafe_project'(i) = todo
def tile( n|Ix, tile_size: Nat, body:(m:Type, given () (Ix m, Subset(m, n))) -> {|eff} () ) -> {|eff} () given (eff) = num_tiles = size n `idiv` tile_size coda_size = size n `rem` tile_size coda_offset = num_tiles * tile_size for_ tile_ix:(Fin num_tiles). tile_ix' = ordinal tile_ix body (FullTileIx(n, tile_size, tile_ix')) body (CodaIx(n, coda_offset, coda_size))
@noinline def tiled_matmul( x: l=>m=>Float, y: m=>n=>Float ) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = -- Tile sizes picked for axch's laptop l_tile_size = 32 n_tile_size = 128 m_tile_size = 8 yield_accum (AddMonoid Float) \result. tile(l, l_tile_size) \l_set. tile(n, n_tile_size) \n_set. tile(m, m_tile_size) \m_set. for_ l_offset:l_set. l_ix = inject(to=l, l_offset) for_ m_offset:m_set. m_ix = inject m_offset for_ n_offset:n_set. n_ix = inject n_offset result!l_ix!n_ix += x[l_ix][m_ix] * y[m_ix][n_ix]
-- matmul. Better symbol to use? `@`?
def (**)( x: l=>m=>Float, y: m=>n=>Float ) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = tiled_matmul(x, y)
def matmul_linearization( x: l=>m=>Float, y: m=>n=>Float ) -> _ given (l|Ix, m|Ix, n|Ix) = def lin(xt: l=>m=>Float, yt: m=>n=>Float) -> _ = x ** yt + xt ** y (x ** y, lin)
custom-linearization tiled_matmul matmul_linearization
def (**.)(mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) = for i. vdot(mat[i], v)
def(.**)(v: n=>Float, mat: n=>m=>Float) -> (m=>Float) given (n|Ix, m|Ix) = transpose mat **. v
def inner(x:n=>Float, mat:n=>m=>Float, y:m=>Float) -> Float given (n|Ix, m|Ix) = fsum for p. (i,j) = p x[i] * mat[i,j] * y[j]
def eye() ->> n=>n=>a given (n|Ix, a|Add|Mul) = for i j. select(ordinal i == ordinal j, one, zero)