How to improve the performance of the recursive method?

好久不见. 提交于 2020-01-02 03:16:48

问题


I'm learning data structures and algorithms, and here is a question that I'm stuck with.

I have to improve the performance of the recursive call by storing the value into memory.

But the problem is that the non-improved version seems faster than this.

Can someone help me out?

Syracuse numbers are a sequence of positive integers defined by the following rules:

syra(1) ≡ 1

syra(n) ≡ n + syra(n/2), if n mod 2 == 0

syra(n) ≡ n + syra((n*3)+1), otherwise

import java.util.HashMap;
import java.util.Map;

public class SyraLengthsEfficient {

    int counter = 0;
    public int syraLength(long n) {
        if (n < 1) {
            throw new IllegalArgumentException();
        }

        if (n < 500 && map.containsKey(n)) {
            counter += map.get(n);
            return map.get(n);
        } else if (n == 1) {
            counter++;
            return 1;
        } else if (n % 2 == 0) {
            counter++;
            return syraLength(n / 2);
        } else {
            counter++;
            return syraLength(n * 3 + 1);
        }
    }

    Map<Integer, Integer> map = new HashMap<Integer, Integer>();

    public int lengths(int n) {
        if (n < 1) {
            throw new IllegalArgumentException();
        }    
        for (int i = 1; i <= n; i++) {
            syraLength(i);
            if (i < 500 && !map.containsKey(i)) {
                map.put(i, counter);
            }
        }    
        return counter;
    }

    public static void main(String[] args) {
        System.out.println(new SyraLengthsEfficient().lengths(5000000));
    }
}

Here is the normal version that i wrote:

 public class SyraLengths{

        int total=1;
        public int syraLength(long n) {
            if (n < 1)
                throw new IllegalArgumentException();
            if (n == 1) {
                int temp=total;
                total=1;
                return temp;
            }
            else if (n % 2 == 0) {
                total++;
                return syraLength(n / 2);
            }
            else {
                total++;
                return syraLength(n * 3 + 1);
            }
        }

        public int lengths(int n){
            if(n<1){
                throw new IllegalArgumentException();
            }
            int total=0;
            for(int i=1;i<=n;i++){
                total+=syraLength(i);
            }

            return total;
        }

        public static void main(String[] args){
            System.out.println(new SyraLengths().lengths(5000000));
        }
       }

EDIT

It is slower than non-enhanced version.

import java.util.HashMap;
import java.util.Map;

public class SyraLengthsEfficient {

    private Map<Long, Long> map = new HashMap<Long, Long>();

    public long syraLength(long n, long count) {

        if (n < 1)
            throw new IllegalArgumentException();

        if (!map.containsKey(n)) {
            if (n == 1) {
                count++;
                map.put(n, count);
            } else if (n % 2 == 0) {
                count++;
                map.put(n, count + syraLength(n / 2, 0));
            } else {
                count++;
                map.put(n, count + syraLength(3 * n + 1, 0));
            }
        }

        return map.get(n);

    }

    public int lengths(int n) {
        if (n < 1) {
            throw new IllegalArgumentException();
        }
        int total = 0;
        for (int i = 1; i <= n; i++) {
            // long temp = syraLength(i, 0);
            // System.out.println(i + " : " + temp);
            total += syraLength(i, 0);

        }
        return total;
    }

    public static void main(String[] args) {
        System.out.println(new SyraLengthsEfficient().lengths(50000000));
    }
}

FINAL SOLUTION (mark as correct by school auto mark system)

public class SyraLengthsEfficient {

private int[] values = new int[10 * 1024 * 1024];

public int syraLength(long n, int count) {

    if (n <= values.length && values[(int) (n - 1)] != 0) {
        return count + values[(int) (n - 1)];
    } else if (n == 1) {
        count++;
        values[(int) (n - 1)] = 1;
        return count;
    } else if (n % 2 == 0) {
        count++;
        if (n <= values.length) {
            values[(int) (n - 1)] = count + syraLength(n / 2, 0);
            return values[(int) (n - 1)];
        } else {
            return count + syraLength(n / 2, 0);
        }
    } else {
        count++;
        if (n <= values.length) {
            values[(int) (n - 1)] = count + syraLength(n * 3 + 1, 0);
            return values[(int) (n - 1)];
        } else {
            return count + syraLength(n * 3 + 1, 0);
        }
    }

}

public int lengths(int n) {
    if (n < 1) {
        throw new IllegalArgumentException();
    }
    int total = 0;
    for (int i = 1; i <= n; i++) {
        total += syraLength(i, 0);
    }
    return total;
}

public static void main(String[] args) {
    SyraLengthsEfficient s = new SyraLengthsEfficient();
    System.out.println(s.lengths(50000000));
}

}


回答1:


Forget about the answers that say that your code is inefficient because of the use of a Map, that's not the reason why it's going slow - it's the fact that you're limiting the cache of calculated numbers to n < 500. Once you remove that restriction, things start to work pretty fast; here's a proof of concept for you to fill-in the details:

private Map<Long, Long> map = new HashMap<Long, Long>();

public long syraLength(long n) {

    if (!map.containsKey(n)) {
        if (n == 1)
            map.put(n, 1L);
        else if (n % 2 == 0)
            map.put(n, n + syraLength(n/2));
        else
            map.put(n, n + syraLength(3*n+1));
    }

    return map.get(n);

}

If you want to read more about what's happening in the program and why is so fast, take a look at this wikipedia article about Memoization.

Also, I think you're misusing the counter variable, you increment it (++) when a value is calculated the first time, but you accumulate over it (+=) when a value is found in the map. That doesn't seem right to me, and I doubt that it gives the expected result.




回答2:


don't use map. store temporary result in a field (it's called accumulator) and do the iteration in a loop until n = 1. after each loop your accumulator will grow by n. and in each loop your n will be growing 3 times + 1 or will be decreasing 2 times. hope that helps you solve your homework




回答3:


Of course it doesn't do as well, you're adding a lot of overhead in the map.put and map.get calls (hashing, bucket creation, etc...). Plus you are autoboxing, which adds a messload of object creation. My guess is that the overhead of the map is far outweighing the benefit.

Try using two arrays instead. one to hold values, and to hold flags that tell you if the value is set or not.

int [] syr = new int[Integer.MAX_VALUE];
boolean [] syrcomputed = new boolean[Integer.MAX_VALUE];

and use those instead of the map:

if (syrcomputed[n]) {
   return syr[n];
}
else {
    syrcomputed[n] = true;
    syr[n] = ....;
}

Also, i would think you might run into some overflow here with larger numbers (as syr approaches MAX_INT/3 you would definately see this if it's not divisible by 2).

As such, you should probably use long types for all your calculations as well.

PS: if your purpose is truly to understand recursion, you shouldn't store the values as a an instance variable, but should be passing it down as an accumulator:

public int syr(int n) {
  return syr(n, new int[Integer.MAX_VALUE], new boolean[Integer.MAX_VALUE]);
}

private int syr(int n, int[] syr, boolean[] syrcomputed) {
   if (syrcomputed[n]) {
     return syr[n];
   }
   else {
     s = [ block for recursive computation ]
     syrcomputed[n] = true;
     syr = s;
   }
}

In some functional languages (scheme, erlang ,etc...) this actually gets unrolled as a tail call (which avoids stack creation). Even though the hotspot jvm doesn't do this (at least to my knowledge), it's still an important concept.



来源:https://stackoverflow.com/questions/10884217/how-to-improve-the-performance-of-the-recursive-method

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!