The limits of linear-time machine learning

I recently enjoyed a talk by François Charton on “Transformers for maths, and maths for transformers” (recording). Charton investigates the application of transformer-based translation models to basic mathematical tasks, ranging from basic arithmetic to integer sequence completion and linear algebra. It is important to note that each of these problems is encoded symbolically, not numerically. For instance, a natural number is fed to the transformer as a series of symbols representing its digits (in some base), and these symbols have no prior meaning to the model. The task of the model is translate the sequence of symbols that encode the problem to the sequence of symbols that encode the solution, and the model is trained on generated pairs of example input and output “sentences”. Remarkably, the transformer excels are some of these tasks, e.g. determining the eigenvalues of 5×5 matrices. For arithmetic tasks, the transformer has no trouble determining which of a pair of rational numbers is the larger, but it absolutely refuses to learn to compute their sum, or even simplify a fraction. Charton further investigates the arithmetic inabilities of the transformer in “Can transformers learn the greatest common divisor?”.

After the talk, it occurred to me that the inability of a transformer model to learn the GCD is a problem of expressibility (model can’t express the algorithm), not of learnability (model could express it, but can’t learn it). This is because any transformer-based algorithm runs in time linear in (input length + output length), whereas the run-time of the best known algorithms for computing the GCD are super-linear (e.g. quadratic or quasilinear) in the input length.

(Note that I’m adopting a computational model in which the transformer is allowed only finite memory (as it is in practice). Thus the size of the attention matrix is bounded, and so each step of the model is O(1).  In unbounded memory was possible, each step would be linear!)

Consider the case of the Euclidean algorithm. In the worst case, the Euclidean algorithm runs in time quadratic in the length $l$ of each of the inputs (as a strings of digits, in some base). But the run-time of the transformer is linear in $l$: it consumes the $\approx 2l$ input symbols, and after which it must immediately proceed to produce output tokens. The sequence of output tokens (as defined by the training set) is of length less than $l$ (being the digital representation of the GCD). So that’s $\approx 3l$ steps in total, which is too few. So the model “doesn’t have enough time” to emulate the Euclidean algorithm.

It might be countered that we don’t really care about the Euclidean algorithm, and that we were rather interested to see what the transformer would come up with. Perhaps a linear time algorithm could exist, after all. But zooming out from GCD, it seems that it would be interesting and not too difficult to find other problems with proven super-linear lower bounds on complexity, and argue that it would be impossible for a transformer to solve it.

So, how would we solve it?

(The stimulus for the above train of thought came from a question that Dave Horsley asked of François after the talk.)