How to convert matrix to a stack of diagonal matrices based on every row?

后端 未结 5 1713
既然无缘
既然无缘 2021-01-20 10:52

I have a matrix:

A = [1 1 1
     2 2 2
     3 3 3]

Is there a vectorized way of obtaining:

B = [1 0 0 
     0 1 0
     0 0          


        
5条回答
  •  庸人自扰
    2021-01-20 11:35

    This is one solution using mod and sub2ind:

    %// example data
    data = reshape(1:9,3,3).' %'
    n = 3;  %// assumed to be known
    
    data =
    
         1     2     3
         4     5     6
         7     8     9
    

    %// row indices
    rows = 1:numel(data);
    %// column indices
    cols = mod(rows-1,n) + 1;
    %// pre-allocation
    out = zeros(n*n,n);
    %// linear indices
    linIdx = sub2ind(size(out),rows,cols);
    %// assigning
    out(linIdx) = data.'
    

    out =
    
         1     0     0
         0     2     0
         0     0     3
         4     0     0
         0     5     0
         0     0     6
         7     0     0
         0     8     0
         0     0     9
    

    Or if you prefer saving lines of code, instead of readability:

    out = zeros(n*n,n);
    out(sub2ind(size(out),1:numel(data),mod((1:numel(data))-1,n) + 1)) = data.'
    

    Two other fast solutions, but not faster than the others:

    %// #1
    Z = blockproc(A,[1 size(A,2)],@(x) diag(x.data));
    
    %// #2
    n = size(A,2);
    Z = zeros(n*n,n);
    Z( repmat(logical(eye(n)),n,1) ) = A;
    

    For the sake of competition - Benchmark

    function [t] = bench()
        A = magic(200);
    
        % functions to compare
        fcns = {
            @() thewaywewalk(A);
            @() lhcgeneva(A);
            @() rayryeng(A);
            @() rlbond(A);
        };
    
        % timeit
        t = zeros(4,1);
        for ii = 1:10;
            t = t + cellfun(@timeit, fcns);
        end
        format long
    end
    
    function Z = thewaywewalk(A) 
        n = size(A,2);
        rows = 1:numel(A);
        cols = mod(rows-1,n) + 1;
        Z = zeros(n*n,n);
        linIdx = sub2ind(size(Z),rows,cols);
        Z(linIdx) = A.';
    end
    function Z = lhcgeneva(A) 
        sz = size(A);
        Z = zeros(sz(1)*sz(2), sz(2));
        for i = 1 : sz(1)
            Z((i-1)*sz(2)+1:i*sz(2), :) = diag(A(i, :));
        end
    end
    function Z = rayryeng(A)  
        A = A.';
        Z = full(sparse(1:numel(A), repmat(1:size(A,2),1,size(A,1)), A(:)));
    end
    function Z = rlbond(A)  
        D = cellfun(@diag,mat2cell(A, ones(size(A,1), 1), size(A,2)), 'UniformOutput', false);
        Z = vertcat(D{:});
    end
    

    ans =
    
       0.322633905428601  %// thewaywewalk
       0.550931853207228  %// lhcgeneva
       0.254718792359946  %// rayryeng - Winner!
       0.898236688657039  %// rlbond
    

提交回复
热议问题