In this article, I'll explain why implementing numbers with just algebraic datatypes is desirable. I'll then talk about common implementations of FFT (Fast Fourier Transform) and why they hide inherent inefficiencies. I'll then show how to implement integers and complex numbers with just algebraic datatypes, in a way that is extremely simple and elegant. I'll conclude by deriving a pure functional implementation of complex FFT with just datatypes, no floats.
For most programmers, "real numbers" are a given: they just use floats and call it a day. But, in some cases, it doesn't work well, and I'm not talking about precision issues. When trying to prove statements on real numbers in proof assistants like Coq, we can't use doubles, we must formalize reals using datatypes, which can be very hard. For me, the real issue arises when I'm trying to implement optimal algorithms on HVM. To do that, I need functions to fuse, which is a fancy way of saying that (+ 2) . (+ 2) "morphs into" (+ 4) during execution - yet, machine floats block that mechanism. Because of that, I've been trying to come up with an elegant way to implement numbers with just datatypes. For natural numbers, it is easy: we can just implement them as bitstrings:
The program above implements addition by repeated increment, and is the simplest numerical example where fusion makes a big difference, being exponential on Haskell and linear on HVM. But this isn't useful in practice, as we already have fast addition algorithms like add-carry. The question, then, becomes: is it possible to explore optimality to implement numeric algorithms that actually beat existing ones in practice? The idea is that we could implement numbers with datatypes, which are, usually, much slower than floats, but, if the resulting asymptotics are improved, then this would easily pay off. One of the most successful numeric algorithms ever is the FFT. How would an "optimal FFT" look like? Let's find out!
Designing optimal functions is extremely tricky, because it takes very little to interrupt fusion and get an exponential slowdown. For example, if we change the inc Haskell function above to let the Nat grow in size:
Then, the fact we're now using the "I" constructor twice will prevent us from fusing the corresponding function on the HVM. A single character will downgrade it from linear to exponential! There are bazillion ways to ruin fusion. Passing a state as an argument, placing a lambda in the wrong location, using an argument more than once... in fact, it feels like anything that detours from a "direct linear recursion" will refuse to fuse. With that in mind, let's examine the textbook FFT:
Note: you do not need to understand FFT to read this article - just view this as a random snippet we want to optimize.
So, what is wrong with this? Well, when it comes to optimality, everything. First, it traverses the list just to compute its length. Then, it copies the entire list 2 more times to split even/odd indices. Then, it traverses the odds list twice. Then it copies it. Then it re-generates the same twiddle factors over and over. Then it traverses everything a few more times with zips. All these things block fusion, and, if that wasn't bad enough, it performs arithmetic on machine floats on its its inner loop, which removes any remaining hope we could have. So, how can it be improved?
The first insight is to replace lists by balanced binary trees, and store elements on nodes such that, by starting from any element of the tree, walking up to the root, and annotating the branches we passed through as 0 for left and 1 for right, we'll get the binary representation of the index. For example:
[0, 1, 2, 3, 4, 5, 6, 7]Is represented as:
(B (B (B (L 0) (L 4)) (B (L 2) (L 6))) (B (B (L 1) (L 5)) (B (L 3) (L 7))))If we start from 6, and move up until the root, we'll pass through right (1), right (1) and left (0) branches of a node, which agrees with the fact 110 is 6 in binary.
Storing elements that way has two benefits. First, we can now split a tree into even/odd indices in O(1) by just taking the left/right branches, removing the no need to call the "split" function, which is O(N), on every recursive call. For example, by taking the left branch of that tree, we get:
(B (B (L 0) (L 4)) (B (L 2) (L 6)))Which corresponds to the list [0, 2, 4, 6], i.e., the even indices of the original list. And if we take the left branch again, we get:
(B (L 0) (L 4))Which corresponds to the list [0, 4], which is, once again, the even indices of the list above. The second benefit is that we're now able to replace the 3 calls to zipWith by a single call to mix, which combines evens, odds and the twiddle factors in a single pass. Here is the improved version:
The way it works is similar, with all unnecessary work removed. The fft function receives a tree with coefficients, and calls itself recursively on the left and right branches, which corresponds to even and odd coefficients on the original polynomial. Then, it combines the set of recursive points in a single pass by calling the mix function. Finally, it doesn't generate a list of twiddle factors. Instead, it stores an angle (as a natural number, ranging from 0 to N, where N is the number of points) which is then used to generate the twiddle factor at the bottom of the recursion.
While this algorithm is less pedagogical and readable, it is fully linear and way less contrived for the runtime. It is in a healthy shape for optimality, but there is still a problem: it represents complex numbers using floats, thus requires machine numbers that don't fuse. Of course, we could perform FFT over other fields, but in some cases we really need complex FFT.
Our ultimate goal is to come up with an implementation of Complex that can be used on FFT, and that is based purely on ADTs, i.e., no native floats, in such a way that is simple and direct enough to possibly fuse. Let's begin with a simpler goal: integers. Let's recall how we implemented Nat and inc:
So, what about Int? We could try implementing it by using a Nat and a sign:
But this is actually a bad idea since, by pattern-matching on the boolean to treat cases separately we inhibit fusion. Also, calling isMinusOne makes it non-linear, which also inhibits fusion. In general, this is exactly the shape of code that does not fuse. We're looking for something more uniform, that doesn't separate the sign.
Fortunatelly, there is a pretty elegant way to do it: balanced ternary. It is similar to ternary numbers, expect its digits are not 0,1,2, but -1,0,1. The -1 digit is usually written as T, so, for example, the string 11T represents 9 + 3 - 1, which is 11. Amazingly, all integers can be uniquely represented in this system. Let's count:
num trits --- | ----- −13 | 𝖳𝖳𝖳 −12 | 𝖳𝖳0 −11 | 𝖳𝖳1 −10 | 𝖳0𝖳 −9 | 𝖳00 −8 | 𝖳01 −7 | 𝖳1𝖳 −6 | 𝖳10 −5 | 𝖳11 −4 | 𝖳𝖳 −3 | 𝖳0 −2 | 𝖳1 −1 | 𝖳 0 | 0 1 | 1 2 | 1𝖳 3 | 10 4 | 11 5 | 1𝖳𝖳 6 | 1𝖳0 7 | 1𝖳1 8 | 10𝖳 9 | 100 10 | 101 11 | 11𝖳 12 | 110 13 | 111Beautiful, isn't it? Its symmetric proporties and simple arithmetic are quite important for fusion. Here is an implementation of Int and inc:
As you can imagine, this representation of Int, and its arithmetic operations, can be implemented on HVM in a way that fuses nicely, solving part of the problem. But what about real and complex numbers?
Sadly, it took me a long time to figure out the proper way to do Int, which is astronomically simpler than these. Just to think of the complexity of floats (which include sign, mantissa, exponent, NaNs...), Dedekind Cuts, Cauchy Sequences and the like would give us the feeling we'll never find a representation of real numbers that is as nice and uniform as the Int type above. Yet, for the sake of FFT, there is some light at the end of the tunnel.
The insight is that we don't actually need all complex numbers; we just need enough of them to represent roots of unity. In fact, if we wanted to perform FFT on lists of 2 elements, integers would be enough for us, since we only need to evaluate polynomials on 2 complex points, [1, -1], which are both just ints!
But what if there were 4 elements? In this case, we would actually need 4 points: [1, i, -1, -i]... which is beyond the set of integers. But there is an common extension of integers that could help: Gaussian Integers, which add the square root of -1, i, to Int. It forms sort of a quantized grid over the complex plane, as follows:
We could implement that type as follows:
This would give us 4 roots of unity, and, thus, the ability to perform FFT on lists of 4 elements!
What about 8 elements though? Well, the idea here is to simply keep going, and add a new constant, j, which is the square root of i. Once we do that, we need 4 ints to represent a number; that's because ij is also part of the set, so any number can be represented as A + B*i + C*j + D*ij. We can implement this extended Gaussian Integer as:
This would give us 8 roots of unity, and, thus, the ability to perform FFT on lists of 8 elements.
And it would allow us to represent "fractional" complex numbers with just integers. For example, the number 4.652 + 7.816 * i can be approximated as 1 + 2*i + 3*j + 4*ij, which can be constructed as:
num :: ExtGauss num = G 1 2 3 4As you can see, we can keep going and generalize this to make the set G(N), with G(0) being Integers, G(1) being Gaussian Integers, G(2) being Extended Gaussian Integers, and so on. To implement it, we can use a tree:
For G(3), we have 16 roots of unity, and can perform FFT on lists of 16 elements:
To add two complex numbers of GN, we just do so pairwise:
But what about multiplication? A very elegant algorithm, found by T6 on HOC's discord, allows us to multiply a GN number by bases like i, j, k, etc., with just rotations and negation:
This simple function, if applied on a GN element, will multiply it by its smallest base. So, for example, on G(3), rot performs multiplication by k. On G(2), it multiplies by j. On G(1), it multiplies by i. On G(0), which is just a scalar, it negates the number. Very nice! We can use Rot to multiply two GN elements as follows (once again, thanks T6):
Which is, again, very elegant. But, for the sake of FFT, we don't actually need to perform multiplication of two arbitrary numbers during the algorithm; we just need to multiply by twiddle factors on the upper side of the unit circle. As such, we can use this simplified algorithm, which receives an angle on the unit circle, represented by a Nat, and multiplies it by the respective point:
So, for example, mul (I (I (I E))) x will multiply x by ijk, which is approximately -0.923 + 0.382i. This completely removes the need for computing cos, sin, e^ix, and the like, simplifying the pt2 computation from:
To just:
In other words, all the complex floating point arithmetic, including trigonometric and exponentiation functions, are replaced by the mul function above, which is just a simple recursive pass over the tree structure!
After all this, we're finally able to implement a complete FFT algorithm, in 40 lines of Haskell code, with just plain ADTs:
Yes, this the a complete implementation! Note in this version I'm actually using the primitive Haskell Int, but, as we've stablished, it can be implemented with just ADTs via balanced ternary. Here is a complete Haskell file which performs FFT using GNs, with some helper functions to convert it from and to lists of normal complex numbers, for visualization.
So, how does this behave on the HVM? I don't know. I've just finished implementing it on Haskell and will spend some time trying to adjust it for the HVM in the future. There are many things to tweak. For example, sub isn't necessary and can easily be removed; Nat angles can be replaced by a twiddle multiplier; mix could be simplified, perhaps. But, as is, this gets rid of most of the inefficiencies with common FFT, and, as far as I know, is the cleanest version of FFT in a pure functional sense. If you like it, feel encouraged to join the Higher Order Community to talk about it and ask any questions. Thanks for reading!
Edit: read the comments below for a cool surprise :)