Dex prelude

Runs before every Dex program unless an alternative is provided with --prelude.

Essentials

Primitive Types

Unit = %UnitType
Type = %TyKind
Effects = %EffKind
Fields = %LabeledRowKind
Int64 = %Int64
Int32 = %Int32
Float64 = %Float64
Float32 = %Float32
Word8 = %Word8
Word32 = %Word32
Word64 = %Word64
Byte = Word8
Char = Byte
RawPtr : Type = %Word8Ptr
Int = Int32
Float = Float32

Casting

def internalCast (b:Type) (x:a) : b = %cast b x
def F64ToF (x : Float64) : Float = internalCast _ x
def F32ToF (x : Float32) : Float = internalCast _ x
def FToF64 (x : Float) : Float64 = internalCast _ x
def FToF32 (x : Float) : Float32 = internalCast _ x
def I64ToI (x : Int64) : Int = internalCast _ x
def I32ToI (x : Int32) : Int = internalCast _ x
def W8ToI (x : Word8) : Int = internalCast _ x
def IToI64 (x : Int) : Int64 = internalCast _ x
def IToI32 (x : Int) : Int32 = internalCast _ x
def IToW8 (x : Int) : Word8 = internalCast _ x
def IToW32 (x : Int) : Word32 = internalCast _ x
def IToW64 (x : Int) : Word64 = internalCast _ x
def W32ToW64 (x : Word32): Word64 = internalCast _ x
def IToF (x:Int) : Float = internalCast _ x
def FToI (x:Float) : Int = internalCast _ x
def I64ToRawPtr (x:Int64 ) : RawPtr = internalCast _ x
def RawPtrToI64 (x:RawPtr) : Int64 = internalCast _ x

Bitwise operations

interface Bits a (.<<.) : a -> Int -> a (.>>.) : a -> Int -> a (.|.) : a -> a -> a (.&.) : a -> a -> a (.^.) : a -> a -> a
instance Bits Word8 (.<<.) = \x y. %shl x (IToW8 y) (.>>.) = \x y. %shr x (IToW8 y) (.|.) = \x y. %or x y (.&.) = \x y. %and x y (.^.) = \x y. %xor x y
instance Bits Word32 (.<<.) = \x y. %shl x (IToW32 y) (.>>.) = \x y. %shr x (IToW32 y) (.|.) = \x y. %or x y (.&.) = \x y. %and x y (.^.) = \x y. %xor x y
instance Bits Word64 (.<<.) = \x y. %shl x (IToW64 y) (.>>.) = \x y. %shr x (IToW64 y) (.|.) = \x y. %or x y (.&.) = \x y. %and x y (.^.) = \x y. %xor x y
def lowWord (x : Word64) : Word32 = internalCast _ (x .>>. 32)
def highWord (x : Word64) : Word32 = internalCast _ x

Basic Arithmetic

Add

Things that can be added. This defines the Add group and its operators.

interface Add a add : a -> a -> a sub : a -> a -> a zero : a
def (+) [Add a] : a -> a -> a = add
def (-) [Add a] : a -> a -> a = sub
instance Add Float64 add = \x y. %fadd x y sub = \x y. %fsub x y zero = FToF64 0.0
instance Add Float32 add = \x y. %fadd x y sub = \x y. %fsub x y zero = FToF32 0.0
instance Add Int64 add = \x y. %iadd x y sub = \x y. %isub x y zero = IToI64 0
instance Add Int32 add = \x y. %iadd x y sub = \x y. %isub x y zero = IToI32 0
instance Add Word8 add = \x y. %iadd x y sub = \x y. %isub x y zero = IToW8 0
instance Add Word32 add = \x y. %iadd x y sub = \x y. %isub x y zero = IToW32 0
instance Add Word64 add = \x y. %iadd x y sub = \x y. %isub x y zero = IToW64 0
instance Add Unit add = \x y. () sub = \x y. () zero = ()
instance [Add a] Add (n=>a) add = \xs ys. for i. xs.i + ys.i sub = \xs ys. for i. xs.i - ys.i zero = for _. zero
instance [Add a] Add (i:n => (i..) => a) -- Upper triangular tables add = \xs ys. for i. xs.i + ys.i sub = \xs ys. for i. xs.i - ys.i zero = for _. zero
instance [Add a] Add (i:n => (..i) => a) -- Lower triangular tables add = \xs ys. for i. xs.i + ys.i sub = \xs ys. for i. xs.i - ys.i zero = for _. zero

Mul

Things that can be multiplied. This defines the Mul Monoid, and its operator.

interface Mul a mul : a -> a -> a one : a
def (*) [Mul a] : a -> a -> a = mul
instance Mul Float64 mul = \x y. %fmul x y one = FToF64 1.0
instance Mul Float32 mul = \x y. %fmul x y one = FToF32 1.0
instance Mul Int64 mul = \x y. %imul x y one = IToI64 1
instance Mul Int32 mul = \x y. %imul x y one = IToI32 1
instance Mul Word8 mul = \x y. %imul x y one = IToW8 1
instance Mul Unit mul = \x y. () one = ()
instance [Mul a] Mul (n=>a) mul = \xs ys. for i. xs.i * ys.i one = for _. one

Integral

Integer-like things.

interface Integral a idiv : a->a->a rem : a->a->a
instance Integral Int64 idiv = \x y. %idiv x y rem = \x y. %irem x y
instance Integral Int32 idiv = \x y. %idiv x y rem = \x y. %irem x y
instance Integral Word8 idiv = \x y. %idiv x y rem = \x y. %irem x y

Fractional

Rational-like things. Includes floating point and two field rational representations.

interface Fractional a divide : a -> a -> a
instance Fractional Float64 divide = \x y. %fdiv x y
instance Fractional Float32 divide = \x y. %fdiv x y

Basic polymorphic functions and types

def (&) (a:Type) (b:Type) : Type = %PairType a b
def (,) (x:a) (y:b) : (a & b) = %pair x y
def fst ((x, _): (a & b)) : a = x
def snd ((_, y): (a & b)) : b = y
def swap ((x, y):(a&b)) : (b&a) = (y, x)
def (<<<) (f: b -> c) (g: a -> b) : a -> c = \x. f (g x)
def (>>>) (g: a -> b) (f: b -> c) : a -> c = \x. f (g x)
flip : (a -> b -> c) -> (b -> a -> c) = \f x y. f y x
uncurry : (a -> b -> c) -> (a & b) -> c = \f (x,y). f x y
const : a -> b -> a = \x _. x

Vector spaces

interface [Add a] VSpace a scaleVec : Float -> a -> a
def (.*) [VSpace a] : Float -> a -> a = scaleVec
def (*.) [VSpace a] : a -> Float -> a = flip scaleVec
def (/) [VSpace a] (v:a) (s:Float) : a = divide 1.0 s .* v
def neg [VSpace a] (v:a) : a = (-1.0) .* v
instance VSpace Float scaleVec = \x y. x * y
instance [VSpace a] VSpace (n=>a) scaleVec = \s xs. for i. s .* xs.i
instance VSpace Unit scaleVec = \_ _. ()

Boolean type

data Bool = False True
def BToW8 (x : Bool) : Word8 = %dataConTag x
def W8ToB (x : Word8) : Bool = %toEnum Bool x
def (&&) (x:Bool) (y:Bool) : Bool = x' = BToW8 x y' = BToW8 y W8ToB $ %and x' y'
def (||) (x:Bool) (y:Bool) : Bool = x' = BToW8 x y' = BToW8 y W8ToB $ %or x' y'
def not (x:Bool) : Bool = x' = BToW8 x W8ToB $ %not x'

Sum types

A sum type, or tagged union can hold values from a fixed set of types, distinguished by tags. For those familiar with the C language, they can be though of as a combination of an enum with a union. Here we define several basic kinds, and some operators on them.

data Maybe a = Nothing Just a
def isNothing (x:Maybe a) : Bool = case x of Nothing -> True Just _ -> False
def isJust (x:Maybe a) : Bool = not $ isNothing x
def maybe (d: b) (f : (a -> b)) (x: Maybe a) : b = case x of Nothing -> d Just x' -> f x'
data (|) a b = Left a Right b

