Runs before every Dex program unless an alternative is provided with --prelude
.
Essentials
Primitive Types
RawPtr : Type = %Word8Ptr()
def the(a:Type, x:a) -> a = x
interface Data(a:Type)
do_not_implement_this_interface_for_the_compiler_relies_on_the_invariant_it_protects : (a) -> a
def internal_cast(x:from) -> to given (from, to) =
%cast(to, x)
def unsafe_coerce(x:from) -> to given (from|Data, to|Data) = %unsafeCoerce(to, x)
def uninitialized_value() -> a given (a|Data) = %garbageVal(a)
def f64_to_f(x: Float64) -> Float = internal_cast x
def f32_to_f(x: Float32) -> Float = internal_cast x
def f_to_f64(x: Float) -> Float64 = internal_cast x
def f_to_f32(x: Float) -> Float32 = internal_cast x
def i64_to_i(x: Int64) -> Int = internal_cast x
def i32_to_i(x: Int32) -> Int = internal_cast x
def w8_to_i(x: Word8) -> Int = internal_cast x
def i_to_i64(x: Int) -> Int64 = internal_cast x
def i_to_i32(x: Int) -> Int32 = internal_cast x
def i_to_w8(x: Int) -> Word8 = internal_cast x
def i_to_w32(x: Int) -> Word32 = internal_cast x
def i_to_w64(x: Int) -> Word64 = internal_cast x
def w32_to_w64(x: Word32)-> Word64 = internal_cast x
def i_to_f(x:Int) -> Float = internal_cast x
def f_to_i(x:Float) -> Int = internal_cast x
def raw_ptr_to_i64(x:RawPtr) -> Int64 = internal_cast x
def nat_to_rep(x : Nat) -> NatRep = %projNewtype(x)
def rep_to_nat(x : NatRep) -> Nat = %NatCon(x)
def n_to_w8(x: Nat) -> Word8 = nat_to_rep x | internal_cast
def n_to_w32(x: Nat) -> Word32 = nat_to_rep x | internal_cast
def n_to_w64(x: Nat) -> Word64 = nat_to_rep x | internal_cast
def n_to_i32(x: Nat) -> Int32 = nat_to_rep x | internal_cast
def n_to_i64(x: Nat) -> Int64 = nat_to_rep x | internal_cast
def n_to_f32(x: Nat) -> Float32 = nat_to_rep x | internal_cast
def n_to_f64(x: Nat) -> Float64 = nat_to_rep x | internal_cast
def n_to_f(x: Nat) -> Float = nat_to_rep x | internal_cast
def w8_to_n(x : Word8) -> Nat = internal_cast x | rep_to_nat
def w32_to_n(x : Word32) -> Nat = internal_cast x | rep_to_nat
def w64_to_n(x : Word64) -> Nat = internal_cast x | rep_to_nat
def i32_to_n(x : Int32) -> Nat = internal_cast x | rep_to_nat
def i64_to_n(x : Int64) -> Nat = internal_cast x | rep_to_nat
def f32_to_n(x : Float32) -> Nat = internal_cast x | rep_to_nat
def f64_to_n(x : Float64) -> Nat = internal_cast x | rep_to_nat
def f_to_n(x : Float) -> Nat = internal_cast x | rep_to_nat
interface FromUnsignedInteger(a:Type)
from_unsigned_integer : (Word64) -> a
instance FromUnsignedInteger(Word8)
def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Word32)
def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Word64)
def from_unsigned_integer(x) = x
instance FromUnsignedInteger(Int32)
def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Int64)
def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Float32)
def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Float64)
def from_unsigned_integer(x) = internal_cast x
instance FromUnsignedInteger(Nat)
def from_unsigned_integer(x) = w64_to_n(x)
interface FromInteger(a:Type)
from_integer : (Int64) -> a
instance FromInteger(Float32)
def from_integer(x) = internal_cast x
instance FromInteger(Int32)
def from_integer(x) = internal_cast x
instance FromInteger(Float64)
def from_integer(x) = internal_cast x
instance FromInteger(Int64)
def from_integer(x) = x
interface Bits(a:Type)
(.<<.) : (a, Int) -> a
(.>>.) : (a, Int) -> a
(.|.) : (a, a) -> a
(.&.) : (a, a) -> a
(.^.) : (a, a) -> a
instance Bits(Word8)
def (.<<.)(x, y) = %shl(x, i_to_w8 y)
def (.>>.)(x, y) = %shr(x, i_to_w8 y)
def (.|.)(x, y) = %or( x, y)
def (.&.)(x, y) = %and(x, y)
def (.^.)(x, y) = %xor(x, y)
instance Bits(Word32)
def (.<<.)(x, y) = %shl(x, i_to_w32 y)
def (.>>.)(x, y) = %shr(x, i_to_w32 y)
def (.|.)(x, y) = %or( x, y)
def (.&.)(x, y) = %and(x, y)
def (.^.)(x, y) = %xor(x, y)
instance Bits(Word64)
def (.<<.)(x, y) = %shl(x, i_to_w64 y)
def (.>>.)(x, y) = %shr(x, i_to_w64 y)
def (.|.)(x, y) = %or( x ,y)
def (.&.)(x, y) = %and(x ,y)
def (.^.)(x, y) = %xor(x ,y)
def low_word( x : Word64) -> Word32 = internal_cast(x .>>. 32)
def high_word(x : Word64) -> Word32 = internal_cast(x)
Basic Arithmetic
Add
Things that can be added.
This defines the Add
group and its operators.
interface Add(a|Data)
(+) : (a, a) -> a
zero : a
interface Sub(a|Add)
(-) : (a, a) -> a
instance Add(Float64)
def (+)(x, y) = %fadd(x, y)
zero = 0
instance Sub(Float64)
def (-)(x, y) = %fsub(x, y)
instance Add(Float32)
def (+)(x, y) = %fadd(x, y)
zero = 0
instance Sub(Float32)
def (-)(x, y) = %fsub(x, y)
instance Add(Int64)
def (+)(x, y) = %iadd(x, y)
zero = 0
instance Sub(Int64)
def (-)(x, y) = %isub(x, y)
instance Add(Int32)
def (+)(x, y) = %iadd(x, y)
zero = 0
instance Sub(Int32)
def (-)(x, y) = %isub(x, y)
instance Add(Word8)
def (+)(x, y) = %iadd(x, y)
zero = 0
instance Sub(Word8)
def (-)(x, y) = %isub(x, y)
instance Add(Word32)
def (+)(x, y) = %iadd(x, y)
zero = 0
instance Sub(Word32)
def (-)(x, y) = %isub(x, y)
instance Add(Word64)
def (+)(x, y) = %iadd(x, y)
zero = 0
instance Sub(Word64)
def (-)(x, y) = %isub(x, y)
instance Add(Nat)
def (+)(x, y) = rep_to_nat %iadd(nat_to_rep x, nat_to_rep y)
zero = 0
instance Add(())
def (+)(x, y) = ()
zero = ()
instance Sub(())
def (-)(x, y) = ()
Mul
Things that can be multiplied.
This defines the Mul
Monoid, and its operator.
interface Mul(a|Data)
(*) : (a, a) -> a
one : a
instance Mul(Float64)
def (*)(x, y) = %fmul(x, y)
one = f_to_f64 1.0
instance Mul(Float32)
def (*)(x, y) = %fmul(x, y)
one = f_to_f32 1.0
instance Mul(Int64)
def (*)(x, y) = %imul(x, y)
one = 1
instance Mul(Int32)
def (*)(x, y) = %imul(x, y)
one = 1
instance Mul(Word8)
def (*)(x, y) = %imul(x, y)
one = 1
instance Mul(Word32)
def (*)(x, y) = %imul(x, y)
one = 1
instance Mul(Word64)
def (*)(x, y) = %imul(x, y)
one = 1
instance Mul(Nat)
def(*)(x, y) = rep_to_nat %imul(nat_to_rep x, nat_to_rep y)
one = 1
instance Mul(())
def (*)(x, y) = ()
one = ()
Integral
Integer-like things.
interface Integral(a)
idiv : (a,a)->a
rem : (a,a)->a
instance Integral(Int64)
def idiv(x, y) = %idiv(x, y)
def rem(x, y) = %irem(x, y)
instance Integral(Int32)
def idiv(x, y) = %idiv(x, y)
def rem(x, y) = %irem(x, y)
instance Integral(Word8)
def idiv(x, y) = %idiv(x, y)
def rem(x, y) = %irem(x, y)
instance Integral(Word32)
def idiv(x, y) = %idiv(x, y)
def rem(x, y) = %irem(x, y)
instance Integral(Word64)
def idiv(x, y) = %idiv(x, y)
def rem(x, y) = %irem(x, y)
instance Integral(Nat)
def idiv(x, y) = rep_to_nat %idiv(nat_to_rep x, (nat_to_rep y))
def rem(x, y) = rep_to_nat %irem(nat_to_rep x, (nat_to_rep y))
Fractional
Rational-like things.
Includes floating point and two field rational representations.
interface Fractional(a)
divide : (a, a) -> a
instance Fractional(Float64)
def divide(x, y) = %fdiv(x, y)
instance Fractional(Float32)
def divide(x, y) = %fdiv(x, y)
Index set interface and instances
interface Ix(n|Data)
size' : () -> Nat
ordinal : (n) -> Nat
unsafe_from_ordinal : (Nat) -> n
def size(n|Ix) -> Nat = size'(n=n)
def Fin(n:Nat) -> Type = %Fin(n)
def (-|)(x: Nat, y:Nat) -> Nat =
x' = nat_to_rep x
y' = nat_to_rep y
requires_clamp = %ilt(x', y')
rep_to_nat %select(requires_clamp, 0, (%isub(x', y')))
def unsafe_nat_diff(x:Nat, y:Nat) -> Nat =
x' = nat_to_rep x
y' = nat_to_rep y
rep_to_nat %isub(x', y')
struct RangeFrom(q:Type, i:q) = val : Nat
struct RangeFromExc(q:Type, i:q) = val : Nat
struct RangeTo(q:Type, i:q) = val : Nat
struct RangeToExc(q:Type, i:q) = val : Nat
instance Ix(RangeFrom q i) given (q|Ix, i:q)
def size'() = unsafe_nat_diff(size q, ordinal i)
def ordinal(j) = j.val
def unsafe_from_ordinal(j) = RangeFrom(j)
instance Ix(RangeFromExc q i) given (q|Ix, i:q)
def size'() = unsafe_nat_diff(size q, ordinal i + 1)
def ordinal(j) = j.val
def unsafe_from_ordinal(j) = RangeFromExc(j)
instance Ix(RangeTo q i) given (q|Ix, i:q)
def size'() = ordinal i + 1
def ordinal(j) = j.val
def unsafe_from_ordinal(j) = RangeTo(j)
instance Ix(RangeToExc q i) given (q|Ix, i:q)
def size'() = ordinal i
def ordinal(j) = j.val
def unsafe_from_ordinal(j) = RangeToExc(j)
instance Ix(())
def size'() = 1
def ordinal(_) = 0
def unsafe_from_ordinal(_) = ()
def iota(n|Ix) -> n=>Nat = for i. ordinal i
Arithmetic instances for table types
instance Add(n=>a) given (a|Add, n|Ix)
def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
instance Sub(n=>a) given (a|Sub, n|Ix)
def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Add((i:n) => (i..) => a) given (a|Add, n|Ix)
def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
instance Sub((i:n) => (i..) => a) given (a|Sub, n|Ix)
def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Add((i:n) => (..i) => a) given (a|Add, n|Ix)
def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
instance Sub((i:n) => (..i) => a) given (a|Sub, n|Ix)
def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Add((i:n) => (..<i) => a) given (a|Add, n|Ix)
def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
instance Sub((i:n) => (..<i) => a) given (a|Sub, n|Ix)
def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Add((i:n) => (i<..) => a) given (a|Add, n|Ix)
def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
instance Sub((i:n) => (i<..) => a) given (a|Sub, n|Ix)
def (-)(xs, ys) = for i. xs[i] - ys[i]
instance Mul(n=>a) given (a|Mul, n|Ix)
def (*)(xs, ys) = for i. xs[i] * ys[i]
one = for _. one
Basic polymorphic functions and types
def fst(pair:(a, b)) -> a given (a, b) = pair.0
def snd(pair:(a, b)) -> b given (a, b) = pair.1
def swap(pair:(a, b)) -> (b, a) given (a, b) =
(x, y) = pair
(y, x)
instance Add((a, b)) given (a|Add, b|Add)
def (+)(x, y) =
(x1, x2) = x
(y1, y2) = y
(x1 + y1, x2 + y2)
zero = (zero, zero)
instance Sub((a, b)) given (a|Sub, b|Sub)
def(-)(x, y) =
(x1, x2) = x
(y1, y2) = y
(x1 - y1, x2 - y2)
instance Ix((a, b)) given (a|Ix, b|Ix)
def size'() = size a * size b
def ordinal(pair) =
(i, j) = pair
(ordinal i * size b) + ordinal j
def unsafe_from_ordinal(o) =
bs = size b
(unsafe_from_ordinal(n=a, idiv(o, bs)), unsafe_from_ordinal(n=b, rem(o, bs)))
instance Ix((a, b, c)) given (a|Ix, b|Ix, c|Ix)
def size'() = size a * size b * size c
def ordinal(tup) =
(i, j, k) = tup
ordinal((i,(j,k)))
def unsafe_from_ordinal(o) =
(i, (j, k)) = unsafe_from_ordinal o
(i, j, k)
instance Ix((a, b, c, d)) given (a|Ix, b|Ix, c|Ix, d|Ix)
def size'() = size a * size b * size c * size d
def ordinal(tup) =
(i, j, k, m) = tup
ordinal((i,(j,(k,m))))
def unsafe_from_ordinal(o) =
(i, (j, (k, m))) = unsafe_from_ordinal o
(i, j, k, m)
interface VSpace(a|Add|Sub)
(.*) : (Float, a) -> a
def (*.)(v:a, s:Float) -> a given (a|VSpace) = s .* v
def (/)( v:a, s:Float) -> a given (a|VSpace) = divide(1.0, s) .* v
def neg( v:a) -> a given (a|VSpace) = (-1.0) .* v
instance VSpace(Float)
def (.*)(x, y) = x * y
instance VSpace(n=>a) given (a|VSpace, n|Ix)
def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace((a, b)) given (a|VSpace, b|VSpace)
def (.*)(s, pair) =
(x, y) = pair
(s .* x, s .* y)
instance VSpace((i:n) => (..i) => a) given (n|Ix, a|VSpace)
def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace((i:n) => (i..) => a) given (n|Ix, a|VSpace)
def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace((i:n) => (..<i) => a) given (n|Ix, a|VSpace)
def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace((i:n) => (i<..) => a) given (n|Ix, a|VSpace)
def (.*)(s, xs) = for i. s .* xs[i]
instance VSpace(())
def (.*)(_, _) = ()
def b_to_w8(x:Bool) -> Word8 = %dataConTag(x)
def w8_to_b(x:Word8) -> Bool = %toEnum(Bool, x)
def (&&)(x:Bool, y:Bool) -> Bool =
x' = b_to_w8 x
y' = b_to_w8 y
w8_to_b $ %and(x', y')
def (||)(x:Bool, y:Bool) -> Bool =
x' = b_to_w8 x
y' = b_to_w8 y
w8_to_b $ %or(x', y')
def not(x:Bool) -> Bool =
x' = b_to_w8 x
w8_to_b $ %not(x')
More Boolean operations
TODO: move these with the others?
def select(p:Bool, x:a, y:a) -> a given (a) =
case p of
True -> x
False -> y
def b_to_i(x:Bool) -> Int = w8_to_i(b_to_w8 x)
def b_to_n(x:Bool) -> Nat = w8_to_n(b_to_w8 x)
def b_to_f(x:Bool) -> Float = i_to_f(b_to_i x)
Ordering
TODO: move this down to with Ord
?
def o_to_w8(x:Ordering) -> Word8 = %dataConTag(x)
Sum types
A sum type, or tagged union can hold values from a fixed set of types, distinguished by tags.
For those familiar with the C language, they can be though of as a combination of an enum
with a union
.
Here we define several basic kinds, and some operators on them.
data Maybe(a:Type) =
Nothing
Just(a)
def is_nothing(x:Maybe a) -> Bool given (a) =
case x of
Nothing -> True
Just(_) -> False
def is_just(x:Maybe a) -> Bool given (a) = not $ is_nothing x
def maybe(d:b, f:(a)->b, x:Maybe a) -> b given (a, b) =
case x of
Nothing -> d
Just(x') -> f x'
data Either(a:Type, b:Type) =
Left(a)
Right(b)
instance Ix(Either(a, b)) given (a|Ix, b|Ix)
def size'() = size a + size b
def ordinal(i) = case i of
Left(ai) -> ordinal ai
Right(bi) -> ordinal bi + size a
def unsafe_from_ordinal(o) =
as = nat_to_rep $ size a
o' = nat_to_rep o
case w8_to_b $ %ilt(o', as) of
True -> Left $ unsafe_from_ordinal(n=a, o)
False -> Right $ unsafe_from_ordinal(n=b, rep_to_nat (%isub(o', as)))
Subtraction on Nats
-- TODO: think more about the right API here
def unsafe_i_to_n(x:Int) -> Nat =
rep_to_nat $ internal_cast x
def n_to_i(x:Nat) -> Int =
internal_cast (nat_to_rep x)
def i_to_n(x:Int) -> Maybe Nat =
if w8_to_b $ %ilt(x, 0::Int)
then Nothing
else Just $ unsafe_i_to_n x
Monoid
A monoid is a things that have an associative binary operator and an identity element.
This is a very useful and general calls of things.
It includes:
- Addition and Multiplication of Numbers
- Boolean Logic
- Concatenation of Lists (including strings)
Monoids support
fold
operations, and similar.
interface Monoid(a|Data)
mempty : a
(<>) : (a, a) -> a
instance Monoid(n=>a) given (a|Monoid, n|Ix)
mempty = for i. mempty
def (<>)(x, y) = for i. x[i] <> y[i]
named-instance AndMonoid : Monoid(Bool)
mempty = True
def (<>)(x, y) = x && y
named-instance OrMonoid : Monoid(Bool)
mempty = False
def (<>)(x, y) = x || y
named-instance AddMonoid(a|Add) -> Monoid(a)
mempty = zero
def (<>)(x, y) = x + y
named-instance MulMonoid(a|Mul) -> Monoid(a)
mempty = one
def (<>)(x, y) = x * y
def Ref(r:Heap, a|Data) -> Type = %Ref(r, a)
def get(ref:Ref h s) -> {State h} s given (h, s) = %get(ref)
def (:=)(ref:Ref h s, x:s) -> {State h} () given (h, s) = %put(ref, x)
def ask(ref:Ref h r) -> {Read h} r given (h, r) = %ask(ref)
data AccumMonoidData(h:Heap, w:Type) = UnsafeMkAccumMonoidData(b:Type, Monoid b)
interface AccumMonoid(h:Heap, w)
getAccumMonoidData : AccumMonoidData(h, w)
instance AccumMonoid(h, n=>w) given (n|Ix, h, w) (am:AccumMonoid(h, w))
getAccumMonoidData =
UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am)
UnsafeMkAccumMonoidData(b, bm)
def (+=)(ref:Ref h w, x:w) -> {Accum h} ()
given (h, w) (am:AccumMonoid(h, w)) =
UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am)
empty = %applyMethod0(bm)
%mextend(ref, empty, \x:b y:b. %applyMethod1(bm, x, y), x)
def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i)
def fst_ref(ref: Ref h (a,b)) -> Ref h a given (b, a|Data, h) = ref.0
def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = ref.1
def run_reader(
init:r,
action:(given (h), Ref h r) -> {Read h|eff} a
) -> {|eff} a given (r|Data, a, eff) =
def explicitAction(h':Heap, ref:Ref h' r) -> {Read h'|eff} a = action ref
%runReader(init, explicitAction)
def with_reader(
init:r,
action: (given (h), Ref(h,r)) -> {Read h|eff} a
) -> {|eff} a given (r|Data, a, eff) =
run_reader(init, action)
def MonoidLifter(b:Type, w:Type) -> Type =
(given (h) (AccumMonoid(h, b))) ->> AccumMonoid(h, w)
named-instance mk_accum_monoid (given (h, w), d:AccumMonoidData(h, w)) -> AccumMonoid(h, w)
getAccumMonoidData = d
def run_accum(
bm:Monoid b,
action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a
) -> {|eff} (a, w) given (a, b, w|Data, eff) (MonoidLifter(b,w)) =
empty = %applyMethod0(bm)
def explicitAction(h':Heap, ref:Ref h' w) -> {Accum h'|eff} a =
accumMonoidData : AccumMonoidData h' b = UnsafeMkAccumMonoidData b bm
accumBaseMonoid = mk_accum_monoid accumMonoidData
%explicitApply(action, h', accumBaseMonoid, ref)
%runWriter(empty, \x:b y:b. %applyMethod1(bm, x, y), explicitAction)
def yield_accum(
m:Monoid b,
action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a
) -> {|eff} w given (a, b, w|Data, eff) (MonoidLifter b w) =
snd $ run_accum(m, action)
def run_state(
init:s,
action: (given (h), Ref h s) -> {State h |eff} a
) -> {|eff} (a,s) given (a, s|Data, eff) =
def explicitAction(h':Heap, ref:Ref h' s) -> {State h'|eff} a = action ref
%runState(init, explicitAction)
def with_state(
init:s,
action: (given (h), Ref h s) -> {State h |eff} a
) -> {|eff} a given (a, s|Data, eff) =
fst $ run_state(init, action)
def yield_state(
init:s,
action: (given (h), Ref h s) -> {State h |eff} a
) -> {|eff} s given (a, s|Data, eff) =
snd $ run_state(init, action)
def unsafe_io(
f:()->{IO|eff} a
) -> {|eff} a given (a, eff) =
f' : (() -> {IO|eff} a) = \. f()
%runIO(f')
def unreachable() -> a given (a|Data) = unsafe_io \. %throwError(a)
Eq
Equatable.
Things that we can tell if they are equal or not to other things.
interface Eq(a|Data)
(==) : (a, a) -> Bool
def (/=)(x:a, y:a) -> Bool given (a|Eq) = not $ x == y
Ord
Orderable / Comparable.
Things that can be place in a total order.
i.e. things that can be compared to other things to find if larger, smaller or equal in value.
We take the standard false-hood and pretend that this applies to Floats, even though strictly speaking this not true as our floats follow IEEE754, and thus have NaN < 1.0 == false
and 1.0 < NaN == false
.
interface Ord(a|Eq)
(>) : (a, a) -> Bool
(<) : (a, a) -> Bool
def (<=)(x:a, y:a) -> Bool given (a|Ord) = x<y || x==y
def (>=)(x:a, y:a) -> Bool given (a|Ord) = x>y || x==y
instance Eq(Float64)
def (==)(x, y) = w8_to_b $ %feq(x, y)
instance Eq(Float32)
def (==)(x, y) = w8_to_b $ %feq(x, y)
instance Eq(Int64)
def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Int32)
def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Word8)
def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Word32)
def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Word64)
def (==)(x, y) = w8_to_b $ %ieq(x, y)
instance Eq(Bool)
def (==)(x, y) = b_to_w8 x == b_to_w8 y
instance Eq(())
def (==)(_, _) = True
instance Eq(Either(a, b)) given (a|Eq, b|Eq)
def (==)(x, y) = case x of
Left(x) -> case y of
Left( y) -> x == y
Right(y) -> False
Right(x) -> case y of
Left( y) -> False
Right(y) -> x == y
instance Eq(Maybe a) given (a|Eq)
def (==)(x, y) = case x of
Just(x) -> case y of
Just(y) -> x == y
Nothing -> False
Nothing -> case y of
Just(y) -> False
Nothing -> True
instance Eq(RawPtr)
def (==)(x, y) = raw_ptr_to_i64 x == raw_ptr_to_i64 y
instance Ord(Float64)
def (>)(x, y) = w8_to_b $ %fgt(x, y)
def (<)(x, y) = w8_to_b $ %flt(x, y)
instance Ord(Float32)
def (>)(x, y) = w8_to_b $ %fgt(x, y)
def (<)(x, y) = w8_to_b $ %flt(x, y)
instance Ord(Int64)
def (>)(x, y) = w8_to_b $ %igt(x, y)
def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(Int32)
def (>)(x, y) = w8_to_b $ %igt(x, y)
def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(Word8)
def (>)(x, y) = w8_to_b $ %igt(x, y)
def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(Word32)
def (>)(x, y) = w8_to_b $ %igt(x, y)
def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(Word64)
def (>)(x, y) = w8_to_b $ %igt(x, y)
def (<)(x, y) = w8_to_b $ %ilt(x, y)
instance Ord(())
def (>)(x, y) = False
def (<)(x, y) = False
instance Eq((a, b)) given (a|Eq, b|Eq)
def (==)(p1, p2) =
(x1, y1) = p1
(x2, y2) = p2
x1 == x2 && y1 == y2
instance Ord((a, b)) given (a|Ord, b|Ord)
def (>)(p1, p2) =
(x1, y1) = p1
(x2, y2) = p2
x1 > x2 || (x1 == x2 && y1 > y2)
def (<)(p1, p2) =
(x1, y1) = p1
(x2, y2) = p2
x1 < x2 || (x1 == x2 && y1 < y2)
instance Eq(Ordering)
def (==)(x, y) = o_to_w8 x == o_to_w8 y
instance Eq(Nat)
def (==)(x, y) = nat_to_rep x == nat_to_rep y
instance Ord(Nat)
def (>)(x, y) = nat_to_rep x > nat_to_rep y
def (<)(x, y) = nat_to_rep x < nat_to_rep y
instance Eq(Fin n) given (n)
def (==)(x, y) = ordinal x == ordinal y
instance Ord(Fin n) given (n)
def (>)(x, y) = ordinal x > ordinal y
def (<)(x, y) = ordinal x < ordinal y
instance Ix(Bool)
def size'() = 2
def ordinal(b) = case b of
False -> 0
True -> 1
def unsafe_from_ordinal(i) = i > 0
instance Ix(Maybe a) given (a|Ix)
def size'() = size a + 1
def ordinal(i) = case i of
Just(ai) -> ordinal ai
Nothing -> size a
def unsafe_from_ordinal(o) =
case o == size a of
False -> Just $ unsafe_from_ordinal o
True -> Nothing
interface NonEmpty(n|Ix)
first_ix : n
instance NonEmpty(())
first_ix = unsafe_from_ordinal(0)
instance NonEmpty(Bool)
first_ix = unsafe_from_ordinal 0
instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix)
first_ix = unsafe_from_ordinal 0
instance NonEmpty(Maybe a) given (a|Ix)
first_ix = unsafe_from_ordinal 0
struct Post(segment:Type) =
val : Nat
instance Ix(Post segment) given (segment|Ix)
def size'() = size segment + 1
def ordinal(i) = i.val
def unsafe_from_ordinal(i) = Post(i)
def left_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i)
def right_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i + 1)
def left_fence(p:Post n) -> Maybe n given (n|Ix) =
ix = ordinal p
if ix == 0
then Nothing
else Just $ unsafe_from_ordinal $ ix -| 1
def right_fence(p:Post n) -> Maybe n given (n|Ix) =
ix = ordinal p
if ix == size n
then Nothing
else Just $ unsafe_from_ordinal ix
def last_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))
instance NonEmpty(Post n) given (n|Ix)
first_ix = unsafe_from_ordinal(n=Post n, 0)
def scan(
init:a,
body:(n, a)->(a,b)
) -> (a, n=>b) given (a|Data, b, n|Ix) =
swap $ run_state(init) \s. for i.
c = get s
(c', y) = body(i, c)
s := c'
y
def fold(init:a, body:(n,a)->a) -> a given (n|Ix, a|Data) =
fst $ scan init \i x. (body(i, x), ())
def compare(x:a, y:a) -> Ordering given (a|Ord) =
if x < y
then LT
else if x == y
then EQ
else GT
instance Monoid(Ordering)
mempty = EQ
def (<>)(x, y) =
case x of
LT -> LT
GT -> GT
EQ -> y
instance Eq(n=>a) given (n|Ix, a|Eq)
def (==)(xs, ys) =
yield_accum AndMonoid \ref.
for i. ref += xs[i] == ys[i]
instance Ord(n=>a) given (n|Ix, a|Ord)
def (>)(xs, ys) =
f: Ordering =
fold EQ $ \i c. c <> compare(xs[i], ys[i])
f == GT
def (<)(xs, ys) =
f: Ordering =
fold EQ $ \i c. c <> compare(xs[i], ys[i])
f == LT
interface Subset(subset, superset)
inject' : (subset) -> superset
project' : (superset) -> Maybe subset
unsafe_project' : (superset) -> subset
def inject(x:from) -> to given (to, from) (Subset(from, to)) = inject'(x)
def project(x:from) -> Maybe to given (to, from) (Subset(to, from)) = project'(x)
def unsafe_project(x:from) -> to given (to, from) (Subset(to, from)) = unsafe_project'(x)
instance Subset(a, c) given (a, b, c) (Subset(a, b), Subset(b, c))
def inject'(x) = inject $ inject(to=b, x)
def project'(x) = case project(to=b, x) of
Nothing -> Nothing
Just(y)-> project y
def unsafe_project'(x) = unsafe_project $ unsafe_project(to=b, x)
def unsafe_project_rangefrom(j:q) -> RangeFrom(q, i) given (q|Ix, i:q) =
RangeFrom unsafe_nat_diff(ordinal j, ordinal i)
instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q)
def inject'(j) =
unsafe_from_ordinal $ j.val + ordinal i
def project'(j) =
j' = ordinal j
i' = ordinal i
if j' < i'
then Nothing
else Just $ RangeFrom $ unsafe_nat_diff(j', i')
def unsafe_project'(j) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i)
instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q)
def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i + 1
def project'(j) =
j' = ordinal j
i' = ordinal i
if j' <= i'
then Nothing
else Just $ RangeFromExc unsafe_nat_diff(j', i' + 1)
def unsafe_project'(j) =
RangeFromExc unsafe_nat_diff(ordinal j, ordinal i + 1)
instance Subset(RangeTo(q, i), q) given (q|Ix, i:q)
def inject'(j) = unsafe_from_ordinal j.val
def project'(j) =
j' = ordinal j
i' = ordinal i
if j' > i'
then Nothing
else Just $ RangeTo j'
def unsafe_project'(j) = RangeTo (ordinal j)
instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q)
def inject'(j) = unsafe_from_ordinal j.val
def project'(j) =
j' = ordinal j
i' = ordinal i
if j' >= i'
then Nothing
else Just $ RangeToExc j'
def unsafe_project'(j) = RangeToExc (ordinal j)
instance Subset(RangeToExc(q, i), RangeTo(q, i)) given (q|Ix, i:q)
def inject'(j) = unsafe_from_ordinal j.val
def project'(j) =
j' = ordinal j
i' = ordinal i
if j' >= i'
then Nothing
else Just $ RangeToExc j'
def unsafe_project'(j) = RangeToExc (ordinal j)
Elementary/Special Functions
This is more or less the standard LibM fare.
Roughly it lines up with some definitions of the set of Elementary and/or Special.
In truth, nothing is elementary or special except that we humans have decided it is.
Many, but not all of these functions are Transcendental.
interface Floating(a:Type)
exp : (a) -> a
exp2 : (a) -> a
log : (a) -> a
log2 : (a) -> a
log10 : (a) -> a
log1p : (a) -> a
sin : (a) -> a
cos : (a) -> a
tan : (a) -> a
sinh : (a) -> a
cosh : (a) -> a
tanh : (a) -> a
floor : (a) -> a
ceil : (a) -> a
round : (a) -> a
sqrt : (a) -> a
pow : (a, a) -> a
lgamma : (a) -> a
erf : (a) -> a
erfc : (a) -> a
def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y)
def float32_sinh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))), 2.0)
def float32_cosh(x:Float32) -> Float32 = %fdiv(%fadd(%exp(x), %exp(%fsub(0.0,x))), 2.0)
def float32_tanh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x)))
,%fadd(%exp(x), %exp(%fsub(0.0,x))))
def float64_sinh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0)
def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0)
def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x)))
,%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))))
instance Floating(Float64)
def exp(x) = %exp(x)
def exp2(x) = %exp2(x)
def log(x) = %log(x)
def log2(x) = %log2(x)
def log10(x) = %log10(x)
def log1p(x) = %log1p(x)
def sin(x) = %sin(x)
def cos(x) = %cos(x)
def tan(x) = %tan( x)
def sinh(x) = float64_sinh(x)
def cosh(x) = float64_cosh(x)
def tanh(x) = float64_tanh(x)
def floor(x) = %floor(x)
def ceil(x) = %ceil(x)
def round(x) = %round(x)
def sqrt(x) = %sqrt(x)
def pow(x,y) = %fpow(x,y)
def lgamma(x)= %lgamma(x)
def erf(x) = %erf(x)
def erfc(x) = %erfc(x)
instance Floating(Float32)
def exp(x) = %exp(x)
def exp2(x) = %exp2(x)
def log(x) = %log(x)
def log2(x) = %log2(x)
def log10(x) = %log10(x)
def log1p(x) = %log1p(x)
def sin(x) = %sin(x)
def cos(x) = %cos(x)
def tan(x) = %tan(x)
def sinh(x) = float32_sinh(x)
def cosh(x) = float32_cosh(x)
def tanh(x) = float32_tanh(x)
def floor(x) = %floor(x)
def ceil(x) = %ceil(x)
def round(x) = %round(x)
def sqrt(x) = %sqrt(x)
def pow(x,y) = %fpow(x, y)
def lgamma(x)= %lgamma(x)
def erf(x) = %erf(x)
def erfc(x) = %erfc(x)
struct Ptr(a:Type) =
val : RawPtr
def cast_ptr(ptr: Ptr a) -> Ptr b given (a, b) = Ptr(ptr.val)
interface Storable(a|Data)
store : (Ptr a, a) -> {IO} ()
load : (Ptr a) -> {IO} a
storage_size : () -> Nat
instance Storable(Word8)
def store(ptr, x) = %ptrStore(ptr.val, x)
def load(ptr) = %ptrLoad(ptr.val)
def storage_size() = 1
instance Storable(Int32)
def store(ptr, x) = %ptrStore(internal_cast(to=%Int32Ptr(), ptr.val), x)
def load(ptr) = %ptrLoad(internal_cast(to=%Int32Ptr(), ptr.val))
def storage_size() = 4
instance Storable(Word32)
def store(ptr, x) = %ptrStore(internal_cast(to=%Word32Ptr(), ptr), x)
def load(ptr) = %ptrLoad(internal_cast(to=%Word32Ptr(), ptr))
def storage_size() = 4
instance Storable(Float32)
def store(ptr, x) = %ptrStore(internal_cast(to=%Float32Ptr(), ptr.val), x)
def load(ptr) = %ptrLoad(internal_cast(to=%Float32Ptr(), ptr.val))
def storage_size() = 4
instance Storable(Nat)
def store(ptr, x) = store(Ptr(ptr.val), nat_to_rep x)
def load(ptr) = rep_to_nat $ load(Ptr(ptr.val))
def storage_size() = storage_size(a=NatRep)
instance Storable(Ptr a) given (a)
def store(ptr, x) = %ptrStore(internal_cast(to=%PtrPtr(), ptr.val), x.val)
def load(ptr) = Ptr(%ptrLoad(internal_cast(to=%PtrPtr(), ptr)))
def storage_size() = 8
def malloc(n:Nat) -> {IO} (Ptr a) given (a|Storable) =
numBytes = storage_size(a=a) * n
Ptr(%alloc(nat_to_rep numBytes))
def free(ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val)
def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) =
i' = nat_to_rep $ i * storage_size(a=a)
Ptr(%ptrOffset(ptr.val, i'))
def store_table(ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) =
for_ i. store(ptr +>> ordinal i, tab[i])
def memcpy(dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) =
for_ i:(Fin n).
i' = ordinal i
store(dest +>> i', load $ src +>> i')
def with_alloc(
n:Nat,
action: (Ptr a) -> {IO} b
) -> {IO} b given (a|Storable, b) =
ptr = malloc n
result = action ptr
free ptr
result
def with_table_ptr(
xs:n=>a,
action: (Ptr a) -> {IO} b
) -> {IO} b given (a|Storable, b, n|Ix) =
ptr <- with_alloc(size n)
for i. store(ptr +>> ordinal i, xs[i])
action ptr
def table_from_ptr(ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) =
for i. load $ ptr +>> ordinal i
Miscellaneous common utilities
pi : Float = 3.141592653589793
def id(x:a) -> a given (a) = x
def dup(x:a) -> (a, a) given (a) = (x, x)
def map(f:(a)->{|eff} b, xs: n=>a) -> {|eff} (n=>b) given (a, b, n|Ix, eff) =
for i. f xs[i]
def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a, b, n|Ix, eff) =
for i. f xs[i]
def zip(xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a, b, n|Ix) = for i. (xs[i], ys[i])
def unzip(xys:n=>(a,b)) -> (n=>a , n=>b) given (a, b, n|Ix)= (each xys fst, each xys snd)
def fanout(x:a) -> n=>a given (n|Ix, a) = for i. x
def sq(x:a) -> a given (a|Mul) = x * x
def abs(x:a) -> a given (a|Sub|Ord) = select(x > zero, x, zero - x)
def mod(x:a, y:a) -> a given (a|Add|Integral) = rem(y + rem(x, y), y)
def (>>>)(f:(a) -> b, g:(b) -> c) -> (a) -> c given (a, b, c) = \x. g(f(x))
def (<<<)(f:(b) -> c, g:(a) -> b) -> (a) -> c given (a, b, c) = \x. f(g(x))
instance Floating(n=>a) given (a|Floating, n|Ix)
def exp(x) = each x exp
def exp2(x) = each x exp2
def log(x) = each x log
def log2(x) = each x log2
def log10(x) = each x log10
def log1p(x) = each x log1p
def sin(x) = each x sin
def cos(x) = each x cos
def tan(x) = each x tan
def sinh(x) = each x sinh
def cosh(x) = each x cosh
def tanh(x) = each x tanh
def floor(x) = each x floor
def ceil(x) = each x ceil
def round(x) = each x round
def sqrt(x) = each x sqrt
def pow(x, y) = for i. pow(x[i], y[i])
def lgamma(x) = each x lgamma
def erf(x) = each x erf
def erfc(x) = each x erfc
def reduce(identity:a, combine:(a,a)->a, xs:n=>a) -> a given (a|Data, n|Ix) =
fold identity \i c. combine(c, xs[i])
def scan'(init:a, body:(n,a)->a) -> n=>a given (a|Data, n|Ix) =
snd $ scan init \i x. dup(body(i, x))
def fsum(xs:n=>Float) -> Float given (n|Ix) =
yield_accum(AddMonoid Float) \ref. for i. ref += xs[i]
def sum(xs:n=>v) -> v given (n|Ix, v|Add) = reduce(zero, (+), xs)
def prod(xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(one , (*), xs)
def mean(xs:n=>v) -> v given (n|Ix, v|VSpace) = sum xs / n_to_f (size n)
def std(xs:n=>v) -> v given (n|Ix, v|Mul|Sub|VSpace|Floating) = sqrt $ mean (each xs sq) - sq (mean xs)
def any(xs:n=>Bool) -> Bool given (n|Ix) = reduce(False, (||), xs)
def all(xs:n=>Bool) -> Bool given (n|Ix) = reduce(True , (&&), xs)
def apply_n(n:Nat, x:a, f:(a) -> a) -> a given (a|Data) =
yield_state x \ref. for _:(Fin n).
ref := f (get ref)
cumulative sum
TODO: Move this to be with reductions?
It's a kind of scan
.
def cumsum(xs: n=>a) -> n=>a given (n|Ix, a|Add) =
total <- with_state zero
for i.
newTotal = get total + xs[i]
total := newTotal
newTotal
def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) =
total <- with_state zero
for i.
oldTotal = get total
total := oldTotal + xs[i]
oldTotal
Automatic differentiation
def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a, b) =
%linearize(\x. f x, x)
def jvp(f:(a)->b, x:a, t:a) -> b given (a, b) = (snd $ linearize(f, x))(t)
def transpose_linear(f:(a)->b) -> (b)->a given (a, b) = \ct.
%linearTranspose(\x. f x, ct)
def vjp(f:(a)->b, x:a) -> (b, (b)->a) given (a, b) =
(y, df) = linearize(f, x)
(y, transpose_linear df)
def grad(f:(a)->Float, x:a) -> a given (a) = (snd vjp(f, x))(1.0)
def deriv(f:(Float)->Float, x:Float) -> Float = jvp(f, x, 1.0)
def deriv_rev(f:(Float)->Float, x:Float) -> Float = (snd vjp(f, x))(1.0)
data SymbolicTangent(a) =
ZeroTangent
SomeTangent(a)
def someTangent(x:SymbolicTangent a) -> a given (a|VSpace) =
case x of
ZeroTangent -> zero
SomeTangent(x') -> x'
Approximate Equality
TODO: move this outside the AD section to be with equality?
interface HasAllClose(a)
allclose : (a, a, a, a) -> Bool
interface HasDefaultTolerance(a)
default_atol : a
default_rtol : a
def (~~)(x:a, y:a) -> Bool given (a|HasAllClose|HasDefaultTolerance) =
allclose(default_atol, default_rtol, x, y)
instance HasAllClose(Float32)
def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y)
instance HasAllClose(Float64)
def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y)
instance HasDefaultTolerance(Float32)
default_atol = f_to_f32 0.00001
default_rtol = f_to_f32 0.0001
instance HasDefaultTolerance(Float64)
default_atol = f_to_f64 0.00000001
default_rtol = f_to_f64 0.00001
instance HasAllClose((a, b)) given ( a|HasDefaultTolerance|HasAllClose
, b|HasDefaultTolerance|HasAllClose)
def allclose(atol, rtol, pair1, pair2) =
(x1, x2) = pair1
(y1, y2) = pair2
(x1 ~~ y1) && (x2 ~~ y2)
instance HasDefaultTolerance((a, b)) given (a|HasDefaultTolerance,b|HasDefaultTolerance)
default_atol = (default_atol, default_atol)
default_rtol = (default_rtol, default_rtol)
instance HasAllClose(n=>t) given (n|Ix, t|HasAllClose)
def allclose(atol, rtol, a, b) =
all for i:n. allclose(atol[i], rtol[i], a[i], b[i])
instance HasDefaultTolerance(n=>t) given (n|Ix, t|HasDefaultTolerance)
default_atol = for i. default_atol
default_rtol = for i. default_rtol
def check_deriv_base(f:(Float)->Float, x:Float) -> Bool =
eps = 0.01
ansFwd = deriv( f, x)
ansRev = deriv_rev( f, x)
ansNumeric = (f (x + eps) - f (x - eps)) / (2. * eps)
ansFwd ~~ ansNumeric && ansRev ~~ ansNumeric
def check_deriv(f:(Float)->Float, x:Float) -> Bool =
check_deriv_base(f, x) && check_deriv_base(\x. deriv(f, x), x)
data List(a)=
AsList(n:Nat, elements:(Fin n => a))
instance Eq(List a) given (a|Eq)
def (==)(xsList, ysList) =
AsList(nx,xs) = xsList
AsList(ny,ys) = ysList
if nx /= ny
then False
else all for i:(Fin nx).
xs[i] == ys[unsafe_from_ordinal (ordinal i)]
def unsafe_cast_table(xs:from=>a) -> to=>a given (to|Ix, from|Ix, a) =
for i. xs[unsafe_from_ordinal (ordinal i)]
def to_list(xs:n=>a) -> List a given (n|Ix, a) =
n' = size n
AsList(_, unsafe_cast_table(to=Fin n', xs))
instance Monoid(List a) given (a|Data)
mempty = AsList(_, [])
def (<>)(x, y) =
AsList(nx,xs) = x
AsList(ny,ys) = y
nz = nx + ny
to_list for i:(Fin nz).
i' = ordinal i
case i' < nx of
True -> xs[unsafe_from_ordinal i']
False -> ys[unsafe_from_ordinal $ unsafe_nat_diff(i', nx)]
named-instance ListMonoid (a|Data) -> Monoid(List a)
mempty = mempty
def (<>)(x, y) = x <> y
def append(list: Ref(h, List a), x:a) -> {Accum h} ()
given (a|Data, h) (AccumMonoid(h, List a)) =
list += to_list [x]
def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a) =
slice_size = unsafe_nat_diff(ordinal end, ordinal start)
to_list for i:(Fin slice_size).
xs[unsafe_from_ordinal(n=n, ordinal i + ordinal start)]
String : Type = List Char
def string_from_char_ptr(n:Word32, ptr:Ptr Char) -> {IO} String =
AsList(rep_to_nat n, table_from_ptr ptr)
def codepoint(c:Char) -> Int = w8_to_i c
struct CString =
ptr : RawPtr
def with_c_string(
s:String,
action: (CString) -> {IO} a
) -> {IO} a given (a) =
AsList(n, s') = s <> "\NUL"
with_table_ptr s' \ptr. action CString(ptr.val)
Show interface
For things that can be shown.
show
gives a string representation of its input.
No particular promises are made to exactly what that representation will contain.
In particular it is not promised to be parseable.
Nor does it promise a particular level of precision for numeric values.
interface Show(a)
show : (a) -> String
instance Show(String)
def show(x) = x
foreign "showInt32" showInt32 : (Int32) -> {IO} (Word32, RawPtr)
instance Show(Int32)
def show(x) = unsafe_io \.
(n, ptr) = showInt32 x
string_from_char_ptr(n, Ptr ptr)
foreign "showInt64" showInt64 : (Int64) -> {IO} (Word32, RawPtr)
instance Show(Int64)
def show(x) = unsafe_io \.
(n, ptr) = showInt64 x
string_from_char_ptr(n, Ptr ptr)
instance Show(Nat)
def show(x) = show $ n_to_i64 x
foreign "showFloat32" showFloat32 : (Float32) -> {IO} (Word32, RawPtr)
instance Show(Float32)
def show(x) = unsafe_io \.
(n, ptr) = showFloat32 x
string_from_char_ptr(n, Ptr ptr)
foreign "showFloat64" showFloat64 : (Float64) -> {IO} (Word32, RawPtr)
instance Show(Float64)
def show(x) = unsafe_io \.
(n, ptr) = showFloat64 x
string_from_char_ptr(n, Ptr ptr)
instance Show(())
def show(_) = "()"
instance Show((a, b)) given (a|Show, b|Show)
def show(tup) =
(x, y) = tup
"(" <> show x <> ", " <> show y <> ")"
instance Show((a, b, c)) given (a|Show, b|Show, c|Show)
def show(tup) =
(x, y, z) = tup
"(" <> show x <> ", " <> show y <> ", " <> show z <> ")"
instance Show((a, b, c, d)) given (a|Show, b|Show, c|Show, d|Show)
def show(tup) =
(x, y, z, w) = tup
"(" <> show x <> ", " <> show y <> ", " <> show z <> ", " <> show w <> ")"
Parse interface
For types that can be parsed from a String
.
interface Parse(a)
parseString : (String) -> Maybe a
foreign "strtof" strtofFFI : (RawPtr, RawPtr) -> {IO} Float
instance Parse(Float)
def parseString(str) = unsafe_io \.
AsList(str_len, _) = str
with_c_string str \cStr.
with_alloc 1 \end_ptr:(Ptr (Ptr Char)).
result = strtofFFI(cStr.ptr, end_ptr.val)
str_end_ptr = load end_ptr
consumed = raw_ptr_to_i64 str_end_ptr.val - raw_ptr_to_i64 cStr.ptr
if consumed == (n_to_i64 str_len) then Just result else Nothing
Floating-point helper functions
TODO: Move these to be with Elementary/Special functions. Or move those to be here.
def sign(x:Float) -> Float =
case x > 0.0 of
True -> 1.0
False -> case x < 0.0 of
True -> -1.0
False -> x
def copysign(a:Float, b:Float) -> Float =
case b > 0.0 of
True -> a
False -> case b < 0.0 of
True -> (-a)
False -> 0.0
def isinf(x:Float) -> Bool = (x == infinity) || (x == -infinity)
def isnan(x:Float) -> Bool = not (x >= x && x <= x)
def either_is_nan(x:Float, y:Float) -> Bool = (isnan x) || (isnan y)
def is_null_raw_ptr(ptr:RawPtr) -> Bool =
raw_ptr_to_i64 ptr == 0
def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a) =
if is_null_raw_ptr ptr
then Nothing
else Just $ Ptr ptr
def c_string_ptr(s:CString) -> Maybe (Ptr Char) = from_nullable_raw_ptr s.ptr
data StreamMode =
ReadMode
WriteMode
struct Stream(mode:StreamMode) =
ptr : RawPtr
foreign "fopen" fopenFFI : (RawPtr, RawPtr) -> {IO} RawPtr
foreign "fclose" fcloseFFI : (RawPtr) -> {IO} Int64
foreign "fwrite" fwriteFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64
foreign "fread" freadFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64
foreign "fflush" fflushFFI : (RawPtr) -> {IO} Int64
def fopen(path:String, mode:StreamMode) -> {IO} (Stream mode) =
modeStr = case mode of
ReadMode -> "r"
WriteMode -> "w"
with_c_string path \cPath.
with_c_string modeStr \cMode.
Stream $ fopenFFI(cPath.ptr, cMode.ptr)
def fclose(stream:Stream mode) -> {IO} () given (mode) =
fcloseFFI stream.ptr
()
def fwrite(stream:Stream WriteMode, s:String) -> {IO} () =
AsList(n, s') = s
with_table_ptr s' \ptr.
fwriteFFI(ptr.val, i_to_i64 1, n_to_i64 n, stream.ptr)
fflushFFI stream.ptr
()
Iteration
TODO: move this out of the file-system section
def while(body: () -> {|eff} Bool) -> {|eff} () given (eff) =
body' : () -> {|eff} Word8 = \. b_to_w8 $ body()
%while(body')
data IterResult(a|Data) =
Continue
Done(a)
def lift_state(ref: Ref(h, c), f:(a) -> {|eff} b, x:a) -> {State h|eff} b
given (a, b, c, h, eff) =
f x
def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) =
result = yield_state Nothing \resultRef.
i <- with_state 0
while \.
continue = is_nothing $ get resultRef
if continue then
case lift_state(resultRef, (\x. lift_state(i, body, x)), get i) of
Continue -> i := get i + 1
Done(result) -> resultRef := Just result
continue
case result of
Just(ans) -> ans
Nothing -> unreachable()
def bounded_iter(
maxIters:Nat,
fallback:a,
body:(Nat) -> {|eff} IterResult a
) -> {|eff} a given (a|Data, eff) = iter \i.
if i >= maxIters
then Done fallback
else body i
def get_output_stream() -> {IO} Stream WriteMode =
Stream $ %outputStream()
@noinline
def print(s:String) -> {IO} () =
stream = get_output_stream()
fwrite(stream, s)
fwrite(stream, "\n")
Partial functions
A partial function in this context is a function that can error.
i.e. a function that is not actually defined for all of its supposed domain.
Not to be confused with a partially applied function
@noinline
def error(s:String) -> a given (a|Data) = unsafe_io \.
print s
%throwError(a)
def todo() ->> a given (a|Data) = error "TODO: implement it!"
@noinline
def from_ordinal_error(i:Nat, upper:Nat) -> String =
"Ordinal index out of range:" <> show i <> " >= " <> show upper
def from_ordinal(i:Nat) -> n given (n|Ix) =
case i < size n of
True -> unsafe_from_ordinal i
False -> error $ from_ordinal_error(i, size n)
def to_ix(i:Nat) -> Maybe n given (n|Ix) =
case i < size n of
True -> Just $ unsafe_from_ordinal i
False -> Nothing
def cast_table(xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) =
case size from == size to of
True -> unsafe_cast_table xs
False -> error $
"Table size mismatch in cast: " <> show (size from) <> " vs " <> show (size to)
def asidx(i:Nat) -> n given (n|Ix) = from_ordinal i
def (@)(i:Nat, n|Ix) -> n = from_ordinal i
def slice(xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a) =
for i. xs[from_ordinal (ordinal i + start)]
def head(xs:n=>a) -> a given (n|Ix, a) = xs[0@_]
def tail(xs:n=>a, start:Nat) -> List a given (n|Ix, a) =
numElts = size n -| start
to_list $ slice(xs, start, Fin numElts)
Pseudorandom number generator utilities
Dex does not use a stateful random number generator.
Rather it uses what is known as a split-able random number generator, which is based on a hash function.
Dex's PRNG system is modelled directly after JAX's, which is based on a well established but shockingly underused idea from the functional programming community: the splittable PRNG. It's a good idea for many reasons, but it's especially helpful in a parallel setting. If you want to read more, Splittable pseudorandom number generators using cryptographic hashing describes the splitting model itself and D.E. Shaw Research's counter-based PRNG proposes the particular hash function we use.
@noinline
def threefry_2x32(k:Word64, count:Word64) -> Word64 =
rotations1 = [13, 15, 26, 6]
rotations2 = [17, 29, 16, 24]
k0 = low_word k
k1 = high_word k
k2 = k0 .^. k1 .^. (n_to_w32 466688986)
x = low_word count
y = high_word count
x = x + k0
y = y + k1
rotations = [rotations1, rotations2]
ks = [k1, k2, k0]
(x, y) = yield_state (x, y) \ref. for i:(Fin 5).
for j.
(x, y) = get ref
rotationIndex = unsafe_from_ordinal (ordinal i `mod` 2)
rot = rotations[rotationIndex, j]
x = x + y
y = (y .<<. rot) .|. (y .>>. (32 - rot))
y = x .^. y
ref := (x, y)
(x, y) = get ref
x = x + ks[unsafe_from_ordinal (ordinal i `mod` 3)]
y = y + ks[unsafe_from_ordinal (((ordinal i)+1) `mod` 3)] + n_to_w32 ((ordinal i)+1)
ref := (x, y)
(w32_to_w64 x .<<. 32) .|. (w32_to_w64 y)
def hash(x:Key, y:Nat) -> Key =
y64 = n_to_w64 y
threefry_2x32(x, y64)
def new_key(x:Nat) -> Key = hash(0, x)
def many(f:(Key)->a, k:Key, i:n) -> a given (a, n|Ix) = f hash(k, ordinal i)
def ixkey(k:Key, i:n) -> Key given (n|Ix) = hash(k, ordinal i)
def split_key(k:Key) -> Fin n => Key given (n:Nat) = for i. ixkey(k, i)
Sample Generators
These functions generate samples taken from, different distributions.
Such as rand_mat
with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. Note that additional standard distributions are provided by the stats
library.
def rand(k:Key) -> Float =
exponent_bits = 1065353216
mantissa_bits = (high_word k .&. 8388607)
bits = exponent_bits .|. mantissa_bits
%bitcast(Float, bits) - 1.0
def rand_vec(n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (a) =
for i:(Fin n). f ixkey(k, i)
def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a) =
for i j. f ixkey(k, (i, j))
def randn(k:Key) -> Float =
[k1, k2] = split_key k
u1 = 1.0 - (rand k1)
u2 = rand k2
sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2)
def rand_int(k:Key) -> Nat = w64_to_n k `mod` 2147483647
def randn_vec(k:Key) -> n=>Float given (n|Ix) =
for i. randn (ixkey(k, i))
def rand_idx(k:Key) -> n given (n|Ix) =
rand k * n_to_f (size n) | floor | f_to_n | unsafe_from_ordinal
interface InnerProd(v|VSpace)
inner_prod : (v, v) -> Float
instance InnerProd(Float)
def inner_prod(x, y) = x * y
instance InnerProd(n=>a) given (a|InnerProd, n|Ix)
def inner_prod(x, y) =sum for i. inner_prod(x[i], y[i])
Arbitrary
Type class for generating example values
interface Arbitrary(a)
arb : (Key) -> a
instance Arbitrary(Bool)
def arb(key) = key .&. 1 == 0
instance Arbitrary(Float32)
def arb(key) = randn key
instance Arbitrary(Int32)
def arb(key) = f_to_i $ randn key * 5.0
instance Arbitrary(Nat)
def arb(key) = f_to_n $ randn key * 5.0
instance Arbitrary(n=>a) given (n|Ix, a|Arbitrary)
def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((i:n)=>(..<i) => a) given (n|Ix, a|Arbitrary)
def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((i:n)=>(..i) => a) given (n|Ix, a|Arbitrary)
def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((i:n)=>(i..) => a) given (n|Ix, a|Arbitrary)
def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary)
def arb(key) = for i. arb $ ixkey(key, i)
instance Arbitrary((a, b)) given (a|Arbitrary, b|Arbitrary)
def arb(key) =
[k1, k2] = split_key key
(arb k1, arb k2)
instance Arbitrary(Fin n) given (n)
def arb(key) = rand_idx key
Returns the bucket of x
assuming boundaries xs
as a Post n
.
The boundaries must already be sorted, and are inclusive on the left.
In other words, if there is an index i
such that xs.i <= x
,
returns the right_post
of the highest such index; otherwise returns
first_ix : Post n
, which is also the left_post
of the minimum i
.
This is equivalent to the right-biased formulation: if an index i
exists such that x < xs.i
, returns the left_post
of the least such
i
, otherwise returns last_ix : Post n
, i.e., the right_post
of
the maximum i
.
def search_sorted(xs:n=>a, x:a) -> Post n given (n|Ix, a|Ord) =
if size n == 0
then first_ix
else if x < xs[from_ordinal 0]
then first_ix
else
low <- with_state(0)
high <- with_state(size n)
_ <- iter
numLeft = n_to_i (get high) - n_to_i (get low)
if numLeft == 1
then Done $ right_post $ from_ordinal $ get low
else
centerIx = get low + unsafe_i_to_n (numLeft `idiv` 2)
if x < xs[from_ordinal centerIx]
then high := centerIx
else low := centerIx
Continue
If i
exists such that xs.i == x
, returns Just
of the largest
such i
, otherwise returns Nothing
.
def search_sorted_exact(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
case left_fence(search_sorted(xs, x)) of
Just i -> if xs[i] == x then Just i else Nothing
Nothing -> Nothing
def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y)
def max_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x > f y, x, y)
def min(x1: o, x2: o) -> o given (o|Ord) = min_by(id, x1, x2)
def max(x1: o, x2: o) -> o given (o|Ord) = max_by(id, x1, x2)
def minimum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) =
reduce(xs[0@_], \x y. min_by(f, x, y), xs)
def maximum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) =
reduce(xs[0@_], \x y. max_by(f, x, y), xs)
def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(id, xs)
def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(id, xs)
argmin/argmax
-- TODO: put in same section as searchsorted
def argscan(comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) =
zeroth = (0@_, xs[0@_])
compare = \p1 p2.
(idx1, x1) = p1
(idx2, x2) = p2
select(comp(x1, x2), (idx1, x1), (idx2, x2))
zipped = for i. (i, xs[i])
fst $ reduce(zeroth, compare, zipped)
def argmin(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((<), xs)
def argmax(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((>), xs)
def lexical_order(
compareElements:(n,n)->Bool,
compareLengths: (Nat,Nat)->Bool,
xList:List n,
yList:List n
) -> Bool given (n|Ord) =
--
--
AsList(nx, xs) = xList
AsList(ny, ys) = yList
iter \i.
case i == min(nx, ny) of
True -> Done $ compareLengths(nx, ny)
False ->
xi = xs[unsafe_from_ordinal i]
yi = ys[unsafe_from_ordinal i]
case compareElements(xi, yi) of
True -> Done True
False -> case xi == yi of
True -> Continue
False -> Done False
instance Ord(List n) given (n|Ord)
def (>)(xs, ys) = lexical_order((>), (>), xs, ys)
def (<)(xs, ys) = lexical_order((<), (<), xs, ys)
def clip(bounds:(a,a), x:a) -> a given (a|Ord) =
(low,high) = bounds
min(high, max(low, x))
Trigonometric functions
TODO: these should be with the other Elementary/Special Functions
atan/atan2
def atan_inner(x:Float) -> Float =
s = x * x
r = 0.0027856871
r = r * s - 0.0158660002
r = r * s + 0.042472221
r = r * s - 0.0749753043
r = r * s + 0.106448799
r = r * s - 0.142070308
r = r * s + 0.199934542
r = r * s - 0.333331466
r = r * s
r * x + x
def min_and_max(x:a, y:a) -> (a, a) given (a|Ord) =
select(x < y, (x, y), (y, x))
def atan2(y:Float, x:Float) -> Float =
abs_x = abs x
abs_y = abs y
(min_abs_x_y, max_abs_x_y) = min_and_max(abs_x, abs_y)
a = atan_inner (min_abs_x_y / max_abs_x_y)
a = select(abs_x <= abs_y, (pi / 2.0) -a, a)
a = select(x < 0.0, pi - a, a)
t = select(x < 0.0, pi, 0.0)
a = select(y == 0.0, t, a)
t = select(x < 0.0, 3.0 * pi / 4.0, pi / 4.0)
a = select(isinf x && isinf y, t, a)
a = copysign(a, y)
select(either_is_nan(x, y), nan, a)
def atan(x:Float) -> Float = atan2(x, 1.0)
Miscellaneous utilities
TODO: all of these should be in some other section
def reflect(i:n) -> n given (n|Ix) =
unsafe_from_ordinal $ unsafe_nat_diff(size n, ordinal i + 1)
def reverse(x:n=>a) -> n=>a given (n|Ix, a) =
for i. x[reflect i]
def wrap_periodic(n|Ix, i:Nat) -> n =
unsafe_from_ordinal(n=n, i `mod` size n)
def pad_to(m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) =
n' = size n
for i.
i' = ordinal i
case i' < n' of
True -> xs[i'@_]
False -> x
def idiv_ceil(x:Nat, y:Nat) -> Nat = x `idiv` y + b_to_n (x `rem` y /= 0)
def intdiv2(x:Nat) -> Nat = rep_to_nat $ %shr(nat_to_rep x, 1 :: NatRep)
def intpow2(power:Nat) -> Nat = rep_to_nat $ %shl(1 :: NatRep, nat_to_rep power)
def is_odd(x:Nat) -> Bool = rem(x, 2) == 1
def is_even(x:Nat) -> Bool = rem(x, 2) == 0
def is_power_of_2(x:Nat) -> Bool =
x' = nat_to_rep x
if x' == 0
then False
else 0 == %and(x', (%isub(x', 1::NatRep)))
def natlog2(x:Nat) -> Nat =
tmp = yield_state 0 \ans.
cmp <- run_state 1
while \.
if x >= (get cmp)
then
ans := (get ans) + 1
cmp := rep_to_nat $ %shl(nat_to_rep $ get cmp, 1 :: NatRep)
True
else
False
unsafe_nat_diff(tmp, 1)
def nextpow2(x:Nat) -> Nat =
case is_power_of_2 x of
True -> natlog2 x
False -> 1 + natlog2 x
def general_integer_power(
times:(a,a)->a,
one:a, base:a,
power:Nat
) -> a given (a|Data) =
iters = if power == 0 then 0 else 1 + natlog2 power
yield_state one \ans.
pow <- with_state power
z <- with_state base
for _:(Fin iters).
if is_odd (get pow)
then ans := times(get ans, get z)
z := times(get z, get z)
pow := intdiv2 (get pow)
def intpow(base:a, power:Nat) -> a given (a|Mul) =
general_integer_power((*), one, base, power)
def from_just(x:Maybe a) -> a given (a) = case x of Just(x') -> x'
def any_sat(f:(a)->Bool, xs:n=>a) -> Bool given (a, n|Ix) = any(each xs f)
def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a) =
case any_sat(is_nothing, xs) of
True -> Nothing
False -> Just $ each xs from_just
def linear_search(xs:n=>a, query:a) -> Maybe n given (n|Ix, a|Eq) =
yield_state Nothing \ref. for i.
case xs[i] == query of
True -> ref := Just i
False -> ()
def list_length(l:List a) -> Nat given (a) =
AsList(n, _) = l
n
def concat(lists:n=>(List a)) -> List a given (a, n|Ix) =
totalSize = sum for i. list_length lists[i]
to_list $ with_state 0 \listIdx.
eltIdx <- with_state 0
for i:(Fin totalSize).
while \.
continue = get eltIdx >= list_length (lists[(get listIdx)@_])
if continue
then
eltIdx := 0
listIdx := get listIdx + 1
else ()
continue
AsList(_, xs) = lists[(get listIdx)@_]
eltIdxVal = get eltIdx
eltIdx := eltIdxVal + 1
xs[eltIdxVal@_]
def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) =
(num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref.
for i. case xs[i] of
Just(_) ->
ix = get ref.0
ref.1 ! (unsafe_from_ordinal ix) := Just i
ref.0 := ix + 1
Nothing -> ()
to_list $ for i:(Fin num_res).
case res_inds[unsafe_from_ordinal $ ordinal i] of
Just(j) -> case xs[j] of
Just(x) -> x
Nothing -> todo
Nothing -> todo
def filter(xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) =
cat_maybes $ for i. if condition xs[i] then Just xs[i] else Nothing
def arg_filter(xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) =
cat_maybes $ for i. if condition xs[i] then Just i else Nothing
def prev_ix(i:n) -> Maybe n given (n|Ix) =
case i_to_n (n_to_i (ordinal i) - 1) of
Nothing -> Nothing
Just(i_prev) -> unsafe_from_ordinal(i_prev) | Just
def lines(source:String) -> List String =
AsList(_, s) = source
AsList(num_lines, newline_ixs) = cat_maybes for i_char.
if s[i_char] == '\n'
then Just(i_char)
else Nothing
to_list for i_line:(Fin num_lines).
start = case prev_ix i_line of
Nothing -> first_ix
Just(i) -> right_post newline_ixs[i]
end = left_post newline_ixs[i_line]
post_slice(s, start, end)
def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) =
r = rand key
from_just $ left_fence $ search_sorted(cdf, r)
def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs
def cdf_for_categorical(logprobs: n=>Float) -> n=>Float given (n|Ix) =
maxLogProb = maximum logprobs
cumsum_low $ normalize_pdf $ for i. exp(logprobs[i] - maxLogProb)
def categorical(logprobs: n=>Float, key: Key) -> n given (n|Ix) =
categorical_from_cdf(cdf_for_categorical logprobs, key)
def categorical_batch(logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) =
cdf = cdf_for_categorical logprobs
for i. categorical_from_cdf(cdf, ixkey(key, i))
def logsumexp(x: n=>Float) -> Float given (n|Ix) =
m = maximum x
m + (log $ sum for i. exp (x[i] - m))
def logsoftmax(x: n=>Float) -> n=>Float given (n|Ix) =
lse = logsumexp x
for i. x[i] - lse
def softmax(x: n=>Float) -> n=>Float given (n|Ix) =
m = maximum x
e = for i. exp (x[i] - m)
s = sum e
for i. e[i] / s
Polynomials
TODO: Move this somewhere else
def evalpoly(coefficients:n=>v, x:Float) -> v given (n|Ix, v|VSpace) =
fold zero \i c. coefficients[i] + x .* c
Exception effect
-- TODO: move error
and todo
to here.
def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a, eff)=
f' : (() -> {Except|eff} a) = \. f()
%catchException(f')
def throw() -> {Except} a given (a) =
%throwException(a)
def assert(b:Bool) -> {Except} () =
if not b then throw()
Misc instances that require error
instance Subset(a, Either(a,b)) given (a|Data, b|Data)
def inject'(x) = Left x
def project'(x) = case x of
Left( y) -> Just y
Right(x) -> Nothing
def unsafe_project'(x) = case x of
Left( x) -> x
Right(x) -> error "Can't project Right branch to Left branch"
instance Subset(b, Either(a,b)) given (a|Data, b|Data)
def inject'(x) = Right x
def project'(x) = case x of
Left( x) -> Nothing
Right(y) -> Just y
def unsafe_project'(x) = case x of
Left( x) -> error "Can't project Left branch to Right branch"
Right(x) -> x
def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) =
base = size b
snd $ scan k \_ cur_k.
next_k = cur_k `idiv` base
digit = cur_k `mod` base
(next_k, unsafe_from_ordinal(n=b, digit))
def reversed_digits_to_int(digits: a=>b) -> Nat given (a|Ix, b|Ix) =
base = size b
fst $ fold (0, 1) \j pair.
(cur_k, cur_base) = pair
next_k = cur_k + ordinal digits[j] * cur_base
next_base = cur_base * base
(next_k, next_base)
instance Ix(a=>b) given (a|Ix, b|Ix)
def size'() = size b `intpow` size a
def ordinal(i) = reversed_digits_to_int i
def unsafe_from_ordinal(i) = int_to_reversed_digits i
instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
struct Stack(h:Heap, a|Data) =
size_ref : Ref h Nat
buf_ref : Ref h (List a)
def size() -> {State h} Nat = get self.size_ref
def unsafe_get_buffer() -> {State h} (Ref(h, Fin 0 => a)) =
get $ snd_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), self.buf_ref)
def buf_size() -> {State h} Nat =
get $ fst_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), self.buf_ref)
def ensure_size_at_least(req_size:Nat) -> {State h} () =
if req_size > self.buf_size() then
new_buf_size = intpow2 $ nextpow2 req_size
buf = self.unsafe_get_buffer()
logical_size = self.size()
cur_data = get $ unsafe_coerce(to=Ref(h, Fin logical_size => a), buf)
self.buf_ref := to_list for i:(Fin new_buf_size).
case to_ix(n=Fin logical_size, ordinal i) of
Just(i') -> cur_data[i']
Nothing -> uninitialized_value()
def read() -> {State h} (List a) =
n = self.size()
buf = unsafe_coerce(to=Ref(h, Fin n => a), self.unsafe_get_buffer())
AsList(n, get buf)
@noinline
def push(x:a) -> {State h} () =
n_old = self.size()
n_new = n_old + 1
self.ensure_size_at_least(n_new)
buf = self.unsafe_get_buffer()
buf ! (unsafe_from_ordinal n_old) := x
self.size_ref := n_new
@noinline
def extend(x:n=>a) -> {State h} () given (n|Ix) =
n_old = self.size()
n_new = n_old + size n
self.ensure_size_at_least(n_new)
buf = self.unsafe_get_buffer()
buf_slice = unsafe_coerce(to=Ref(h,n=>a), buf ! (unsafe_from_ordinal n_old))
buf_slice := x
self.size_ref := n_new
def pop() -> {State h} Maybe a =
n_old = self.size()
case n_old == 0 of
True -> Nothing
False ->
n_new = unsafe_nat_diff(n_old, 1)
buf = self.unsafe_get_buffer()
self.size_ref := n_new
Just $ get buf!(unsafe_from_ordinal n_new)
def with_stack(
a|Data,
action:(given (h:Heap), Stack(h, a)) -> {State h|eff} r
) -> {|eff} r given (eff, r) =
init_stack = to_list for i:(Fin stack_init_size). uninitialized_value()
with_state (0, init_stack) \ref . action(Stack(ref.0, ref.1))
def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) =
stack.extend(x)
def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h) =
stack.push(x)
def with_stack_internal(f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char =
with_stack Char \stack.
f stack
stack.read()
def from_c_string(s:CString) -> {IO} (Maybe String) =
case c_string_ptr s of
Nothing -> Nothing
Just(ptr) ->
Just do
stack <- with_stack Char
i <- iter
c = load $ ptr +>> i
if c == '\NUL'
then Done $ stack.read()
else
stack.push(c)
Continue
def show_any(x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x))
def coerce_table(m|Ix, x:n=>a) -> m => a given (n|Ix, a|Data) =
if size m == size n
then unsafe_coerce(to=m=>a, x)
else error "mismatched sizes in table coercion"
foreign "getenv" getenvFFI : (RawPtr) -> {IO} RawPtr
def get_env(name:String) -> {IO} Maybe String =
cStr <- with_c_string name
getenvFFI cStr.ptr | CString | from_c_string
def check_env(name:String) -> {IO} Bool =
is_just $ get_env name
def dex_test_mode() -> Bool = unsafe_io \. check_env "DEX_TEST_MODE"
def fread(stream:Stream ReadMode) -> {IO} String =
n = 4096
ptr:(Ptr Char) <- with_alloc n
stack <- with_stack Char
iter \_.
numRead = i_to_w32 $ i64_to_i $ freadFFI(ptr.val, 1, n_to_i64 n, stream.ptr)
AsList(_, new_chars) = string_from_char_ptr(numRead, ptr)
stack.extend(new_chars)
if numRead == n_to_w32 n
then Continue
else Done ()
stack.read()
foreign "popen" popenFFI : (RawPtr, RawPtr) -> {IO} RawPtr
foreign "remove" removeFFI : (RawPtr) -> {IO} Int64
foreign "mkstemp" mkstempFFI : (RawPtr) -> {IO} Int32
foreign "close" closeFFI : (Int32) -> {IO} Int32
def shell_out(command:String) -> {IO} String =
modeStr = "r"
with_c_string command \command'.
with_c_string modeStr \modeStr'.
pipe = Stream $ popenFFI(command'.ptr, modeStr'.ptr)
fread pipe
def delete_file(f:FilePath) -> {IO} () =
s <- with_c_string(f)
removeFFI s.ptr
()
def with_file(
f:FilePath,
mode:StreamMode,
action:(Stream mode) -> {IO} a
) -> {IO} a given (a|Data) =
stream = fopen(f, mode)
if is_null_raw_ptr stream.ptr
then
error $ "Unable to open file: " <> f
else
result = action stream
fclose stream
result
def write_file(f:FilePath, s:String) -> {IO} () =
with_file(f, WriteMode) \stream. fwrite(stream, s)
def read_file(f:FilePath) -> {IO} String =
with_file(f, ReadMode) \stream. fread stream
def has_file(f:FilePath) -> {IO} Bool =
stream = fopen(f, ReadMode)
result = not (is_null_raw_ptr stream.ptr)
if result then fclose stream
result
def new_temp_file() -> {IO} FilePath =
s <- with_c_string "/tmp/dex-XXXXXX"
fd = mkstempFFI s.ptr
closeFFI fd
string_from_char_ptr(15, (Ptr s.ptr))
def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a) =
tmpFile = new_temp_file()
result = action tmpFile
delete_file tmpFile
result
def with_temp_files(action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) =
tmpFiles = for i. new_temp_file()
result = action tmpFiles
for i. delete_file tmpFiles[i]
result
def linspace(n|Ix, low:Float, high:Float) -> n=>Float =
dx = (high - low) / n_to_f (size n)
for i:n. low + n_to_f (ordinal i) * dx
def transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = for i j. x[j,i]
def vdot(x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i. x[i] * y[i]
def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j. s[j] .* vs[j]
def naive_matmul(x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) =
for i k. fsum for j. x[i,j] * y[j,k]
struct FullTileIx(n|Ix, tile_size:Nat, tile_ix:Nat) =
unwrap : Fin tile_size
instance Ix(FullTileIx(n, tile_size, tile_ix)) given (n|Ix, tile_size:Nat, tile_ix:Nat)
def size'() = tile_size
def ordinal(i) = ordinal i.unwrap
def unsafe_from_ordinal(i) = FullTileIx $ unsafe_from_ordinal i
instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat, tile_ix:Nat)
def inject'(i) = unsafe_from_ordinal $ tile_size * tile_ix + ordinal i.unwrap
def project'(i) = todo
def unsafe_project'(i) = todo
struct CodaIx(n|Ix, coda_offset:Nat, coda_size:Nat) =
unwrap : Fin coda_size
instance Ix(CodaIx(n, coda_offset, coda_size)) given (n|Ix, coda_offset:Nat, coda_size:Nat)
def size'() = coda_size
def ordinal(i) = ordinal i.unwrap
def unsafe_from_ordinal(i) = CodaIx $ unsafe_from_ordinal i
instance Subset(CodaIx(n, coda_offset, coda_size), n) given (n|Ix, coda_offset:Nat, coda_size:Nat)
def inject'(i) = unsafe_from_ordinal $ coda_offset + ordinal i.unwrap
def project'(i) = todo
def unsafe_project'(i) = todo
def tile(
n|Ix,
tile_size: Nat,
body:(m:Type, given () (Ix m, Subset(m, n))) -> {|eff} ()
) -> {|eff} () given (eff) =
num_tiles = size n `idiv` tile_size
coda_size = size n `rem` tile_size
coda_offset = num_tiles * tile_size
for_ tile_ix:(Fin num_tiles).
tile_ix' = ordinal tile_ix
body (FullTileIx(n, tile_size, tile_ix'))
body (CodaIx(n, coda_offset, coda_size))
@noinline
def tiled_matmul(
x: l=>m=>Float,
y: m=>n=>Float
) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
l_tile_size = 32
n_tile_size = 128
m_tile_size = 8
yield_accum (AddMonoid Float) \result.
tile(l, l_tile_size) \l_set.
tile(n, n_tile_size) \n_set.
tile(m, m_tile_size) \m_set.
for_ l_offset:l_set.
l_ix = inject(to=l, l_offset)
for_ m_offset:m_set.
m_ix = inject m_offset
for_ n_offset:n_set.
n_ix = inject n_offset
result!l_ix!n_ix += x[l_ix][m_ix] * y[m_ix][n_ix]
def (**)(
x: l=>m=>Float,
y: m=>n=>Float
) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
tiled_matmul(x, y)
def matmul_linearization(
x: l=>m=>Float,
y: m=>n=>Float
) -> _ given (l|Ix, m|Ix, n|Ix) =
def lin(xt: l=>m=>Float, yt: m=>n=>Float) -> _ =
x ** yt + xt ** y
(x ** y, lin)
custom-linearization tiled_matmul matmul_linearization
def (**.)(mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) =
for i. vdot(mat[i], v)
def(.**)(v: n=>Float, mat: n=>m=>Float) -> (m=>Float) given (n|Ix, m|Ix) =
transpose mat **. v
def inner(x:n=>Float, mat:n=>m=>Float, y:m=>Float) -> Float given (n|Ix, m|Ix) =
fsum for p.
(i,j) = p
x[i] * mat[i,j] * y[j]
def eye() ->> n=>n=>a given (n|Ix, a|Add|Mul) =
for i j. select(ordinal i == ordinal j, one, zero)