Why is 2 * (i * i) faster than 2 * i * i in Java?

前端 未结 10 712
一生所求
一生所求 2020-12-22 14:43

The following Java program takes on average between 0.50 secs and 0.55 secs to run:

public static void main(String[] args) {
    long startTime = System.nano         


        
10条回答
  •  感情败类
    2020-12-22 14:45

    (Editor's note: this answer is contradicted by evidence from looking at the asm, as shown by another answer. This was a guess backed up by some experiments, but it turned out not to be correct.)


    When the multiplication is 2 * (i * i), the JVM is able to factor out the multiplication by 2 from the loop, resulting in this equivalent but more efficient code:

    int n = 0;
    for (int i = 0; i < 1000000000; i++) {
        n += i * i;
    }
    n *= 2;
    

    but when the multiplication is (2 * i) * i, the JVM doesn't optimize it since the multiplication by a constant is no longer right before the n += addition.

    Here are a few reasons why I think this is the case:

    • Adding an if (n == 0) n = 1 statement at the start of the loop results in both versions being as efficient, since factoring out the multiplication no longer guarantees that the result will be the same
    • The optimized version (by factoring out the multiplication by 2) is exactly as fast as the 2 * (i * i) version

    Here is the test code that I used to draw these conclusions:

    public static void main(String[] args) {
        long fastVersion = 0;
        long slowVersion = 0;
        long optimizedVersion = 0;
        long modifiedFastVersion = 0;
        long modifiedSlowVersion = 0;
    
        for (int i = 0; i < 10; i++) {
            fastVersion += fastVersion();
            slowVersion += slowVersion();
            optimizedVersion += optimizedVersion();
            modifiedFastVersion += modifiedFastVersion();
            modifiedSlowVersion += modifiedSlowVersion();
        }
    
        System.out.println("Fast version: " + (double) fastVersion / 1000000000 + " s");
        System.out.println("Slow version: " + (double) slowVersion / 1000000000 + " s");
        System.out.println("Optimized version: " + (double) optimizedVersion / 1000000000 + " s");
        System.out.println("Modified fast version: " + (double) modifiedFastVersion / 1000000000 + " s");
        System.out.println("Modified slow version: " + (double) modifiedSlowVersion / 1000000000 + " s");
    }
    
    private static long fastVersion() {
        long startTime = System.nanoTime();
        int n = 0;
        for (int i = 0; i < 1000000000; i++) {
            n += 2 * (i * i);
        }
        return System.nanoTime() - startTime;
    }
    
    private static long slowVersion() {
        long startTime = System.nanoTime();
        int n = 0;
        for (int i = 0; i < 1000000000; i++) {
            n += 2 * i * i;
        }
        return System.nanoTime() - startTime;
    }
    
    private static long optimizedVersion() {
        long startTime = System.nanoTime();
        int n = 0;
        for (int i = 0; i < 1000000000; i++) {
            n += i * i;
        }
        n *= 2;
        return System.nanoTime() - startTime;
    }
    
    private static long modifiedFastVersion() {
        long startTime = System.nanoTime();
        int n = 0;
        for (int i = 0; i < 1000000000; i++) {
            if (n == 0) n = 1;
            n += 2 * (i * i);
        }
        return System.nanoTime() - startTime;
    }
    
    private static long modifiedSlowVersion() {
        long startTime = System.nanoTime();
        int n = 0;
        for (int i = 0; i < 1000000000; i++) {
            if (n == 0) n = 1;
            n += 2 * i * i;
        }
        return System.nanoTime() - startTime;
    }
    

    And here are the results:

    Fast version: 5.7274411 s
    Slow version: 7.6190804 s
    Optimized version: 5.1348007 s
    Modified fast version: 7.1492705 s
    Modified slow version: 7.2952668 s
    

提交回复
热议问题