Why is this simple haskell algorithm so slow?

倖福魔咒の 提交于 2019-12-05 12:54:53

问题


Spoiler alert: this is related to Problem 14 from Project Euler.

The following code takes around 15s to run. I have a non-recursive Java solution that runs in 1s. I think I should be able to get this code much closer to that.

import Data.List

collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

main = do
  print ((foldl1' max) . map (collatz 1) $ [1..1000000])

I have profiled with +RHS -p and noticed that the allocated memory is large, and grows as the input grows. For n = 100,000 1gb is allocated(!), for n = 1,000,000 13gb(!!) is allocated.

Then again, -sstderr shows that although lots of bytes were allocated, total memory use was 1mb, and productivity was 95%+, so maybe 13gb is red-herring.

I can think of a few possibilities:

  1. Something isn't as strict as it needs to be. I've already discovered foldl1', but maybe I need to do more? Is it possible to mark collatz as strict (does that even make sense?

  2. collatz isn't tail-call optimizing. I think it should be but don't know a way to confirm.

  3. The compiler isn't doing some optimizations I think it should - for instance only two results of collatz need to be in memory at any one time (max and current)

Any suggestions?

This is pretty much a duplicate of Why is this Haskell expression so slow?, though I will note that the fast Java solution does not have to perform any memoization. Are there any ways to speed this up without having to resort to it?

For reference, here is my profiling output:

  Wed Dec 28 09:33 2011 Time and Allocation Profiling Report  (Final)

     scratch +RTS -p -hc -RTS

  total time  =        5.12 secs   (256 ticks @ 20 ms)
  total alloc = 13,229,705,716 bytes  (excludes profiling overheads)

COST CENTRE                    MODULE               %time %alloc

collatz                        Main                  99.6   99.4


                                                                                               individual    inherited
COST CENTRE              MODULE                                               no.    entries  %time %alloc   %time %alloc

MAIN                     MAIN                                                   1           0   0.0    0.0   100.0  100.0
 CAF                     Main                                                 208          10   0.0    0.0   100.0  100.0
  collatz                Main                                                 215           1   0.0    0.0     0.0    0.0
  main                   Main                                                 214           1   0.4    0.6   100.0  100.0
   collatz               Main                                                 216           0  99.6   99.4    99.6   99.4
 CAF                     GHC.IO.Handle.FD                                     145           2   0.0    0.0     0.0    0.0
 CAF                     System.Posix.Internals                               144           1   0.0    0.0     0.0    0.0
 CAF                     GHC.Conc                                             128           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Handle.Internals                              119           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Encoding.Iconv                                113           5   0.0    0.0     0.0    0.0

And -sstderr:

./scratch +RTS -sstderr 
525
  21,085,474,908 bytes allocated in the heap
      87,799,504 bytes copied during GC
           9,420 bytes maximum residency (1 sample(s))          
          12,824 bytes maximum slop               
               1 MB total memory in use (0 MB lost due to fragmentation)  

  Generation 0: 40219 collections,     0 parallel,  0.40s,  0.51s elapsed
  Generation 1:     1 collections,     0 parallel,  0.00s,  0.00s elapsed

  INIT  time    0.00s  (  0.00s elapsed)
  MUT   time   35.38s  ( 36.37s elapsed)
  GC    time    0.40s  (  0.51s elapsed)
  RP    time    0.00s  (  0.00s elapsed)  PROF  time    0.00s  (  0.00s elapsed)
  EXIT  time    0.00s  (  0.00s elapsed)
  Total time   35.79s  ( 36.88s elapsed)  %GC time       1.1%  (1.4% elapsed)  Alloc rate    595,897,095 bytes per MUT second

  Productivity  98.9% of total user, 95.9% of total elapsed

And Java solution (not mine, taken from Project Euler forums with memoization removed):

public class Collatz {
  public int getChainLength( int n )
  {
    long num = n;
    int count = 1;
    while( num > 1 )
    {
      num = ( num%2 == 0 ) ? num >> 1 : 3*num+1;
      count++;
    }
    return count;
  }

  public static void main(String[] args) {
    Collatz obj = new Collatz();
    long tic = System.currentTimeMillis();
    int max = 0, len = 0, index = 0;
    for( int i = 3; i < 1000000; i++ )
    {
      len = obj.getChainLength(i);
      if( len > max )
      {
        max = len;
        index = i;
      }
    }
    long toc = System.currentTimeMillis();
    System.out.println(toc-tic);
    System.out.println( "Index: " + index + ", length = " + max );
  }
}

回答1:


At first, I thought you should try putting an exclamation mark before a in collatz:

collatz !a 1  = a
collatz !a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

(You'll need to put {-# LANGUAGE BangPatterns #-} at the top of your source file for this to work.)

My reasoning went as follows: The problem is that you're building up a massive thunk in the first argument to collatz: it starts off as 1, and then becomes 1 + 1, and then becomes (1 + 1) + 1, ... all without ever being forced. This bang pattern forces the first argument of collatz to be forced whenever a call is made, so it starts off as 1, and then becomes 2, and so on, without building up a large unevaluated thunk: it just stays as an integer.

Note that a bang pattern is just shorthand for using seq; in this case, we could rewrite collatz as follows:

collatz a _ | seq a False = undefined
collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

The trick here is to force a in the guard, which then always evaluates to False (and so the body is irrelevant). Then evaluation continues with the next case, a having already been evaluated. However, a bang pattern is clearer.

Unfortunately, when compiled with -O2, this doesn't run any faster than the original! What else can we try? Well, one thing we can do is assume that the two numbers never overflow a machine-sized integer, and give collatz this type annotation:

collatz :: Int -> Int -> Int

We'll leave the bang pattern there, since we should still avoid building up thunks, even if they aren't the root of the performance problem. This brings the time down to 8.5 seconds on my (slow) computer.

The next step is to try bringing this closer to the Java solution. The first thing to realise is that in Haskell, div behaves in a more mathematically correct manner with respect to negative integers, but is slower than "normal" C division, which in Haskell is called quot. Replacing div with quot brought the runtime down to 5.2 seconds, and replacing x `quot` 2 with x `shiftR` 1 (importing Data.Bits) to match the Java solution brought it down to 4.9 seconds.

This is about as low as I can get it for now, but I think this is a pretty good result; since your computer is faster than mine, it should hopefully be even closer to the Java solution.

Here's the final code (I did a little bit of clean-up on the way):

{-# LANGUAGE BangPatterns #-}

import Data.Bits
import Data.List

collatz :: Int -> Int
collatz = collatz' 1
  where collatz' :: Int -> Int -> Int
        collatz' !a 1 = a
        collatz' !a x
          | even x    = collatz' (a + 1) (x `shiftR` 1)
          | otherwise = collatz' (a + 1) (3 * x + 1)

main :: IO ()
main = print . foldl1' max . map collatz $ [1..1000000]

Looking at the GHC Core for this program (with ghc-core), I think this is probably about as good as it gets; the collatz loop uses unboxed integers and the rest of the program looks OK. The only improvement I can think of would be eliminating the boxing from the map collatz [1..1000000] iteration.

By the way, don't worry about the "total alloc" figure; it's the total memory allocated over the lifetime of the program, and it never decreases even when the GC reclaims that memory. Figures of multiple terabytes are common.




回答2:


You could lose the list and the bang patterns and still get the same performance by using the stack instead.

import Data.List
import Data.Bits

coll :: Int -> Int
coll 0 = 0
coll 1 = 1
coll 2 = 2
coll n =
  let a = coll (n - 1)
      collatz a 1 = a
      collatz a x
        | even x    = collatz (a + 1) (x `shiftR` 1)
        | otherwise = collatz (a + 1) (3 * x + 1)
  in max a (collatz 1 n)


main = do
  print $ coll 100000

One problem with this is that you will have to increase the size of the stack for large inputs, like 1_000_000.

update:

Here is a tail recursive version that doesn't suffer from the stack overflow problem.

import Data.Word
collatz :: Word -> Word -> (Word, Word)
collatz a x
  | x == 1    = (a,x)
  | even x    = collatz (a + 1) (x `quot` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

coll :: Word -> Word
coll n = collTail 0 n
  where
    collTail m 1 = m
    collTail m n = collTail (max (fst $ collatz 1 n) m) (n-1)

Notice the use of Word instead of Int. It makes a difference in performance. You could still use the bang patterns if you want, and that would nearly double the performance.




回答3:


One thing I found made a surprising difference in this problem. I stuck with the straight recurrence relation rather than folding, you should pardon the expression, the counting in with it. Rewriting

collatz n = if even n then n `div` 2 else 3 * n + 1

as

collatz n = case n `divMod` 2 of
            (n', 0) -> n'
            _       -> 3 * n + 1

took 1.2 seconds off the runtime for my program on a system with a 2.8 GHz Athlon II X4 430 CPU. My initial faster version (2.3 seconds after the use of divMod):

{-# LANGUAGE BangPatterns #-}

import Data.List
import Data.Ord

collatzChainLen :: Int -> Int
collatzChainLen n = collatzChainLen' n 1
    where collatzChainLen' n !l
            | n == 1    = l
            | otherwise = collatzChainLen' (collatz n) (l + 1)

collatz:: Int -> Int
collatz n = case n `divMod` 2 of
                 (n', 0) -> n'
                 _       -> 3 * n + 1

pairMap :: (a -> b) -> [a] -> [(a, b)]
pairMap f xs = [(x, f x) | x <- xs]

main :: IO ()
main = print $ fst (maximumBy (comparing snd) (pairMap collatzChainLen [1..999999]))

A perhaps more idiomatic Haskell version run in about 9.7 seconds (8.5 with divMod); it's identical save for

collatzChainLen :: Int -> Int
collatzChainLen n = 1 + (length . takeWhile (/= 1) . (iterate collatz)) n

Using Data.List.Stream is supposed to allow stream fusion that would make this version run more like that with the explicit accumulation, but I can't find an Ubuntu libghc* package that has Data.List.Stream, so I can't yet verify that.



来源:https://stackoverflow.com/questions/8659345/why-is-this-simple-haskell-algorithm-so-slow

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!