More Boolean operations

TODO: move these with the others?

def select (p:Bool) (x:a) (y:a) : a = case p of True -> x False -> y
def BToI (x:Bool) : Int = W8ToI $ BToW8 x
def BToF (x:Bool) : Float = IToF (BToI x)

Ordering

TODO: move this down to with Ord?

data Ordering = LT EQ GT
def OToW8 (x : Ordering) : Word8 = %dataConTag x

Monoid

A monoid is a things that have an associative binary operator and an identity element. This is a very useful and general calls of things. It includes:

  • Addition and Multiplication of Numbers
  • Boolean Logic
  • Concatenation of Lists (including strings) Monoids support fold operations, and similar.
interface Monoid a mempty : a mcombine : a -> a -> a -- can't use `<>` just for parser reasons?
def (<>) [Monoid a] : a -> a -> a = mcombine
instance [Monoid a] Monoid (n=>a) mempty = for i. mempty mcombine = \x y. for i. mcombine x.i y.i
named-instance AndMonoid : Monoid Bool mempty = True mcombine = (&&)
named-instance OrMonoid : Monoid Bool mempty = False mcombine = (||)
def AddMonoid (a:Type) -> (_:Add a) ?=> : Monoid a = A = a -- XXX: Typing `Monoid a` below would quantify it over a, which we don't want named-instance result : Monoid A mempty = zero mcombine = add result
def MulMonoid (a:Type) -> (_:Mul a) ?=> : Monoid a = A = a -- XXX: Typing `Monoid a` below would quantify it over a, which we don't want named-instance result : Monoid A mempty = one mcombine = mul result

Effects

