问题
This implementation of Fibonacci is easy to understand but very slow:
fib 0 = 0
fib 1 = 1
fib n = fib (n-1) + fib (n-2)
Following implementation of Fibonacci is hard to understand but super fast. It calculates 100,000th Fibonacci number instantly on my laptop.
fib = fastFib 1 1
fastFib _ _ 0 = 0
fastFib _ _ 1 = 1
fastFib _ _ 2 = 1
fastFib a b 3 = a + b
fastFib a b c = fastFib (a + b) a (c - 1)
What the magic is happening here about the latter implementation and how does it work?
回答1:
Why is the first implementation slow?
Well it’s slow because each call to fib
may result in up to two (the average is more like 1.6) calls to fib
, so to compute fib 5
you call fib 4
and fib 3
which respectively call fib 3
and fib 2
, and fib 2
and fib 1
, so we can see that each call to fib (n+1)
results in something like twice as much work as calling fib n
.
One thing we might observe is that we work out the same thing lots of times, e.g. above we work out fib 3
twice. That could take a long time if you had to work out e.g. fib 100
twice.
How to do fib faster?
I think it’s better to start with this than trying to jump straight into fastFib
. If I asked you to compute the tenth Fibonacci number by hand, I expect you wouldn’t be computing the third one dozens of times by applying the algorithm. You would probably remember what you had so far. Indeed one could do that for this in Haskell. Just write a program to generate the list of Fibonacci numbers (lazily) and index into it:
mediumFib = (\n -> seq !! n) where
seq = 0:1:[mediumFib (i-1) + mediumFib (i-2)| i <- [2..]]
This is much faster but it is bad because it is using a lot of memory to store the list of Fibonacci numbers, and it is slow to find the nth element of a list because you have to follow a lot of pointers.
To compute a single Fibonacci number from scratch (ie not having any computed already) takes quadratic time.
Another way you might compute the tenth Fibonacci number by hand is by writing down the Fibonacci sequence until you get to the tenth element. You then never need to look far in the past or remember all the things you previously computed, you just need to look at the two previous elements. One can imagine an imperative algorithm to do this
fib(n):
if (n<2) return n
preprevious = 0
previous = 1
i = 2
while true:
current = preprevious + previous
if (i = n) return current
preprevious, previous = previous, current
This is just stepping through the recurrence relation:
f_n = f_(n-2) + f_(n-1)
Indeed we can write it in Haskell too:
fastFib n | n < 2 = n
| otherwise = go 0 1 2 where
go pp p i | i = n = pp + p
| otherwise = go p (pp + p) (i + 1)
This is pretty fast now and we can transform this into the function you have too. Here are the steps:
- Swap the argument order of
pp
(preprevious) andp
(previous) - Instead of counting
i
up, start atn
and count down. - Extract
go
into a top level function because it no longer depends onn
.
This algorithm only needs to do one sum in each step so it is linear time and that’s pretty fast. To compute fib (n+1)
is only a small constant more work than computing fib n
. Compare this to above where it was about 1.6 times as much work.
Is there a faster fib
?
Sure there is. It turns out there’s a clever way to express the Fibonacci sequence. We consider the transformation a,b -> a+b,a
to be a special case of a family of transformations T_pq
:
T_pq : a -> bq + aq + ap
b -> bp + aq
Specifically it is the special case where p = 0
and q = 1
. We now can do some algebra to work out if there is a simple way to express applying T_pq
twice:
T_pq T_pq :
a -> (bp + aq)q + (bq + aq + ap)(q + p)
b -> (bp + aq)p + (bq + aq + ap)q
=
a -> (b + a)(q^2 + 2pq) + a(q^2 + p^2)
b -> b(q^2 + p^2) + a(q^2 + 2pq)
= T_(q^2 + p^2),(q^2 + 2pq)
So now let’s write a simple function to compute T_pq^n (a,b)
and fib n
tPow p q a b n | n = 1 = (b*q + a*q + a*p, b*p + a*q)
| otherwise = let (a', b') = tPow p q a b 1 in tPow p q a' b' (n-1)
fib 0 = 0
fib 1 = 1
fib n = fst $ tPow 0 1 1 0 (n-1)
And now we can use our relation to make tPow
faster:
tPow p q a b n | n = 1 = (b*q + a*q + a*p, b*p + a*q)
| odd n = let (a', b') = tPow p q a b 1 in tPow p q a' b' (n-1)
| even n = tPow (q*q + p*p) (q*q + 2*p*q) a b (n `div` 2)
Why is this faster? Well it’s faster because then computing fib (2*n)
is only a small constant more work than computing fib n
, whereas before it was twice as much work and before that it was four times as much work and before that it was the square of the amount of work. Indeed the number of steps is something like the number of bits of n
in binary plus the number of 1
s in the binary representation of n
. To compute fib 1024
only takes about 10 steps whereas the previous algorithm took about 1000. Computing the billionth Fibonacci number only takes 30 steps, which is a lot less than a billion.
回答2:
The magic is reflection, reification, explication of computational process described by the recursive formula:
fib 0 = 0 -- NB!
fib 1 = 1
fib n = fib (n-1) + fib (n-2)
-- n1 n2
= let {n1 = fib (n-1) ; n2 = fib (n-2)}
in n1 + n2
= let {n1 = fib (n-2) + fib (n-3) ; n2 = fib (n-2)}
-- n2 n3
in n1 + n2
= let {n1 = n2+n3 ; n2 = fib (n-2) ; n3 = fib (n-3)}
in n1 + n2
= let {n1 = n2+n3 ; n2 = fib (n-3) + fib (n-4) ; n3 = fib (n-3)}
-- n3 n4
in n1 + n2
= let {n1 = n2+n3 ; n2 = n3+n4 ; n3 = fib (n-3) ; n4 = fib (n-4)}
in n1 + n2
= let {n1 = n2+n3 ; n2 = n3+n4 ; n3 = n4+n5 ; n4 = fib (n-4) ; n5 = fib (n-5)}
in n1 + n2
= .......
, seeing it through to the end case(s), then flipping the time arrow (or just reading it from right to left), and coding explicitly what's been implicitly going on inside the let
as part of the recursion's simulated "call stack" operations.
Most importantly, replacing equals by equals, aka referential transparency -- using n2
in place of each appearance of fib (n-2)
, etc.
回答3:
Just want to make it clear that tail recursion has nothing to do with what makes second program fast. Below, I rewrite your first program to use a proper tail call and we compare the execution time to the second program. I also rewrote that one because it can be simplified quite a bit -
fib1 n = slow n id
where
slow 0 k = k 0
slow 1 k = k 1
slow n k = slow (n - 1) (\a ->
slow (n - 2) (\b ->
k (a + b)))
fib2 n = fast n 0 1
where
fast 0 a _ = a
fast n a b = fast (n - 1) b (a + b)
The impact on tiny numbers like n = 10
is negligible -
fib1 10
-- 55
-- (0.01 secs, 138,264 bytes)
fib2 10
-- 55
-- (0.01 secs, 71,440 bytes)
But even around n = 20
we notice a huge fall-off in fib1
performance -
fib1 20
-- 6765
-- (0.70 secs, 8,787,320 bytes)
fib2 20
-- 6765
-- (0.01 secs, 76,192 bytes)
At n = 30
, the impact is laughable. Both programs still arrive at the same result, so that's good, but fib1
takes over 30 seconds. fib2
still only takes a fraction of a second -
fib1 30
-- 832040
-- (32.91 secs, 1,072,371,488 bytes) LOL so bad
fib2 30
-- 832040 (0.09 secs, 80,944 bytes)
The reason for this is because the first program, fib1
, makes two recursive calls. The process for this function uses exponential time and space as n
grows. At n = 30
, the slow program will make 1,073,741,824 (230) recursive calls. The fast program will only recur 30 times.
At n = 1000
, we run into a serious problem with fib1
. Based on the performance of fib1 30
, we estimate it would take 1.041082353242204e286
years to complete 21000 recursive calls. Meanwhile, fib2 1000
handles 1000 recursions effortlessly -
fib2 1000
-- 43466557686937456435688527675040625802564660517371780402481729089536555417949051890403879840079255169295922593080322634775209689623239873322471161642996440906533187938298969649928516003704476137795166849228875
-- (0.13 secs, 661,016 bytes)
The original rewrite of your first program might be hard to follow with the added k
parameter. Use of Cont
allows us to see the a clear sequence of steps in Haskell's familiar do
notation -
import Control.Monad.Cont
fib1 n = runCont (slow n) id
where
slow 0 = return 0
slow 1 = return 1
slow n = do
a <- slow (n - 1)
b <- slow (n - 2)
return (a + b)
回答4:
It's just obfuscation to hide the fact that the input number is being used as a counter. I would hope that if you saw something like this instead, you'd understand why:
fib2 n = fastFib2 0 1 0 n
fastFib2 current previous count 0 = 0
fastFib2 current previous count 1 = 1
fastFib2 current previous count n
| count == n = current
| otherwise =
fastFib2 (current + previous) current (count + 1) n
In the code above, we've made the counter explicit: when it equals our input, n
, we return our accumulator, current
; otherwise, we keep track in this "forward" recursion of the current and previous numbers (the "two preceding ones"), all that is needed to construct the Fibonacci sequence.
The code you shared does the same thing. The (c - 1)
makes it look like a more traditional "backwards" recurrence, when it's actually starting off the accumulator in the first call, then adding to it.
来源:https://stackoverflow.com/questions/54843670/why-this-implementation-of-fibonacci-is-extremely-fast