Improving MATLAB Matrix Construction Code : Or, code Vectorization for beginners

前端 未结 3 1596
迷失自我
迷失自我 2020-12-09 13:32

I wrote a program in order to construct a portion of a 3-Band Wavelet Transform Matrix. However, given that the size of the matrix is 3^9 X 3^10, it takes a while for MATLAB

3条回答
  •  抹茶落季
    2020-12-09 14:08

    You can use a cunning way to create a block diagonal matrix like this:

    >> v=[-0.117377016134830 0.54433105395181 -0.0187057473531300 ...
              -0.699119564792890 -0.136082763487960 0.426954037816980];
    >> lendiff=length(v)-3;
    >> B=repmat([v zeros(1,3^n-lendiff)],3^(n-1),1);
    >> B=reshape(B',3^(n),3^(n-1)+1);
    >> B(:,end-1)=B(:,end-1)+B(:,end);
    >> B=B(:,1:end-1)';

    Here, lendiff is used to create 3^{n-1} copies of a line with v followed by zeros, that have length 3^n+3, so a matrix of size [3^{n-1} 3^n+3].

    That matrix is reshaped into size [3^n 3^{n-1}+1] to create the shifts. The extra column needs to be added to the last and B needs to be transposed.

    Should be much faster though.

    EDIT

    Seeing Darren's solution and realising that reshape works on sparse matrices too, got me to come up with this -- without for loops (un-coded the original solution).

    First the values to start with:

    >> v=[-0.117377016134830  ...
           0.54433105395181   ...
          -0.0187057473531300 ...
          -0.699119564792890  ...
          -0.136082763487960  ...
           0.426954037816980];    
    >> rows = 3^(n-1);                  % same number of rows
    >> cols = 3^(n)+3;                  % add 3 cols to implement the shifts    
    

    Then make the matrix with 3 extra columns per row

    >> row=(1:rows)'*ones(1,length(v)); % row number where each copy of v is stored'
    >> col=ones(rows,1)*(1:length(v));  % place v at the start columns of each row
    >> val=ones(rows,1)*v;              % fill in the values of v at those positions
    >> B=sparse(row,col,val,rows,cols); % make the matrix B[rows cols+3], but now sparse
    

    Then reshape to implement the shifts (extra row, right number of columns)

    >> B=reshape(B',3^(n),rows+1);      % reshape into B[3^n rows+1], shifted v per row'
    >> B(1:3,end-1)=B(1:3,end);         % the extra column contains last 3 values of v
    >> B=B(:,1:end-1)';                 % delete extra column after copying, transpose
    

    For n=4,5,6,7 this results in cpu times in s:

    n    original    new version
    4    0.033       0.000
    5    0.206       0.000
    6    1.906       0.000
    7    16.311      0.000
    

    measured by the profiler. For the original version I cannot run n>7 but the new version gives

    n    new version
    8    0.002
    9    0.009
    10   0.022
    11   0.062
    12   0.187
    13   0.540
    14   1.529
    15   4.210
    

    and that is how far my RAM goes :)

提交回复
热议问题