Levenshtein Distance

May 13, 2022

The Levenshtein distance is the minimum edit distance between two sequences, counting insertions, deletions, and substitutions each as 1.

Let's see how well Dex handles the conventional dynamic program for this: For each pair of prefixes of the input strings, the Levenshtein distance between those prefixes is the minimum obtainable from the prefixes one shorter by inserting, deleting, or substituting an element.

Here is the helper that builds the dynamic program table. Dex's flexible index sets let us encode the fact that the table is 1 larger in each dimension than the inputs. By capturing the relationship statically we avoid both programmer off-by-one errors and runtime array bounds checks.

def levenshtein_table {n m a} [Eq a] (xs: n=>a) (ys: m=>a) : (Post n => Post m => Nat) = yield_state (for _ _. 0) \tab. for i:(Post n). tab!i!first_ix := ordinal i for j:(Post m). tab!first_ix!j := ordinal j for i:n j:m. subst_cost = if xs.i == ys.j then 0 else 1 d_subst = get tab!(left_post i)!(left_post j) + subst_cost d_delete = get tab!(left_post i)!(right_post j) + 1 d_insert = get tab!(right_post i)!(left_post j) + 1 tab!(right_post i)!(right_post j) := minimum [d_subst, d_delete, d_insert]
%time levenshtein_table ['k', 'i', 't', 't', 'e', 'n'] ['s', 'i', 't', 't', 'i', 'n', 'g']
[ [0, 1, 2, 3, 4, 5, 6, 7]@(Post (Fin 7)) , [1, 1, 2, 3, 4, 5, 6, 7]@(Post (Fin 7)) , [2, 2, 1, 2, 3, 4, 5, 6]@(Post (Fin 7)) , [3, 3, 2, 1, 2, 3, 4, 5]@(Post (Fin 7)) , [4, 4, 3, 2, 1, 2, 3, 4]@(Post (Fin 7)) , [5, 5, 4, 3, 2, 2, 3, 4]@(Post (Fin 7)) , [6, 6, 5, 4, 3, 3, 2, 3]@(Post (Fin 7)) ]@(Post (Fin 6))
Compile time: 122.989 ms Run time: 223.400 us

The actual distance is of course just the last element of the table.

def levenshtein {n m a} [Eq a] (xs: n=>a) (ys: m=>a) : Nat = (levenshtein_table xs ys).last_ix.last_ix
%time levenshtein (iota $ Fin 100) (iota $ Fin 100)
0
Compile time: 105.583 ms Run time: 135.700 us

Speed

To check that we don't embarrass ourselves on performance, let's run the Sountsov benchmark: Compute Levenshtein distances for all pairs of arange arrays of size up to 100.

%bench "Sountsov Benchmark" answer = for i:(Fin 100). for j:(Fin 100). iint = ordinal i jint = ordinal j levenshtein (iota $ Fin iint) (iota $ Fin jint)
Sountsov Benchmark Compile time: 225.188 ms Run time: 43.331 ms (based on 47 runs)
sum(sum(answer))
333300

The straightforward C++ program for this takes about 35ms on my workstation, so Dex performance is in the right ballpark. (And we know several optimizations that should let us close the gap.) As of this writing, native JAX takes 15 minutes, due to tracing and compiling the body 10,000 times (once for each pair of input sizes).

Real Data

Just for fun, we can make a crude spelling correcter out of this distance function on words:

(AsList ct words) = lines $ unsafe_io do read_file "/usr/share/dict/words"
def closest_word (s:String) : String = (AsList _ s') = s fst $ minimum_by snd for i. (AsList _ word) = words.i (words.i, levenshtein word s')
%time closest_word "hello"
"hello"
Compile time: 233.633 ms Run time: 153.246 ms
closest_word "kitttens"
"kittens"
closest_word "functor"
"function"
closest_word "applicative"
"application"
closest_word "monoids"
"ovoids"
closest_word "semigroup"
"subgroup"
closest_word "paralllel"
"parallel"