Multi-step Ray Tracer

Based on Eric Jang's JAX implementation, described here.

import png
import plot

Generic Helper Functions

Some of these should probably go in prelude.

def Vec (n:Nat) : Type = Fin n => Float
def Mat (n:Nat) (m:Nat) : Type = Fin n => Fin m => Float
def relu (x:Float) : Float = max x 0.0
def length {d} (x: d=>Float) : Float = sqrt $ sum for i. sq x.i
-- TODO: make a newtype for normal vectors
def normalize {d} (x: d=>Float) : d=>Float = x / (length x)
def directionAndLength {d} (x: d=>Float) : (d=>Float & Float) = l = length x (x / (length x), l)
def randuniform (lower:Float) (upper:Float) (k:Key) : Float = lower + (rand k) * (upper - lower)
def sampleAveraged {a} [VSpace a] (sample:Key -> a) (n:Nat) (k:Key) : a = yield_state zero \total. for i:(Fin n). total := get total + sample (ixkey k i) / n_to_f n
def positiveProjection {n} (x:n=>Float) (y:n=>Float) : Bool = dot x y > 0.0

3D Helper Functions

def cross (a:Vec 3) (b:Vec 3) : Vec 3 = [a1, a2, a3] = a [b1, b2, b3] = b [a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1]
-- TODO: Use `data Color = Red | Green | Blue` and ADTs for index sets
data Image = MkImage height:Nat width:Nat (Fin height => Fin width => Color)
xHat : Vec 3 = [1., 0., 0.]
yHat : Vec 3 = [0., 1., 0.]
zHat : Vec 3 = [0., 0., 1.]
Angle = Float -- angle in radians
def rotateX (p:Vec 3) (angle:Angle) : Vec 3 = c = cos angle s = sin angle [px, py, pz] = p [px, c*py - s*pz, s*py + c*pz]
def rotateY (p:Vec 3) (angle:Angle) : Vec 3 = c = cos angle s = sin angle [px, py, pz] = p [c*px + s*pz, py, - s*px+ c*pz]
def rotateZ (p:Vec 3) (angle:Angle) : Vec 3 = c = cos angle s = sin angle [px, py, pz] = p [c*px - s*py, s*px+c*py, pz]
def sampleCosineWeightedHemisphere (normal: Vec 3) (k:Key) : Vec 3 = [k1, k2] = split_key k u1 = rand k1 u2 = rand k2 uu = normalize $ cross normal [0.0, 1.1, 1.1] vv = cross uu normal ra = sqrt u2 rx = ra * cos (2.0 * pi * u1) ry = ra * sin (2.0 * pi * u1) rz = sqrt (1.0 - u2) rr = (rx .* uu) + (ry .* vv) + (rz .* normal) normalize rr

Raytracer

