Does MATLAB optimize diag(A*B)?

前端 未结 4 995
予麋鹿
予麋鹿 2021-01-05 10:00

Say I have two very big matrices A (M-by-N) and B (N-by-M). I need the diagonal of A*B. Computing the full A*B requires M

4条回答
  •  时光取名叫无心
    2021-01-05 10:40

    One can also implement diag(A*B) as sum(A.*B',2). Let's benchmark this along with all other implementations/solutions as suggested for this question.

    The different methods implemented as functions are listed below for benchmarking purposes:

    1. Sum-multiplication method-1

      function out = sum_mult_method1(A,B)
      
      out = sum(A.*B',2);
      
    2. Sum-multiplication method-2

      function out = sum_mult_method2(A,B)
      
      out = sum(A.'.*B).';
      
    3. For-loop method

      function out = for_loop_method(A,B)
      
      M = size(A,1);
      out = zeros(M,1);
      for i=1:M
          out(i) = A(i,:) * B(:,i);
      end
      
    4. Full/Direct-multiplication method

      function out = direct_mult_method(A,B)
      
      out = diag(A*B);
      
    5. Bsxfun-method

      function out = bsxfun_method(A,B)
      
      out = sum(bsxfun(@times,A,B.'),2);
      

    Benchmarking Code

    num_runs = 1000;
    M_arr = [100 200 500 1000];
    N = 4;
    
    %// Warm up tic/toc.
    tic();
    elapsed = toc();
    tic();
    elapsed = toc();
    
    for k2 = 1:numel(M_arr)
        M = M_arr(k2);
    
        fprintf('\n')
        disp(strcat('*** Benchmarking sizes are M =',num2str(M),' and N = ',num2str(N)));
    
        A = randi(9,M,N);
        B = randi(9,N,M);
    
        disp('1. Sum-multiplication method-1');
        tic
        for k = 1:num_runs
            out1 = sum_mult_method1(A,B);
        end
        toc
        clear out1
    
        disp('2. Sum-multiplication method-2');
        tic
        for k = 1:num_runs
            out2 = sum_mult_method2(A,B);
        end
        toc
        clear out2
    
        disp('3. For-loop method');
        tic
        for k = 1:num_runs
            out3 = for_loop_method(A,B);
        end
        toc
        clear out3
    
        disp('4. Direct-multiplication method');
        tic
        for k = 1:num_runs
            out4 = direct_mult_method(A,B);
        end
        toc
        clear out4
    
        disp('5. Bsxfun method');
        tic
        for k = 1:num_runs
            out5 = bsxfun_method(A,B);
        end
        toc
        clear out5
    
    end
    

    Results

    *** Benchmarking sizes are M =100 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.015242 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.015180 seconds.
    3. For-loop method
    Elapsed time is 0.192021 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.065543 seconds.
    5. Bsxfun method
    Elapsed time is 0.054149 seconds.
    
    *** Benchmarking sizes are M =200 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.009138 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.009428 seconds.
    3. For-loop method
    Elapsed time is 0.435735 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.148908 seconds.
    5. Bsxfun method
    Elapsed time is 0.030946 seconds.
    
    *** Benchmarking sizes are M =500 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.033287 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.026405 seconds.
    3. For-loop method
    Elapsed time is 0.965260 seconds.
    4. Direct-multiplication method
    Elapsed time is 2.832855 seconds.
    5. Bsxfun method
    Elapsed time is 0.034923 seconds.
    
    *** Benchmarking sizes are M =1000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.026068 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.032850 seconds.
    3. For-loop method
    Elapsed time is 1.775382 seconds.
    4. Direct-multiplication method
    Elapsed time is 13.764870 seconds.
    5. Bsxfun method
    Elapsed time is 0.044931 seconds.
    

    Intermediate Conclusions

    Looks like sum-multiplication methods are the best approaches, though bsxfun approach seems be to catching up with them as M increases from 100 to 1000.

    Next, higher benchmarking sizes were tested with just the sum-multiplication and bsxfun methods. The sizes were -

    M_arr = [1000 2000 5000 10000 20000 50000];
    

    The results are -

    *** Benchmarking sizes are M =1000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.030390 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.032334 seconds.
    5. Bsxfun method
    Elapsed time is 0.047377 seconds.
    
    *** Benchmarking sizes are M =2000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.040111 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.045132 seconds.
    5. Bsxfun method
    Elapsed time is 0.060762 seconds.
    
    *** Benchmarking sizes are M =5000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.099986 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.103213 seconds.
    5. Bsxfun method
    Elapsed time is 0.117650 seconds.
    
    *** Benchmarking sizes are M =10000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.375604 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.273726 seconds.
    5. Bsxfun method
    Elapsed time is 0.226791 seconds.
    
    *** Benchmarking sizes are M =20000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 1.906839 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 1.849166 seconds.
    5. Bsxfun method
    Elapsed time is 1.344905 seconds.
    
    *** Benchmarking sizes are M =50000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 5.159177 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 5.081211 seconds.
    5. Bsxfun method
    Elapsed time is 3.866018 seconds.
    

    Alternate benchmarking Code (with `timeit)

    num_runs = 1000;
    M_arr = [1000 2000 5000 10000 20000 50000 100000 200000 500000 1000000];
    N = 4;
    
    timeall = zeros(5,numel(M_arr));
    for k2 = 1:numel(M_arr)
        M = M_arr(k2);
    
        A = rand(M,N);
        B = rand(N,M);
    
        f = @() sum_mult_method1(A,B);
        timeall(1,k2) = timeit(f);
        clear f
    
        f = @() sum_mult_method2(A,B);
        timeall(2,k2) = timeit(f);
        clear f
    
        f = @() bsxfun_method(A,B);
        timeall(5,k2) = timeit(f);
        clear f
    
    end
    
    figure,
    hold on
    plot(M_arr,timeall(1,:),'-ro')
    plot(M_arr,timeall(2,:),'-ko')
    plot(M_arr,timeall(5,:),'-.b')
    legend('sum-method1','sum-method2','bsxfun-method')
    xlabel('M ->')
    ylabel('Time(sec) ->')
    

    Plot

    enter image description here

    Final Conclusions

    It seems sum-multiplication method is great till certain stage, which is around M=5000 mark and after that bsxfun seems to have a slight upper-hand.

    Future Work

    One can look into varying N and study the performances for the implementations mentioned here.

提交回复
热议问题