Continuous optimization for discrete control
Reproduce in a Notebook Discuss on
In this blog post I'd like to show how differentiable optimization can be used to learn Finite State Machines (FSM) for solving toy string processing tasks. I'll show how simple regularization and initialization techniques can steer continuous optimization towards finding discrete deterministic solutions and increase the training success rate. I think that experiments shown here may have some educational value, e.g. in demonstrating less conventional (and perhaps unexpected) uses of differentiable programming and some elegant JAX tricks.
There exists a vast literature on various approaches to FSM synthesis from data, so it's highly likely that the particular method described here is already well known. A quick review of related literature will be given in the Related Work section.
Here is an example of a toy problem we are going to solve. Consider a pair of input-output strings:
We'd like to design an algorithm that would produce the expected output string if we feed it with the input string character by character. There are many different possibilities to perform this task. The brute force way would be just to store the expected output string and emit it character by character without paying attention to the input at all. This is actually a pretty complex solution, for example implementing such an algorithm as a state machine would require the number of states that is equal to the string length. Intuitively, ideas behind Occam's razor and Kolmogorov's complexity suggest that simpler solutions are preferable. The simplicity gives us hope that the solution would be interpretable and generalize to input sequences beyond the training set.
The simplest algorithm for the string pair above scans the input and replaces every second '1' character it encounters with '0'. A type of FSM called Mealy Machine is natural way of representing such algorithms:
T: state' | R: output | s0 | ||||
---|---|---|---|---|---|---|
A | B | 0 | 1 | |||
input | state | |||||
0 | A | 1.00 | 0.00 | 1.00 | 0.00 | 1.00 |
B | 0.00 | 1.00 | 1.00 | 0.00 | 0.00 | |
1 | A | 0.00 | 1.00 | 0.00 | 1.00 | 1.00 |
B | 1.00 | 0.00 | 1.00 | 0.00 | 0.00 |
This machine starts at the state "A". The label at each edge has a form "<input>/<output>": <input> character activates the transition along the given edge, and <output> character is emitted to the output string during the transitions. Transition tables are a common way to represent FSMs. Each row of the transition table above corresponds to a different combination of input and state, and represents the new state' and the output value, encoded as two one-hot vectors. This table also includes the column s0, a one-hot vector, encoding the starting FSM state (which is replicated two times for completeness). Formally, this table can be represented by three tensors with following axes:
$$ \begin{aligned} T&:[\mathsf{input}, \mathsf{state}, \mathsf{state'}] \\ R&:[\mathsf{input}, \mathsf{state}, \mathsf{output}] \\ s^0&:[\mathsf{state}] \end{aligned} $$
Now we can express the FSM output and state transition computation as the following sums
$$ \begin{aligned} y^t_{\mathsf{output}} &= \sum_{\mathsf{input},\mathsf{state}} x^t_\mathsf{input} s^t_\mathsf{state} R_{\mathsf{input}, \mathsf{state}, \mathsf{output}} \\ s^{t+1}_{\mathsf{state'}} &= \sum_{\mathsf{input},\mathsf{state}} x^t_\mathsf{input} s^t_\mathsf{state} T_{\mathsf{input}, \mathsf{state}, \mathsf{state'}} \end{aligned}$$
where x, y, and s are sequences of one-hot vectors, representing inputs, outputs and FSM states. Luckily, modern differentiable programming frameworks (I'll be using JAX in this tutorial) provide the powerful '
einsum
'
function to compute these kinds of expressions. It's easy to combine it with the '
scan
'
primitive to compute FSM outputs and intermediate states for a given input:
For now we were presenting states and characters as one-hot vectors, but we can also think of them as probability distributions. This relaxes the discrete set of one-hot vectors to a continuous space of vectors of non-negative values that sum up to one. The resulting model is called a probabilistic automaton and the above math still holds.
Given that we represented a state machine in terms of differentiable operations, why don't we try to use gradient descent to find a machine given the input-output pair? We need to make sure that we use the right parameterisation that enforces rows of $T$, $R$ and $s^0$ to be probability distributions. Usually this is achieved by using the 'softmax' function. $(T, R, s^0)$ are parametrized with real-valued tensors and softmax is applied along the last dimension to turn them into probabilities. Our toy examples are going to use the alphabet that consists of just two characters: '0' and '1'. We assume that we don't know the required number of FSM states in advance, so in the spirit of deep learning we are going to use an overparameterized FSM representation that has a maximum of 8 states.
CHAR_N = 2
STATE_N = 8
Params = namedtuple('Params', 'T R s0')
def init_fsm(key, noise=1e-3) -> Params:
k1, k2, k3 = jax.random.split(key, 3)
T = jax.random.normal(k1, [CHAR_N, STATE_N, STATE_N]) * noise
R = jax.random.normal(k2, [CHAR_N, STATE_N, CHAR_N]) * noise
s0 = jax.random.normal(k3, [STATE_N]) * noise
return Params(T, R, s0)
def hardmax(x):
return nn.one_hot(x.argmax(-1), x.shape[-1])
def decode_fsm(params: Params, hard=False) -> FSM:
T, R, s0 = params
f = hardmax if hard else nn.softmax
return FSM(f(T), f(R), f(s0))
'init_fsm'
function creates a randomly initialized FSM. Parameters are set to very small values, so all distributions are close to uniform. 'decode_fsm' serves to convert parameters into FSM matrices to consume by 'run_fsm'
. It has two operation modes: 'soft', when softmax is used, and 'hard', when all distributions get collapsed into one-hot vectors, and stochastic FSM becomes deterministic.
Let's try to simply minimize the sum of squared-differences between the FSM output and the reference output sequence:
fsm = decode_fsm(params)
y, s = run_fsm(fsm, x)
error = jp.square(y-y0).sum()
Throwing ADAM optimizer into minimizing the error indeed makes it go down. Getting there took a bit of fiddling with the optimizer parameters. Motivated by the lack of loss stochasticity and small number of training steps (400) I use a pretty high learning rate (0.25) and set beta1=beta2=0.5
.
Here is the diagram showing input, output and state probability distributions produced by the resulting FSM at different time steps (circle size is proportional to the probability):
We see that although the model manages to perfectly reproduce the desired output for the given input string, it's using many more states than necessary, and at some timesteps (e.g. 0 or 3) the probability is distributed among a number of states, so the process doesn't look deterministic. In machine learning, the usual strategy to steer the optimization into a more desirable solutions region is to introduce additional regularizer objectives. In our case "desirable solutions" are the ones that have deterministic behavior and use the least number of states. I experimented with a few ways to achieve these goals. The simplest and one of the most efficient method I found was to penalize the entropy ($H$) of the average state probability across time steps: $$ p_i = {1 \over N_\mathsf{step}}\sum_t s^t_i $$ $$ H = -\sum_i p_i log(p_i) $$
The total loss is therefore $$ L_\mathsf{total} = L_\mathsf{error} + w_H H $$
where $w_H$ is a regularization coefficient that is set by default to 0.01. This regularization happened to be sufficient to steer optimization to the deterministic solution that only uses two states:
We see that states 'A' and 'H' of the learned machine are equivalent to the states 'A' and 'B' of the expected minimal FSM.
Let's check how robust and dependent on the initialization the solution is. init_fsm
function uses the key
argument to generate the initial FSM parameters for the optimization. I wrote a Trainer
class that encapsulates the task training and evaluation, and can be used like this:
x = '01010100100111111'
y0 = '01000100000101010'
trainer = Trainer(x, y0)
key = jax.random.PRNGKey(1)
result = trainer.run(key)
Returned data contains the final learned parameters, training log and evaluation result (number of errors and used states) by the deterministic (decoded with hard=True
) version of the learned FSM. Here I'd like to give a great shout-out to jax.vmap
function, because with its help running 100 parallel experiments with different random keys boils down to the following lines:
keys = jax.random.split(key, 100)
results = jax.vmap(trainer.run)(keys)
By default vmap
replicates the nested function by introducing an extra dimension to all input and output arrays. Single line change creates many Adam optimizers and runs many training loops in parallel. The diagram below shows results from running 100 randomly initialized experiments with four different values of $w_H$. We see that strong regularization increases the number of runs that have converged to the minimal two-state FSM, but also increases the number of erroneous runs, i.e. those that produced the machine that couldn't correctly process the training sequence.
Entropy regularization is just one of many possible methods to steer the optimization towards desired solutions. Intuitively, machines that make fewer state transitions during processing of the training sequence are executing simpler, and thus, preferable behavior. Of course we may try to introduce another loss term that would penalize state transitions, but I'd like to demonstrate another approach here. What I tried is to initialize the transition tensor $T$ with a bias towards preserving the current state and disregarding the input. I do so by adding the identity matrix, that gets broadcasted for all input characters.
T = jax.random.normal(k1, [CHAR_N, STATE_N, STATE_N]) * noise
T += jp.eye(STATE_N) * lazy_bias
Setting lazy_bias=1.0
happened to dramatically change the distribution of training outcomes, leading to higher chances of finding the minimal solution for lower values of $w_H$.
My observations showed that for simple problems the optimization is not particularly sensitive to the choice of $w_H$, so I picked the value 0.01
To see how the procedure described above handles different problems, I've created a "dataset" of 10 toy tasks and ran 1000 training attempts for each of them (thanks to JAX it takes a few seconds per task even on CPU). Almost all training runs found solutions that correctly process given training sequences. Only a few erroneous state machines were produced for tasks 1 and 9. The table below shows distribution of the number of states in discovered state machines, and minimal FSM diagram(s).
I rename states, assigning letters in the order of the first visit during processing of the training input. This way we can equalize solutions that only differ by the states' order. Note that these diagrams only show nodes and edges that were traversed during processing of the training sequence, missing edges are considered undefined (e.g. task 3 FSM doesn't have an edge corresponding to the input "1" from the state B'").
Task 0: skip every second '1' Training input : 01010100100111111 Expected outout : 01000100000101010
Task 1: skip every third '1' Training input : 0101010000111111 Expected outout : 0101000000110110
Task 2: emit every third '1' Training input : 01010100111111 Expected outout : 00000100001001
Task 3: emit 0s, when meet 1, start emitting 1s Training input : 00010000 Expected outout : 00011111
Task 4: invert bits Training input : 01001110 Expected outout : 10110001
Task 5: generate repeating '01' sequence Training input : 00000000 Expected outout : 01010101
Task 6: replace '1'->'111' Training input : 00100001000 Expected outout : 00111001110
Task 7: shift right by 1 bit Training input : 0110010011001 Expected outout : 0011001001100
Task 8: shift right by 2 bits Training input : 000111010001100100 Expected outout : 000001110100011001
Task 9: skip every fourth '1' Training input : 01010100100111111 Expected outout : 01010100000111011
For most problems it's easy to validate that the intended solution was indeed found. The task 8 ("shift the string right by two characters") FSM may look a bit cumbersome, but it's easy to understand what the state machine is doing by renaming the states: A->00, B->01, C->11, D->10
. Now the label of each state reflects two most recent input bits, stored by the FSM and ready to be emitted.
Finally, task 9 (replace every 4th '1' with 0') happened to be an interesting example of an insufficiently defined problem. I expected the optimization to find the solution similar to problems 0 and 1: a loop with 4 states, that counts a number of encountered '1'-s and replaces the last with '0'. Turns out that there is more than one 4-state solution for the training sequence '01010100100111111'
. Optimization found a 4-state FSM 217 times, only 84 out of which were the machine I expected. If you expand the last cell of the table, which shows all discovered FSMs and their counts, you will see that we happened to be quite fortunate, because the second most popular machine count is smaller by just one. The other state machines have overfitted to the particular training sequence and are not going to perform correctly on other inputs. The usual rule of machine learning, when low training loss doesn't guarantee generalization, applies here as well. Additional initialization and regularization strategies may increase the chance of finding the well performing solutions.
The amount of literature published on almost any topic is overwhelming these days. Finite State Machine learning is not an exception, especially given that the concept of state machines lies at the very foundation of computer science. Here are just some of the vast number of relevant articles, some of which are not directly related to FSM learning, but still conceptually close. Please feel free to leave additional pointers and comments in the article discussion.
LearnLib
State machines were also shown to be useful in more complex differentiable machine learning systems
On another end Neural Turing Machines
An interesting recent work
In this tutorial we've seen that simple deterministic string processing finite state machines can be discovered as a limit case of differentiable stochastic state machines. I've only covered simple toy problems in this tutorial, but differentiability allows incorporating FSMs in more complex end-to-end pipelines that I'm going to talk about in future publications.
I'd like to thank my colleagues Ettore Randazzo, Eyvind Niklasson, Dominik Roblek and David Ha for their feedback, suggestions and relevant work pointers. This article was prepared using the Distill template and customized Google Docs->HTML workflow.