def Ref (r:Type) (a:Type) : Type = %Ref r a
def get (ref:Ref h s) : {State h} s = %get ref
def (:=) (ref:Ref h s) (x:s) : {State h} Unit = %put ref x
def ask (ref:Ref h r) : {Read h} r = %ask ref
data AccumMonoid h w = UnsafeMkAccumMonoid (Monoid w)
@instance def tableAccumMonoid ((UnsafeMkAccumMonoid m):AccumMonoid h w) ?=> : AccumMonoid h (n=>w) = %instance mHint = m def liftTableMonoid (tm:Monoid (n=>w)) ?=> : Monoid (n=>w) = tm UnsafeMkAccumMonoid liftTableMonoid
def (+=) (am:AccumMonoid h w) ?=> (ref:Ref h w) (x:w) : {Accum h} Unit = (UnsafeMkAccumMonoid m) = am %instance mHint = m updater = \v. mcombine v x %mextend ref updater
def (!) (ref:Ref h (n=>a)) (i:n) : Ref h a = %indexRef ref i
def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref
def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref
def runReader (init:r) (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = def explicitAction (h':Type) (ref:Ref h' r) : {Read h'|eff} a = action ref %runReader init explicitAction
def withReader (init:r) (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = runReader init action
def MonoidLifter (b:Type) (w:Type) : Type = h:Type -> AccumMonoid h b ?=> AccumMonoid h w
def runAccum (mlift:MonoidLifter b w) ?=> (bm:Monoid b) (action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} (a & w) = -- Normally, only the ?=> lambda binders participate in dictionary synthesis, -- so we need to explicitly declare `m` as a hint. %instance bmHint = bm empty : b = mempty combine : b -> b -> b = mcombine def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = %instance accumBaseMonoidHint : AccumMonoid h' b = UnsafeMkAccumMonoid bm action ref %runWriter empty combine explicitAction
def yieldAccum (mlift:MonoidLifter b w) ?=> (m:Monoid b) (action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} w = snd $ runAccum m action
def runState (init:s) (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} (a & s) = def explicitAction (h':Type) (ref:Ref h' s) : {State h'|eff} a = action ref %runState init explicitAction
def withState (init:s) (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} a = fst $ runState init action
def yieldState (init:s) (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} s = snd $ runState init action
def unsafeIO (f: Unit -> {IO|eff} a) : {|eff} a = %runIO f
def unreachable (():Unit) : a = unsafeIO do %throwError a

Type classes

Eq and Ord

Eq

Equatable. Things that we can tell if they are equal or not to other things.

interface Eq a (==) : a -> a -> Bool
def (/=) [Eq a] (x:a) (y:a) : Bool = not $ x == y

Ord

Orderable / Comparable. Things that can be place in a total order. i.e. things that can be compared to other things to find if larger, smaller or equal in value.

We take the standard false-hood and pretend that this applies to Floats, even though strictly speaking this not true as our floats follow IEEE754, and thus have NaN < 1.0 == false and 1.0 < NaN == false.

interface [Eq a] Ord a (>) : a -> a -> Bool (<) : a -> a -> Bool
def (<=) [Ord a] (x:a) (y:a) : Bool = x<y || x==y
def (>=) [Ord a] (x:a) (y:a) : Bool = x>y || x==y
instance Eq Float64 (==) = \x y. W8ToB $ %feq x y
instance Eq Float32 (==) = \x y. W8ToB $ %feq x y
instance Eq Int64 (==) = \x y. W8ToB $ %ieq x y
instance Eq Int32 (==) = \x y. W8ToB $ %ieq x y
instance Eq Word8 (==) = \x y. W8ToB $ %ieq x y
instance Eq Bool (==) = \x y. BToW8 x == BToW8 y
instance Eq Unit (==) = \x y. True
instance Eq RawPtr (==) = \x y. RawPtrToI64 x == RawPtrToI64 y
instance Ord Float64 (>) = \x y. W8ToB $ %fgt x y (<) = \x y. W8ToB $ %flt x y
instance Ord Float32 (>) = \x y. W8ToB $ %fgt x y (<) = \x y. W8ToB $ %flt x y
instance Ord Int64 (>) = \x y. W8ToB $ %igt x y (<) = \x y. W8ToB $ %ilt x y
instance Ord Int32 (>) = \x y. W8ToB $ %igt x y (<) = \x y. W8ToB $ %ilt x y
instance Ord Word8 (>) = \x y. W8ToB $ %igt x y (<) = \x y. W8ToB $ %ilt x y
instance Ord Unit (>) = \x y. False (<) = \x y. False
instance [Eq a, Eq b] Eq (a & b) (==) = \(x1,x2) (y1,y2). x1 == y1 && x2 == y2
instance [Ord a, Ord b] Ord (a & b) (>) = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) (<) = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2)
instance Eq Ordering (==) = \x y. OToW8 x == OToW8 y
def scan (init:a) (body:n->a->(a&b)) : (a & n=>b) = swap $ runState init \s. for i. c = get s (c', y) = body i c s := c' y
def fold (init:a) (body:(n->a->a)) : a = fst $ scan init \i x. (body i x, ())
def compare [Ord a] (x:a) (y:a) : Ordering = if x < y then LT else if x == y then EQ else GT
instance Monoid Ordering mempty = EQ mcombine = \x y. case x of LT -> LT GT -> GT EQ -> y
instance [Eq a] Eq (n=>a) (==) = \xs ys. yieldAccum AndMonoid \ref. for i. ref += xs.i == ys.i
instance [Ord a] Ord (n=>a) (>) = \xs ys. f: Ordering = fold EQ $ \i c. c <> (compare xs.i ys.i) f == GT (<) = \xs ys. f: Ordering = fold EQ $ \i c. c <> (compare xs.i ys.i) f == LT

Elementary/Special Functions

This is more or less the standard LibM fare. Roughly it lines up with some definitions of the set of Elementary and/or Special. In truth, nothing is elementary or special except that we humans have decided it is. Many, but not all of these functions are Transcendental.

interface Floating a exp : a -> a exp2 : a -> a log : a -> a log2 : a -> a log10 : a -> a log1p : a -> a sin : a -> a cos : a -> a tan : a -> a sinh : a -> a cosh : a -> a tanh : a -> a floor : a -> a ceil : a -> a round : a -> a sqrt : a -> a pow : a -> a -> a lgamma : a -> a
def lbeta [Add a, Floating a] : a -> a -> a = \x y. lgamma x + lgamma y - lgamma (x + y)
-- Todo: better numerics for very large and small values.
-- Using %exp here to avoid circular definition problems.
def float32_sinh (x:Float32) : Float32 = %fdiv (%fsub (%exp x) (%exp (%fsub 0.0 x))) 2.0
def float32_cosh (x:Float32) : Float32 = %fdiv ((%exp x) + (%exp (%fsub 0.0 x))) 2.0
def float32_tanh (x:Float32) : Float32 = %fdiv (%fsub (%exp x) (%exp (%fsub 0.0 x))) ((%exp x) + (%exp (%fsub 0.0 x)))
-- Todo: unify this with float32 functions.
def float64_sinh (x:Float64) : Float64 = %fdiv (%fsub (%exp x) (%exp (%fsub (FToF64 0.0) x))) (FToF64 2.0)
def float64_cosh (x:Float64) : Float64 = %fdiv ((%exp x) + (%exp (%fsub (FToF64 0.0) x))) (FToF64 2.0)
def float64_tanh (x:Float64) : Float64 = %fdiv (%fsub (%exp x) (%exp (%fsub (FToF64 0.0) x))) ((%exp x) + (%exp (%fsub (FToF64 0.0) x)))
instance Floating Float64 exp = \x. %exp x exp2 = \x. %exp2 x log = \x. %log x log2 = \x. %log2 x log10 = \x. %log10 x log1p = \x. %log1p x sin = \x. %sin x cos = \x. %cos x tan = \x. %tan x sinh = float64_sinh cosh = float64_cosh tanh = float64_tanh floor = \x. %floor x ceil = \x. %ceil x round = \x. %round x sqrt = \x. %sqrt x pow = \x y. %fpow x y lgamma = \x. %lgamma x
instance Floating Float32 exp = \x. %exp x exp2 = \x. %exp2 x log = \x. %log x log2 = \x. %log2 x log10 = \x. %log10 x log1p = \x. %log1p x sin = \x. %sin x cos = \x. %cos x tan = \x. %tan x sinh = float32_sinh cosh = float32_cosh tanh = float32_tanh floor = \x. %floor x ceil = \x. %ceil x round = \x. %round x sqrt = \x. %sqrt x pow = \x y. %fpow x y lgamma = \x. %lgamma x

Index set utilities

def Range (low:Int) (high:Int) : Type = %IntRange low high
def Fin (n:Int) : Type = Range 0 n
def ordinal (i:a) : Int = %toOrdinal i
def size (n:Type) : Int = %idxSetSize n
def unsafeFromOrdinal (n : Type) (i : Int) : n = %unsafeFromOrdinal n i
def iota (n:Type) : n=>Int = view i. ordinal i
-- TODO: we want Eq and Ord for all index sets, not just `Fin n`
instance Eq (Fin n) (==) = \x y. ordinal x == ordinal y
instance Ord (Fin n) (>) = \x y. ordinal x > ordinal y (<) = \x y. ordinal x < ordinal y

Raw pointer operations

data Ptr a = MkPtr RawPtr
def castPtr (ptr: Ptr a) : Ptr b = (MkPtr rawPtr) = ptr MkPtr rawPtr
-- Is there a better way to select the right instance for `storageSize`??
data TypeVehicle a = MkTypeVehicle
def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle
interface Storable a store : Ptr a -> a -> {IO} Unit load : Ptr a -> {IO} a storageSize_ : TypeVehicle a -> Int
def storageSize (a:Type) -> (d:Storable a) ?=> : Int = tv : TypeVehicle a = MkTypeVehicle storageSize_ tv
instance Storable Word8 store = \(MkPtr ptr) x. %ptrStore ptr x load = \(MkPtr ptr) . %ptrLoad ptr storageSize_ = const 1
instance Storable Int32 store = \(MkPtr ptr) x. %ptrStore (internalCast %Int32Ptr ptr) x load = \(MkPtr ptr) . %ptrLoad (internalCast %Int32Ptr ptr) storageSize_ = const 4
instance Storable Float32 store = \(MkPtr ptr) x. %ptrStore (internalCast %Float32Ptr ptr) x load = \(MkPtr ptr) . %ptrLoad (internalCast %Float32Ptr ptr) storageSize_ = const 4
instance Storable (Ptr a) store = \(MkPtr ptr) (MkPtr x). %ptrStore (internalCast %PtrPtr ptr) x load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) storageSize_ = const 8 -- TODO: something more portable?
-- TODO: Storable instances for other types
def malloc [Storable a] (n:Int) : {IO} (Ptr a) = numBytes = storageSize a * n MkPtr $ %alloc numBytes
def free (ptr:Ptr a) : {IO} Unit = (MkPtr ptr') = ptr %free ptr'
def (+>>) [Storable a] (ptr:Ptr a) (i:Int) : Ptr a = (MkPtr ptr') = ptr i' = i * storageSize a MkPtr $ %ptrOffset ptr' i'
-- TODO: consider making a Storable instance for tables instead
def storeTab [Storable a] (ptr: Ptr a) (tab:n=>a) : {IO} Unit = for_ i. store (ptr +>> ordinal i) tab.i
def memcpy [Storable a] (dest:Ptr a) (src:Ptr a) (n:Int) : {IO} Unit = for_ i:(Fin n). i' = ordinal i store (dest +>> i') (load $ src +>> i')
-- TODO: generalize these brackets to allow other effects
-- TODO: make sure that freeing happens even if there are run-time errors
def withAlloc [Storable a] (n:Int) (action: Ptr a -> {IO} b) : {IO} b = ptr = malloc n result = action ptr free ptr result
def withTabPtr [Storable a] (xs:n=>a) (action : Ptr a -> {IO} b) : {IO} b = withAlloc (size n) \ptr. for i. store (ptr +>> ordinal i) xs.i action ptr
def tabFromPtr [Storable a] (n:Type) -> (ptr:Ptr a) : {IO} n=>a = for i. load $ ptr +>> ordinal i

Miscellaneous common utilities

pi : Float = 3.141592653589793
def id (x:a) : a = x
def dup (x:a) : (a & a) = (x, x)
def map (f:a->{|eff} b) (xs: n=>a) : {|eff} (n=>b) = for i. f xs.i
def zip (xs:n=>a) (ys:n=>b) : (n=>(a&b)) = view i. (xs.i, ys.i)
def unzip (xys:n=>(a&b)) : (n=>a & n=>b) = (map fst xys, map snd xys)
def fanout (n:Type) (x:a) : n=>a = view i. x
def sq [Mul a] (x:a) : a = x * x
def abs [Add a, Ord a] (x:a) : a = select (x > zero) x (zero - x)
def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y

Table Operations

instance [Floating a] Floating (n=>a) exp = map exp exp2 = map exp2 log = map log log2 = map log2 log10 = map log10 log1p = map log1p sin = map sin cos = map cos tan = map tan sinh = map sinh cosh = map cosh tanh = map tanh floor = map floor ceil = map ceil round = map round sqrt = map sqrt pow = \x y. for i. pow x.i y.i lgamma = map lgamma

Axis Restructuring

def axis1 (x : a => b => c) : b => a => c = for j. for i. x.i.j
def axis2 (x : a => b => c => d) : c => a => b => d = for k. for i. for j. x.i.j.k
def reindex (ixr: b -> a) (tab: a=>v) : b=>v = for i. tab.(ixr i)

Reductions

-- `combine` should be a commutative and associative, and form a
-- commutative monoid with `identity`
def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = -- TODO: implement with the accumulator effect fold identity (\i c. combine c xs.i)
-- TODO: call this `scan` and call the current `scan` something else
def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x)
-- TODO: allow tables-via-lambda and get rid of this
def fsum (xs:n=>Float) : Float = yieldAccum (AddMonoid Float) \ref. for i. ref += xs i
def sum [Add v] (xs:n=>v) : v = reduce zero (+) xs
def prod [Mul v] (xs:n=>v) : v = reduce one (*) xs
def mean [VSpace v] (xs:n=>v) : v = sum xs / IToF (size n)
def std [Mul v, VSpace v, Floating v] (xs:n=>v) : v = sqrt $ mean (map sq xs) - sq (mean xs)
def any (xs:n=>Bool) : Bool = reduce False (||) xs
def all (xs:n=>Bool) : Bool = reduce True (&&) xs

ApplyN

def applyN (n:Int) (x:a) (f:a -> a) : a = yieldState x \ref. for _:(Fin n). ref := f (get ref)

Linear Algebra

def linspace (n:Type) (low:Float) (high:Float) : n=>Float = dx = (high - low) / IToF (size n) for i:n. low + IToF (ordinal i) * dx
def transpose (x:n=>m=>a) : m=>n=>a = view i j. x.j.i
def vdot (x:n=>Float) (y:n=>Float) : Float = fsum view i. x.i * y.i
def dot [VSpace v] (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j
-- matmul. Better symbol to use? `@`?
(**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. for i k. fsum view j. x.i.j * y.j.k
(**.) : (n=>m=>Float) -> (m=>Float) -> (n=>Float) = \mat v. for i. vdot mat.i v
(.**) : (m=>Float) -> (n=>m=>Float) -> (n=>Float) = flip (**.)
def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = fsum view (i,j). x.i * mat.i.j * y.j
def eye [Add a, Mul a] : n=>n=>a = for i j. select (ordinal i == ordinal j) one zero

cumSum

TODO: Move this to be with reductions? It's a kind of scan.

def cumSum (xs: n=>Float) : n=>Float = withState 0.0 \total. for i. newTotal = get total + xs.i total := newTotal newTotal

Automatic differentiation

AD operations

-- TODO: add vector space constraints
def linearize (f:a->b) (x:a) : (b & a --o b) = %linearize f x
def jvp (f:a->b) (x:a) : a --o b = snd (linearize f x)
def transposeLinear (f:a --o b) : b --o a = %linearTranspose f
def vjp (f:a->b) (x:a) : (b & b --o a) = (y, df) = linearize f x (y, transposeLinear df)
def grad (f:a->Float) (x:a) : a = snd (vjp f x) 1.0
def deriv (f:Float->Float) (x:Float) : Float = jvp f x 1.0
def derivRev (f:Float->Float) (x:Float) : Float = snd (vjp f x) 1.0

Approximate Equality

TODO: move this outside the AD section to be with equality?

interface HasAllClose a allclose : a -> a -> a -> a -> Bool
interface HasDefaultTolerance a atol : a rtol : a
def (~~) [HasAllClose a, HasDefaultTolerance a] : a -> a -> Bool = allclose atol rtol
instance HasAllClose Float32 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y)
instance HasAllClose Float64 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y)
instance HasDefaultTolerance Float32 atol = FToF32 0.00001 rtol = FToF32 0.0001
instance HasDefaultTolerance Float64 atol = FToF64 0.00000001 rtol = FToF64 0.00001
instance [HasAllClose t] HasAllClose (n=>t) allclose = \atol rtol a b. all for i:n. allclose atol.i rtol.i a.i b.i
instance [HasDefaultTolerance t] HasDefaultTolerance (n=>t) atol = for i. atol rtol = for i. rtol

AD Checking tools

def checkDerivBase (f:Float->Float) (x:Float) : Bool = eps = 0.01 ansFwd = deriv f x ansRev = derivRev f x ansNumeric = (f (x + eps) - f (x - eps)) / (2. * eps) ansFwd ~~ ansNumeric && ansRev ~~ ansNumeric
def checkDeriv (f:Float->Float) (x:Float) : Bool = checkDerivBase f x && checkDerivBase (deriv f) x

Vector support

-- TODO: Reenable vector suport once fixed-width types are supported.
-- def UNSAFEFromOrdinal (n : Type) (i : Int) : n = %unsafeAsIndex n i
--
-- VectorWidth = 4 -- XXX: Keep this synced with the constant defined in Array.hs
-- VectorFloat = todo
--
-- def packVector (a : Float) (b : Float) (c : Float) (d : Float) : VectorFloat = %vectorPack a b c d
-- def indexVector (v : VectorFloat) (i : Fin VectorWidth) : Float = %vectorIndex v i
--
-- -- NB: Backends should be smart enough to optimize this to a vector load from v
-- def loadVector (v : (Fin VectorWidth)=>Float) : VectorFloat =
-- idx = Fin VectorWidth
-- (packVector v.(UNSAFEFromOrdinal idx 0)
-- v.(UNSAFEFromOrdinal idx 1)
-- v.(UNSAFEFromOrdinal idx 2)
-- v.(UNSAFEFromOrdinal idx 3))
-- def storeVector (v : VectorFloat) : (Fin VectorWidth)=>Float =
-- idx = Fin VectorWidth
-- [ indexVector v (UNSAFEFromOrdinal idx 0)
-- , indexVector v (UNSAFEFromOrdinal idx 1)
-- , indexVector v (UNSAFEFromOrdinal idx 2)
-- , indexVector v (UNSAFEFromOrdinal idx 3) ]
--
-- def broadcastVector (v : Float) : VectorFloat = packVector v v v v
--
-- @instance vectorFloatAdd : Add VectorFloat =
-- (MkAdd ( \x y. %vfadd x y )
-- ( \x y. %vfsub x y )
-- ( broadcastVector zero ))
-- @instance vectorFloatMul : Mul VectorFloat =
-- MkMul (\x y. %vfmul x y) $ packVector 1.0 1.0 1.0 1.0
-- @instance vectorFloatVSpace : VSpace VectorFloat =
-- MkVSpace vectorFloatAdd \x v. broadcastVector x * v

Tiling functions

def Tile (n : Type) (m : Type) : Type = %IndexSlice n m
-- One can think of instances of `Tile n m` as injective functions `m -> n`,
-- with the special property that consecutive elements of m map to consecutive
-- elements of n. In this view (+>) is just function application, while ++>
-- is currying followed by function application. We cannot represent currying
-- in isolation, because `Tile n (Tile u v)` does not make sense, unlike `Tile n (u & v)`.
def (+>) (t:Tile n l) (i : l) : n = %sliceOffset t i
def (++>) (t : Tile n (u & v)) (i : u) : Tile n v = %sliceCurry t i
def tile (fTile : (t:(Tile n l) -> {|eff} l=>a)) (fScalar : n -> {|eff} a) : {|eff} n=>a = %tiled fTile fScalar
def tile1 (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) (fScalar : n -> {|eff} m=>a) : {|eff} m=>n=>a = %tiledd fTile fScalar
-- TODO: This should become just `loadVector $ for i. arr.(t +> i)`
-- once we are able to eliminate temporary arrays. Until then, we inline for performance...
--def loadTile (t : Tile n (Fin VectorWidth)) (arr : n=>Float) : VectorFloat =
-- idx = Fin VectorWidth
-- (packVector arr.(t +> UNSAFEFromOrdinal idx 0)
-- arr.(t +> UNSAFEFromOrdinal idx 1)
-- arr.(t +> UNSAFEFromOrdinal idx 2)
-- arr.(t +> UNSAFEFromOrdinal idx 3))

Length-erased lists

data List a = AsList n:Int elements:(Fin n => a)
instance [Eq a] Eq (List a) (==) = \(AsList nx xs) (AsList ny ys). if nx /= ny then False else all for i:(Fin nx). xs.i == ys.(unsafeFromOrdinal _ (ordinal i))
def unsafeCastTable (m:Type) (xs:n=>a) : m=>a = for i. xs.(unsafeFromOrdinal _ (ordinal i))
def toList (xs:n=>a) : List a = n' = size n AsList _ $ unsafeCastTable (Fin n') xs
instance Monoid (List a) mempty = AsList _ [] mcombine = \x y. (AsList nx xs) = x (AsList ny ys) = y nz = nx + ny AsList _ $ for i:(Fin nz). i' = ordinal i case i' < nx of True -> xs.(unsafeFromOrdinal _ i') False -> ys.(unsafeFromOrdinal _ (i' - nx))
def ListMonoid (a:Type) : Monoid (List a) = A = a -- XXX: Typing `Monoid a` below would quantify it over a, -- which we don't want. named-instance result : Monoid (List A) mempty = mempty mcombine = mcombine result
def append [AccumMonoid h (List a)] (list: Ref h (List a)) (x:a) : {Accum h} Unit = list += AsList 1 [x]
def filter (condition:a->Bool) (xs:n=>a) : List a = yieldAccum (ListMonoid a) \list. for i. if condition xs.i then append list xs.i
def argFilter (condition:a->Bool) (xs:n=>a) : List n = -- Returns all indices where the condition is true. yieldAccum (ListMonoid n) \list. for i. if condition xs.i then append list i

Isomorphisms

data Iso a b = MkIso { fwd: a -> b & bwd: b -> a }
def appIso (iso: Iso a b) (x:a) : b = (MkIso {fwd, bwd}) = iso fwd x
def flipIso (iso: Iso a b) : Iso b a = (MkIso {fwd, bwd}) = iso MkIso {fwd=bwd, bwd=fwd}
def revIso (iso: Iso a b) (x:b) : a = appIso (flipIso iso) x
idIso : Iso a a = MkIso {fwd=id, bwd=id}
def (&>>) (iso1: Iso a b) (iso2: Iso b c) : Iso a c = (MkIso {fwd=fwd1, bwd=bwd1}) = iso1 (MkIso {fwd=fwd2, bwd=bwd2}) = iso2 MkIso {fwd=(fwd1 >>> fwd2), bwd=(bwd1 <<< bwd2)}
def (<<&) (iso2: Iso b c) (iso1: Iso a b) : Iso a c = iso1 &>> iso2

Lens-like accessors

note: #foo is an Iso {foo: a & ...b} (a & {&...b}))

def getAt (iso: Iso a (b & c)) : a -> b = fst <<< appIso iso
def popAt (iso: Iso a (b & c)) : a -> c = snd <<< appIso iso
def pushAt (iso: Iso a (b & c)) (x:b) (r:c) : a = revIso iso (x, r)
def setAt (iso: Iso a (b & c)) (x:b) (r:a) : a = pushAt iso x $ popAt iso r

Prism-like accessors

note: #?foo is an Iso {foo: a | ...b} (a | {|...b}))

def matchWith (iso: Iso a (b | c)) (x: a) : Maybe b = case appIso iso x of Left x -> Just x Right _ -> Nothing
def buildWith (iso: Iso a (b | c)) (x: b) : a = revIso iso $ Left x
swapPairIso : Iso (a & b) (b & a) = MkIso {fwd = \(a, b). (b, a), bwd = \(b, a). (a, b)}

Complement lens

Complement the focus of a lens-like isomorphism

exceptLens : Iso a (b & c) -> Iso a (c & b) = \iso. iso &>> swapPairIso
swapEitherIso : Iso (a | b) (b | a) = fwd = \x. case x of Left l -> Right l Right r -> Left r bwd = \x. case x of Left r -> Right r Right l -> Left l MkIso {fwd, bwd}

Complement prism

Complement the focus of a prism-like isomorphism

exceptPrism : Iso a (b | c) -> Iso a (c | b) = \iso. iso &>> swapEitherIso
-- Use a lens-like iso to split a 1d table into a 2d table
def overLens (iso: Iso a (b & c)) (tab: a=>v) : (b=>c=>v) = for i j. tab.(revIso iso (i, j))

Zipper

Zipper isomorphisms to easily specify many record/variant fields:

#&foo is an Iso ({&...l} & {foo:a & ...r}) ({foo:a & ...l} & {&...r})
#|foo is an Iso ({|...l} | {foo:a | ...r}) ({foo:a | ...l} | {|...r})

Convert a record zipper isomorphism to a normal lens-like isomorphism by using splitR &>> iso

splitR : Iso a ({&} & a) = MkIso {fwd=\x. ({}, x), bwd=\({}, x). x}
def overFields (iso: Iso ({&} & a) (b & c)) (tab: a=>v) : b=>c=>v = overLens (splitR &>> iso) tab

Convert a variant zipper isomorphism to a normal prism-like isomorphism by using splitV &>> iso

splitV : Iso a ({|} | a) = MkIso {fwd=\x. Right x, bwd=\v. case v of Right x -> x}
def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = reindex (buildWith $ splitV &>> iso) tab

Dynamic buffer

-- TODO: would be nice to be able to use records here
data DynBuffer a = MkDynBuffer { size : Ptr Int & maxSize : Ptr Int & buffer : Ptr (Ptr a) }
def withDynamicBuffer [Storable a] (action: DynBuffer a -> {IO} b) : {IO} b = initMaxSize = 256 withAlloc 1 \sizePtr. withAlloc 1 \maxSizePtr. withAlloc 1 \bufferPtr. store sizePtr 0 store maxSizePtr initMaxSize store bufferPtr $ malloc initMaxSize result = action $ MkDynBuffer { size = sizePtr , maxSize = maxSizePtr , buffer = bufferPtr } free $ load bufferPtr result
def maybeIncreaseBufferSize [Storable a] ((MkDynBuffer db): DynBuffer a) (sizeDelta:Int) : {IO} Unit = size = load $ getAt #size db maxSize = load $ getAt #maxSize db bufPtr = load $ getAt #buffer db newSize = sizeDelta + size if newSize > maxSize then -- TODO: maybe this should use integer arithmetic? newMaxSize = FToI $ pow 2.0 (ceil $ log2 $ IToF newSize) newBufPtr = malloc newMaxSize memcpy newBufPtr bufPtr size free bufPtr store (getAt #maxSize db) newMaxSize store (getAt #buffer db) newBufPtr
def addAtIntPtr (ptr: Ptr Int) (n:Int) : {IO} Unit = store ptr (load ptr + n)
def extendDynBuffer [Storable a] (buf: DynBuffer a) (new:List a) : {IO} Unit = (AsList n xs) = new maybeIncreaseBufferSize buf n (MkDynBuffer db) = buf bufPtr = load $ getAt #buffer db size = load $ getAt #size db storeTab (bufPtr +>> size) xs addAtIntPtr (getAt #size db) n
def loadDynBuffer [Storable a] (buf: DynBuffer a) : {IO} (List a) = (MkDynBuffer db) = buf bufPtr = load $ getAt #buffer db size = load $ getAt #size db AsList size $ tabFromPtr _ bufPtr
def pushDynBuffer [Storable a] (buf: DynBuffer a) (x:a) : {IO} Unit = extendDynBuffer buf $ AsList _ [x]

Strings and Characters

String : Type = List Char
def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {IO} String = AsList n $ tabFromPtr _ ptr
-- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint
def codepoint (c:Char) : Int = W8ToI c

Show interface

For things that can be shown. show gives a string representation of its input. No particular promises are made to exactly what that representation will contain. In particular it is not promised to be parseable. Nor does it promise a particular level of precision for numeric values.

interface Show a show : a -> String
instance Show String show = id
instance Show Int32 show = \x. unsafeIO do (n, ptr) = %ffi showInt32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr
instance Show Int64 show = \x. unsafeIO do (n, ptr) = %ffi showInt64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr
instance Show Float32 show = \x. unsafeIO do (n, ptr) = %ffi showFloat32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr
instance Show Float64 show = \x. unsafeIO do (n, ptr) = %ffi showFloat64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr
instance [Show a, Show b] Show (a & b) show = \(a, b). "(" <> show a <> ", " <> show b <> ")"

pipe-like reverse function application

TODO: move this

def (|>) (x:a) (f: a -> b) : b = f x

Floating-point helper functions

TODO: Move these to be with Elementary/Special functions. Or move those to be here.

def sign (x:Float) : Float = case x > 0.0 of True -> 1.0 False -> case x < 0.0 of True -> -1.0 False -> x
def copysign (a:Float) (b:Float) : Float = case b > 0.0 of True -> a False -> case b < 0.0 of True -> (-a) False -> 0.0
-- Todo: use IEEE floating-point builtins.
infinity = 1.0 / 0.0
nan = 0.0 / 0.0
-- Todo: use IEEE floating-point builtins.
def isinf (x:Float) : Bool = (x == infinity) || (x == -infinity)
def isnan (x:Float) : Bool = not (x >= x && x <= x)
-- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered.
def either_is_nan (x:Float) (y:Float) : Bool = (isnan x) || (isnan y)

File system operations

FilePath : Type = String
data CString = MkCString RawPtr
def nullRawPtr : RawPtr = I64ToRawPtr $ IToI64 0
def fromNullableRawPtr (ptr:RawPtr) : Maybe (Ptr a) = if ptr == nullRawPtr then Nothing else Just $ MkPtr ptr
def cStringPtr (s:CString) : Maybe (Ptr Char) = (MkCString ptr) = s fromNullableRawPtr ptr
data StreamMode = ReadMode WriteMode
data Stream mode:StreamMode = MkStream RawPtr
-- TODO: check the string contains no nulls
def withCString (s:String) (action: CString -> {IO} a) : {IO} a = (AsList n s') = s <> "\NUL" withTabPtr s' \(MkPtr ptr). action $ MkCString ptr

Stream IO

def fopen (path:String) (mode:StreamMode) : {IO} (Stream mode) = modeStr = case mode of ReadMode -> "r" WriteMode -> "w" withCString path \(MkCString pathPtr). withCString modeStr \(MkCString modePtr). MkStream $ %ffi fopen RawPtr pathPtr modePtr
def fclose (stream:Stream mode) : {IO} Unit = (MkStream stream') = stream %ffi fclose Int64 stream' ()
def fwrite (stream:Stream WriteMode) (s:String) : {IO} Unit = (MkStream stream') = stream (AsList n s') = s withTabPtr s' \(MkPtr ptr). %ffi fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' %ffi fflush Int64 stream' ()

Iteration

TODO: move this out of the file-system section

def while (body: Unit -> {|eff} Bool) : {|eff} Unit = body' : Unit -> {|eff} Word8 = \_. BToW8 $ body () %while body'
data IterResult a = Continue Done a
-- TODO: can we improve effect inference so we don't need this?
def liftState (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|eff} b = f x
-- A little iteration combinator
def iter (body: Int -> {|eff} IterResult a) : {|eff} a = result = yieldState Nothing \resultRef. withState 0 \i. while do continue = isNothing $ get resultRef if continue then case liftState resultRef (liftState i body) (get i) of Continue -> i := get i + 1 Done result -> resultRef := Just result continue case result of Just ans -> ans Nothing -> unreachable ()
def boundedIter (maxIters:Int) (fallback:a) (body: Int -> {|eff} IterResult a) : {|eff} a = iter \i. if i >= maxIters then Done fallback else body i

Environment Variables

def fromCString (s:CString) : {IO} (Maybe String) = case cStringPtr s of Nothing -> Nothing Just ptr -> Just $ withDynamicBuffer \buf. iter \i. c = load $ ptr +>> i if c == '\NUL' then Done $ loadDynBuffer buf else pushDynBuffer buf c Continue
def getEnv (name:String) : {IO} Maybe String = withCString name \(MkCString ptr). fromCString $ MkCString $ %ffi getenv RawPtr ptr
def checkEnv (name:String) : {IO} Bool = isJust $ getEnv name

More Stream IO

def fread (stream:Stream ReadMode) : {IO} String = (MkStream stream') = stream -- TODO: allow reading longer files! n = 4096 withAlloc n \ptr:(Ptr Char). withDynamicBuffer \buf. iter \_. (MkPtr rawPtr) = ptr numRead = I64ToI $ %ffi fread Int64 rawPtr (IToI64 1) (IToI64 n) stream' extendDynBuffer buf $ stringFromCharPtr numRead ptr if numRead == n then Continue else Done () loadDynBuffer buf

Print

def getOutputStream (_:Unit) : {IO} Stream WriteMode = MkStream $ %ptrLoad %outputStreamPtr
def print (s:String) : {IO} Unit = fwrite (getOutputStream ()) (s <> "\n")

Shelling Out

def shellOut (command:String) : {IO} String = modeStr = "r" withCString command \(MkCString commandPtr). withCString modeStr \(MkCString modePtr). pipe = MkStream %ffi popen RawPtr commandPtr modePtr fread pipe

Partial functions

A partial function in this context is a function that can error. i.e. a function that is not actually defined for all of its supposed domain. Not to be confused with a partially applied function

Error throwing

def error (s:String) : a = unsafeIO do print s %throwError a
def todo : a = error "TODO: implement it!"

File Operations

def deleteFile (f:FilePath) : {IO} Unit = withCString f \(MkCString ptr). %ffi remove Int64 ptr ()
def withFile (f:FilePath) (mode:StreamMode) (action: Stream mode -> {IO} a) : {IO} a = stream = fopen f mode (MkStream stream') = stream if stream' == nullRawPtr then error $ "Unable to open file: " <> f else result = action stream fclose stream result
def writeFile (f:FilePath) (s:String) : {IO} Unit = withFile f WriteMode \stream. fwrite stream s
def readFile (f:FilePath) : {IO} String = withFile f ReadMode \stream. fread stream
def hasFile (f:FilePath) : {IO} Bool = stream = fopen f ReadMode (MkStream stream') = stream result = stream' /= nullRawPtr if result then fclose stream result

Temporary Files

def newTempFile (_:Unit) : {IO} FilePath = withCString "/tmp/dex-XXXXXX" \(MkCString ptr). fd = %ffi mkstemp Int32 ptr %ffi close Int32 fd stringFromCharPtr 15 (MkPtr ptr)
def withTempFile (action: FilePath -> {IO} a) : {IO} a = tmpFile = newTempFile () result = action tmpFile deleteFile tmpFile result
def withTempFiles (action: (n=>FilePath) -> {IO} a) : {IO} a = tmpFiles = for i. newTempFile () result = action tmpFiles for i. deleteFile tmpFiles.i result

Table operations

def fromOrdinal (n:Type) (i:Int) : n = case (0 <= i) && (i < size n) of True -> unsafeFromOrdinal _ i False -> error $ "Ordinal index out of range:" <> show i <> " >= " <> show (size n)
-- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy
-- TODO: safe (runtime-checked) and unsafe versions
def castTable (m:Type) (xs:n=>a) : m=>a = case size m == size n of True -> unsafeCastTable _ xs False -> error $ "Table size mismatch in cast: " <> show (size m) <> " vs " <> show (size n)
def asidx (i:Int) : n = fromOrdinal n i
def (@) (i:Int) (n:Type) : n = fromOrdinal n i
def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = for i. xs.(fromOrdinal _ (ordinal i + start))
def head (xs:n=>a) : a = xs.(0@_)
def tail (xs:n=>a) (start:Int) : List a = numElts = size n - start toList $ slice xs start (Fin numElts)

Pseudorandom number generator utilities

Dex does not use a stateful random number generator. Rather it uses what is known as a split-able random number generator, which is based on a hash function. Dex's PRNG system is modelled directly after JAX's, which is based on a well established but shockingly underused idea from the functional programming community: the splittable PRNG. It's a good idea for many reasons, but it's especially helpful in a parallel setting. If you want to read more, Splittable pseudorandom number generators using cryptographic hashing describes the splitting model itself and D.E. Shaw Research's counter-based PRNG proposes the particular hash function we use.

Key functions

-- TODO: newtype
Key = Word64
def threeFry2x32 (k:Word64) (count:Word64) : Word64 = -- Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins rotations1 = [13, 15, 26, 6] rotations2 = [17, 29, 16, 24] k0 = lowWord k k1 = highWord k -- TODO: add a fromHex k2 = k0 .^. k1 .^. (IToW32 466688986) -- 0x1BD11BDA x = lowWord count y = highWord count x = x + k0 y = y + k1 rotations = [rotations1, rotations2] ks = [k1, k2, k0] (x, y) = yieldState (x, y) \ref. for i:(Fin 5). for j. (x, y) = get ref rotationIndex = unsafeFromOrdinal _ (mod (ordinal i) 2) rot = rotations.rotationIndex.j x = x + y y = (y .<<. rot) .|. (y .>>. (32 - rot)) y = x .^. y ref := (x, y) (x, y) = get ref x = x + ks.(unsafeFromOrdinal _ (mod (ordinal i) 3)) y = y + ks.(unsafeFromOrdinal _ (mod ((ordinal i)+1) 3)) + IToW32 ((ordinal i)+1) ref := (x, y) (W32ToW64 x .<<. 32) .|. (W32ToW64 y)
def hash (x:Key) (y:Int32) : Key = y64 = IToW64 y threeFry2x32 x y64
def newKey (x:Int) : Key = hash (IToW64 0) x
def many (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i))
def ixkey (k:Key) (i:n) : Key = hash k (ordinal i)
def ixkey2 (k:Key) (i:n) (j:m) : Key = hash (hash k (ordinal i)) (ordinal j)
def splitKey (k:Key) : Fin n => Key = for i. ixkey k i

Sample Generators

These functions generate samples taken from, different distributions. Such as randMat with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution.

def rand (k:Key) : Float = unsafeIO do F64ToF $ %ffi randunif Float64 k
def randVec (n:Int) (f: Key -> a) (k: Key) : Fin n => a = for i:(Fin n). f (ixkey k i)
def randMat (n:Int) (m:Int) (f: Key -> a) (k: Key) : Fin n => Fin m => a = for i j. f (ixkey2 k i j)
def randn (k:Key) : Float = [k1, k2] = splitKey k u1 = rand k1 u2 = rand k2 sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2)
-- TODO: Make this better...
def randInt (k:Key) : Int = (internalCast Int k) `mod` 2147483647
def bern (p:Float) (k:Key) : Bool = rand k < p
def randnVec (k:Key) : n=>Float = for i. randn (ixkey k i)
def randIdx (k:Key) : n = unif = rand k fromOrdinal n $ FToI $ floor $ unif * IToF (size n)

Arbitrary

Type class for generating example values

interface Arbitrary a arb : Key -> a
instance Arbitrary Float32 arb = randn
instance Arbitrary Int32 arb = \key. FToI $ randn key * 5.0
instance [Arbitrary a] Arbitrary (n=>a) arb = \key. for i. arb $ ixkey key i
instance [Arbitrary a, Arbitrary b] Arbitrary (a & b) arb = \key. [k1, k2] = splitKey key (arb k1, arb k2)
instance Arbitrary (Fin n) arb = randIdx

Ord on Arrays

Searching

returns the highest index i such that xs.i <= x

def searchSorted [Ord a] (xs:n=>a) (x:a) : Maybe n = if size n == 0 then Nothing else if x < xs.(fromOrdinal _ 0) then Nothing else withState 0 \low. withState (size n) \high. iter \_. numLeft = get high - get low if numLeft == 1 then Done $ Just $ fromOrdinal _ $ get low else centerIx = get low + idiv numLeft 2 if x < xs.(fromOrdinal _ centerIx) then high := centerIx else low := centerIx Continue

min / max etc

def minBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y
def maxBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y
def min [Ord o] (x1: o) -> (x2: o) : o = minBy id x1 x2
def max [Ord o] (x1: o) -> (x2: o) : o = maxBy id x1 x2
def minimumBy [Ord o] (f:a->o) (xs:n=>a) : a = reduce xs.(0@_) (minBy f) xs
def maximumBy [Ord o] (f:a->o) (xs:n=>a) : a = reduce xs.(0@_) (maxBy f) xs
def minimum [Ord o] (xs:n=>o) : o = minimumBy id xs
def maximum [Ord o] (xs:n=>o) : o = maximumBy id xs

argmin/argmax

TODO: put in same section as searchsorted

def argscan (comp:o->o->Bool) (xs:n=>o) : n = zeroth = (0@_, xs.(0@_)) compare = \(idx1, x1) (idx2, x2). select (comp x1 x2) (idx1, x1) (idx2, x2) zipped = for i. (i, xs.i) fst $ reduce zeroth compare zipped
def argmin [Ord o] (xs:n=>o) : n = argscan (<) xs
def argmax [Ord o] (xs:n=>o) : n = argscan (>) xs

clip

def clip [Ord a] ((low,high):(a&a)) (x:a) : a = min high $ max low x

Trigonometric functions

TODO: these should be with the other Elementary/Special Functions

atan/atan2

def atan_inner (x:Float) : Float = -- From "Computing accurate Horner form approximations to -- special functions in finite precision arithmetic" -- https://arxiv.org/abs/1508.03211 -- Only accurate in the range [-1, 1] s = x * x r = 0.0027856871 r = r * s - 0.0158660002 r = r * s + 0.042472221 r = r * s - 0.0749753043 r = r * s + 0.106448799 r = r * s - 0.142070308 r = r * s + 0.199934542 r = r * s - 0.333331466 r = r * s r * x + x
def min_and_max [Ord a] (x:a) (y:a) : (a & a) = select (x < y) (x, y) (y, x) -- get both with one comparison.
def atan2 (y:Float) (x:Float) : Float = -- Based off of the Tensorflow implementation at -- github.com/tensorflow/mlir-hlo/blob/master/lib/ -- Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147 -- With a fix to the nan propagation. abs_x = abs x abs_y = abs y (min_abs_x_y, max_abs_x_y) = min_and_max abs_x abs_y a = atan_inner (min_abs_x_y / max_abs_x_y) a = select (abs_x <= abs_y) ((pi / 2.0) -a) a a = select (x < 0.0) (pi - a) a t = select (x < 0.0) pi 0.0 a = select (y == 0.0) t a t = select (x < 0.0) (3.0 * pi / 4.0) (pi / 4.0) a = select (isinf x && isinf y) t a -- Handle infinite inputs. a = copysign a y select (either_is_nan x y) nan a -- Propagate NaNs.
def atan (x:Float) : Float = atan2 x 1.0

Complex numbers

data Complex = MkComplex Float Float -- real, imaginary
instance HasAllClose Complex allclose = \atol rtol (MkComplex a b) (MkComplex c d). (a ~~ c) && (b ~~ d)
instance HasDefaultTolerance Complex atol = MkComplex atol atol rtol = MkComplex rtol rtol
instance Eq Complex (==) = \(MkComplex a b) (MkComplex c d). (a == c) && (b == d)
instance Add Complex add = \(MkComplex a b) (MkComplex c d). MkComplex (a + c) (b + d) sub = \(MkComplex a b) (MkComplex c d). MkComplex (a - c) (b - d) zero = MkComplex 0.0 0.0
instance Mul Complex mul = \(MkComplex a b) (MkComplex c d). MkComplex (a * c - b * d) (a * d + b * c) one = MkComplex 1.0 0.0
instance VSpace Complex scaleVec = \a:Float (MkComplex c d):Complex. MkComplex (a * c) (a * d)
-- Todo: Hook up to (/) operator. Might require two-parameter VSpace.
def complex_division (MkComplex a b:Complex) (MkComplex c d:Complex): Complex = MkComplex ((a * c + b * d) / (c * c + d * d)) ((b * c - a * d) / (c * c + d * d))
def complex_exp (MkComplex a b:Complex) : Complex = expx = exp a MkComplex (expx * cos b) (expx * sin b)
def complex_exp2 (MkComplex a b:Complex) : Complex = exp2x = exp2 a b' = b * log 2.0 MkComplex (exp2x * cos b') (exp2x * sin b')
def complex_conj (MkComplex a b:Complex) : Complex = MkComplex a (-b)
def complex_abs (MkComplex a b:Complex) : Float = a * a + b * b
def complex_mag (MkComplex a b:Complex) : Float = sqrt (a * a + b * b)
def complex_arg (MkComplex a b:Complex) : Float = atan2 b a
complex_log = \x:Complex. MkComplex (log (complex_mag x)) (complex_arg x)
complex_log2 = \x:Complex. (complex_log x) / log 2.0
complex_log10 = \x:Complex. (complex_log x) / log 10.0
complex_pow = \base:Complex power:Complex. complex_exp (power * complex_log base)
def complex_sqrt (MkComplex a b:Complex) : Complex = m = complex_mag $ MkComplex a b MkComplex (sqrt ((a + m) / 2.0)) (sign b * sqrt ((m - a) / 2.0))
def complex_sin (MkComplex a b:Complex) : Complex = MkComplex (sin a * cosh b) (cos a * sinh b)
def complex_sinh (MkComplex a b:Complex) : Complex = MkComplex (sinh a * cos b) (cosh a * sin b)
def complex_cos (MkComplex a b:Complex) : Complex = MkComplex (cos a * cosh b) (-sin a * sinh b)
def complex_cosh (MkComplex a b:Complex) : Complex = MkComplex (cosh a * cos b) (-sinh a * sin b)
def complex_tan (x:Complex) : Complex = complex_division (complex_sin x) (complex_cos x)
def complex_tanh (MkComplex a b:Complex) : Complex = num = MkComplex (sinh a * cos b) (cosh a * sin b) den = MkComplex (cosh a * cos b) (sinh a * sin b) complex_division num den
instance Fractional Complex divide = complex_division
def complex_floor (MkComplex re im:Complex) : Complex = -- from "Complex Floor" by Eugene McDonnell -- https://www.jsoftware.com/papers/eem/complexfloor.htm fr = floor re fi = floor im x = re - fr y = im - fi case (x + y) < 1.0 of True -> MkComplex fr fi False -> case x >= y of True -> MkComplex (fr + 1.0) fi False -> MkComplex fr (fi + 1.0)
complex_ceil = \x:Complex. -(complex_floor (-x))
complex_round = \x:Complex. complex_floor $ MkComplex 0.5 0.0
complex_lgamma : Complex -> Complex = \x:Complex. todo -- This one is pretty hairy.
-- See https://cs.uwaterloo.ca/research/tr/1994/23/CS-94-23.pdf
def complex_log1p (x:Complex) : Complex = (MkComplex a b) = x case a == 0.0 of True -> x False -> u = x + MkComplex 1.0 0.0 case a <= -1.0 of True -> complex_log u False -> divide ((complex_log u) * x) x
instance Floating Complex exp = complex_exp exp2 = complex_exp2 log = complex_log log2 = complex_log2 log10 = complex_log10 log1p = complex_log1p sin = complex_sin cos = complex_cos tan = complex_tan sinh = complex_sinh cosh = complex_cosh tanh = complex_tanh floor = complex_floor ceil = complex_ceil round = complex_round sqrt = complex_sqrt pow = complex_pow lgamma = complex_lgamma

Miscellaneous utilities

TODO: all of these should be in some other section

def reflect (i:n) : n = unsafeFromOrdinal n ((size n) - 1 - ordinal i)
def reverse (x:n=>a) : n=>a = for i. x.(reflect i)
def padTo (m:Type) (x:a) (xs:n=>a) : (m=>a) = n' = size n for i. i' = ordinal i case i' < n' of True -> xs.(i'@_) False -> x
def idivCeil (x:Int) (y:Int) : Int = idiv x y + BToI (rem x y /= 0)
def intdiv2 (x:Int) : Int = %shr x 1
def intpow2 (power:Int) : Int = %shl 1 power
def isOdd (x:Int) : Bool = rem x 2 == 1
def isEven (x:Int) : Bool = rem x 2 == 0
def isPowerOf2 (x:Int) : Bool = -- A fast trick based on bitwise AND. -- This works on integer types larger than 8 bits. -- Note: The bitwise and operator (.&.) -- is only defined for Byte, which is why -- we use %and here. TODO: Make (.&.) polymorphic. if x == 0 then False else 0 == %and x (x - 1)
def intlog2 (x:Int) : Int = yieldState (-1) \ansRef. runState 1 \cmpRef. while do if x >= (get cmpRef) then ansRef := (get ansRef) + 1 cmpRef := %shl (get cmpRef) 1 True else False
def nextpow2 (x:Int) : Int = case isPowerOf2 x of True -> x False -> intpow2 (1 + intlog2 x)
def generalIntegerPower (times:a->a->a) (one:a) (base:a) (power:Int) : a = -- Implements exponentiation by squaring. -- Todo: Make power a Nat when it's available. -- This could be nicer if there were a way to explicitly -- specify which typelcass instance to use for Mul. yieldState one \ans. withState power \pow. withState base \z. while do if get pow > 0 then if isOdd (get pow) then ans := times (get ans) (get z) z := times (get z) (get z) pow := intdiv2 (get pow) True else False
def intpow [Mul a] (base:a) (power:Int) : a = generalIntegerPower (*) one base power
def fromJust (x:Maybe a) : a = case x of Just x' -> x'
def anySat (f:a -> Bool) (xs:n=>a) : Bool = any (map f xs)
def seqMaybes (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) = -- is it possible to implement this safely? (i.e. without using partial -- functions) case anySat isNothing xs of True -> Nothing False -> Just $ map fromJust xs
def linearSearch [Eq a] (xs:n=>a) (query:a) : Maybe n = yieldState Nothing \ref. for i. case xs.i == query of True -> ref := Just i False -> ()
def listLength ((AsList n _):List a) : Int = n
-- This is for efficiency (rather than using `<>` repeatedly)
-- TODO: we want this for any monoid but this implementation won't work.
def concat (lists:n=>(List a)) : List a = totalSize = sum for i. listLength lists.i AsList _ $ withState 0 \listIdx. withState 0 \eltIdx. for i:(Fin totalSize). while do continue = get eltIdx >= listLength (lists.((get listIdx)@_)) if continue then eltIdx := 0 listIdx := get listIdx + 1 else () continue (AsList _ xs) = lists.((get listIdx)@_) eltIdxVal = get eltIdx eltIdx := eltIdxVal + 1 xs.(eltIdxVal@_)

Probability

def cumSumLow (xs: n=>Float) : n=>Float = withState 0.0 \total. for i. oldTotal = get total total := oldTotal + xs.i oldTotal
-- cdf should include 0.0 but not 1.0
def categoricalFromCDF (cdf: n=>Float) (key: Key) : n = r = rand key case searchSorted cdf r of Just i -> i
def normalizePdf (xs: d=>Float) : d=>Float = xs / sum xs
def cdfForCategorical (logprobs: n=>Float) : n=>Float = maxLogProb = maximum logprobs cumSumLow $ normalizePdf $ map exp $ for i. logprobs.i - maxLogProb
def categorical (logprobs: n=>Float) (key: Key) : n = categoricalFromCDF (cdfForCategorical logprobs) key
-- batch variant to share the work of forming the cumsum
-- (alternatively we could rely on hoisting of loop constants)
def categoricalBatch (logprobs: n=>Float) (key: Key) : m=>n = cdf = cdfForCategorical logprobs for i. categoricalFromCDF cdf $ ixkey key i
def logsumexp (x: n=>Float) : Float = m = maximum x m + (log $ sum for i. exp (x.i - m))
def logsoftmax (x: n=>Float) : n=>Float = lse = logsumexp x for i. x.i - lse
def softmax (x: n=>Float) : n=>Float = m = maximum x e = for i. exp (x.i - m) s = sum e for i. e.i / s

Polynomials

TODO: Move this somewhere else

def evalpoly [VSpace v] (coefficients:n=>v) (x:Float) : v = -- Evaluate a polynomial at x. Same as Numpy's polyval. fold zero \i c. coefficients.i + x .* c

TestMode

TODO: move this to be in Testing Helpers

def dex_test_mode (():Unit) : Bool = unsafeIO do checkEnv "DEX_TEST_MODE"

Exception effect

TODO: move error and todo to here.

def catch (f:Unit -> {Except|eff} a) : {|eff} Maybe a = ans = %catchException f case %sumToVariant ans of {|c=() |} -> Nothing {|c|c=val|} -> Just val
def throw (_:Unit) : {Except} a = %throwException a
def assert (b:Bool) : {Except} Unit = if not b then throw ()

Testing Helpers

-- -- Reliably causes a segfault if pointers aren't initialized to zero.
-- -- TODO: add this test when we cache modules
-- justSomeDataToTestCaching = toList for i:(Fin 100).
-- if ordinal i == 0
-- then Left (toList [1,2,3])
-- else Right 1