Distance = Float
Position = Vec 3
Direction = Vec 3 -- Should be normalized. TODO: use a newtype wrapper
BlockHalfWidths = Vec 3
Radius = Float
Radiance = Color
data ObjectGeom = Wall Direction Distance Block Position BlockHalfWidths Angle Sphere Position Radius
data Surface = Matte Color Mirror
OrientedSurface = (Direction & Surface)
data Object = PassiveObject ObjectGeom Surface -- position, half-width, intensity (assumed to point down) Light Position Float Radiance
Ray = (Position & Direction)
Filter = Color
-- TODO: use a record
-- num samples, num bounces, share seed?
Params = { numSamples : Nat & maxBounces : Nat & shareSeed : Bool }
-- TODO: use a list instead, once they work
data Scene n:Type [Ix n] = MkScene (n=>Object)
def sampleReflection ((nor, surf):OrientedSurface) ((pos, dir):Ray) (k:Key) : Ray = newDir = case surf of Matte _ -> sampleCosineWeightedHemisphere nor k -- TODO: surely there's some change-of-solid-angle correction we need to -- consider when reflecting off a curved surface. Mirror -> dir - (2.0 * dot dir nor) .* nor (pos, newDir)
def probReflection ((nor, surf):OrientedSurface) (_:Ray) ((_, outRayDir):Ray) : Float = case surf of Matte _ -> relu $ dot nor outRayDir Mirror -> 0.0 -- TODO: this should be a delta function of some sort
def applyFilter (filter:Filter) (radiance:Radiance) : Radiance = for i. filter.i * radiance.i
def surfaceFilter (filter:Filter) (surf:Surface) : Filter = case surf of Matte color -> for i. filter.i * color.i Mirror -> filter
def sdObject (pos:Position) (obj:Object) : Distance = case obj of PassiveObject geom _ -> case geom of Wall nor d -> d + dot nor pos Block blockPos halfWidths angle -> pos' = rotateY (pos - blockPos) angle length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0 Sphere spherePos r -> pos' = pos - spherePos max (length pos' - r) 0.0 Light squarePos hw _ -> pos' = pos - squarePos halfWidths = [hw, 0.01, hw] length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0
def sdScene {n} (scene:Scene n) (pos:Position) : (Object & Distance) = (MkScene objs) = scene (i, d) = minimum_by snd $ for i. (i, sdObject pos objs.i) (objs.i, d)
def calcNormal (obj:Object) (pos:Position) : Direction = normalize (grad (flip sdObject obj) pos)
data RayMarchResult = -- incident ray, surface normal, surface properties HitObj Ray OrientedSurface HitLight Radiance -- Could refine with failure reason (beyond horizon, failed to converge etc) HitNothing
def raymarch {n} (scene:Scene n) (ray:Ray) : RayMarchResult = maxIters = 100 tol = 0.01 startLength = 10.0 * tol -- trying to escape the current surface (rayOrigin, rayDir) = ray with_state (10.0 * tol) \rayLength. bounded_iter maxIters HitNothing \_. rayPos = rayOrigin + get rayLength .* rayDir (obj, d) = sdScene scene $ rayPos -- 0.9 ensures we come close to the surface but don't touch it rayLength := get rayLength + 0.9 * d case d < tol of False -> Continue True -> surfNorm = calcNormal obj rayPos case positiveProjection rayDir surfNorm of True -> -- Oops, we didn't escape the surface we're leaving.. -- (Is there a more standard way to do this?) Continue False -> -- We made it! Done $ case obj of PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf) Light _ _ radiance -> HitLight radiance
def rayDirectRadiance {n} (scene:Scene n) (ray:Ray) : Radiance = case raymarch scene ray of HitLight intensity -> intensity HitNothing -> zero HitObj _ _ -> zero
def sampleSquare (hw:Float) (k:Key) : Position = [kx, kz] = split_key k x = randuniform (- hw) hw kx z = randuniform (- hw) hw kz [x, 0.0, z]
def sampleLightRadiance {n} (scene:Scene n) (osurf:OrientedSurface) (inRay:Ray) (k:Key) : Radiance = (surfNor, surf) = osurf (rayPos, _) = inRay (MkScene objs) = scene yield_accum (AddMonoid Float) \radiance. for i. case objs.i of PassiveObject _ _ -> () Light lightPos hw _ -> (dirToLight, distToLight) = directionAndLength $ lightPos + sampleSquare hw k - rayPos if positiveProjection dirToLight surfNor then -- light on this far side of current surface fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight) outRay = (rayPos, dirToLight) coeff = fracSolidAngle * probReflection osurf inRay outRay radiance += coeff .* rayDirectRadiance scene outRay
def trace {n} (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = noFilter = [1.0, 1.0, 1.0] yield_accum (AddMonoid Float) \radiance. run_state noFilter \filter. run_state initRay \ray. bounded_iter (get_at #maxBounces params) () \i. case raymarch scene $ get ray of HitNothing -> Done () HitLight intensity -> if i == 0 then radiance += intensity -- TODO: scale etc Done () HitObj incidentRay osurf -> [k1, k2] = split_key $ hash k i lightRadiance = sampleLightRadiance scene osurf incidentRay k1 ray := sampleReflection osurf incidentRay k2 filter := surfaceFilter (get filter) (snd osurf) radiance += applyFilter (get filter) lightRadiance Continue
-- Assumes we're looking towards -z.
Camera = { numPix : Nat & pos : Position -- pinhole position & halfWidth : Float -- sensor half-width & sensorDist : Float } -- pinhole-sensor distance
-- TODO: might be better with an anonymous dependent pair for the result
def cameraRays (n:Nat) (camera:Camera) : Fin n => Fin n => (Key -> Ray) = -- images indexed from top-left halfWidth = get_at #halfWidth camera pixHalfWidth = halfWidth / n_to_f n ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth xs = linspace (Fin n) (neg halfWidth) halfWidth for i j. \key. [kx, ky] = split_key key x = xs.j + randuniform (-pixHalfWidth) pixHalfWidth kx y = ys.i + randuniform (-pixHalfWidth) pixHalfWidth ky (get_at #pos camera, normalize [x, y, neg (get_at #sensorDist camera)])
def takePicture {m} (params:Params) (scene:Scene m) (camera:Camera) : Image = n = get_at #numPix camera rays = cameraRays n camera rootKey = new_key 0 image = for i j. pixKey = if get_at #shareSeed params then rootKey else ixkey (ixkey rootKey i) j sampleRayColor : Key -> Color = \k. [k1, k2] = split_key k trace params scene (rays.i.j k1) k2 sampleAveraged sampleRayColor (get_at #numSamples params) pixKey MkImage _ _ $ image / mean (for (i,j,k). image.i.j.k)

Define the scene and render it

lightColor = [0.2, 0.2, 0.2]
leftWallColor = 1.5 .* [0.611, 0.0555, 0.062]
rightWallColor = 1.5 .* [0.117, 0.4125, 0.115]
whiteWallColor = [255.0, 239.0, 196.0] / 255.0
blockColor = [200.0, 200.0, 255.0] / 255.0
theScene = MkScene $ [ Light (1.9 .* yHat) 0.5 lightColor , PassiveObject (Wall xHat 2.0) (Matte leftWallColor) , PassiveObject (Wall (neg xHat) 2.0) (Matte rightWallColor) , PassiveObject (Wall yHat 2.0) (Matte whiteWallColor) , PassiveObject (Wall (neg yHat) 2.0) (Matte whiteWallColor) , PassiveObject (Wall zHat 2.0) (Matte whiteWallColor) , PassiveObject (Block [ 1.0, -1.6, 1.2] [0.6, 0.8, 0.6] 0.5) (Matte blockColor) , PassiveObject (Sphere [-1.0, -1.2, 0.2] 0.8) (Matte (0.7.* whiteWallColor)) , PassiveObject (Sphere [ 2.0, 2.0, -2.0] 1.5) (Mirror) ]
defaultParams = { numSamples = 50 , maxBounces = 10 , shareSeed = True }
defaultCamera = { numPix = 250 , pos = 10.0 .* zHat , halfWidth = 0.3 , sensorDist = 1.0 }
-- We change to a small num pix here to reduce the compute needed for tests
params = defaultParams
camera = if dex_test_mode () then defaultCamera |> set_at #numPix 10 else defaultCamera
-- %time
(MkImage _ _ image) = takePicture params theScene camera
:html imshow image

Just for fun, here's what we get with a single sample (sharing the PRNG key among pixels)

params2 = defaultParams |> set_at #numSamples 1
(MkImage _ _ image2) = takePicture params2 theScene camera
:html imshow image2