Multiply columns of a matrix with 2d matrix slices of a 3d matrix in MatLab

核能气质少年 提交于 2019-12-10 06:40:30

问题


Basically, I want to perform the following computation:

    G is m x n x k
    S is n x k

    Answer=zeros(m,d)
    for Index=1:k
        Answer(:,Index)=G(:,:,Index)*S(:,Index)
    end

So, answer is a matrix, whose columns are the result of multiplying each layer of a 3d matrix with a column of another matrix.

This really seems like a straightforward type of operation, and I was hoping to find out if there is a native or vectorized (or at least >> faster) way of performing this type of computation in Matlab. Thanks.


回答1:


Try using mtimesx from the Matlab File Exchange. It's the best (fast/efficient) tool I've found so far to do this sort of n-dimensional array multiplication, since it uses mex . I think you could also use bsxfun, but my Matlab-fu is not enough for this sort of thing.

You have m x n x k and m x k and want to produce a n x k.

mtimesx multiplies inputs like i x j x k and j x r x k to produce i x r x k.

To put your problem in mtimesx form, let G be m x n x k, and expand S to be n x 1 x k. Then mtimesx(G,S) would be m x 1 x k, which could then be flattened down to m x k.

m=3; 
n=4; 
k=2;
G=rand(m,n,k);
S=rand(n,k);

% reshape S
S2=reshape(S,n,1,k);

% do multiplication and flatten mx1xk to mxk
Ans_mtimesx = reshape(mtimesx(G,S2),m,k)

% try loop method to compare
Answer=zeros(m,k);
for Index=1:k
    Answer(:,Index)=G(:,:,Index)*S(:,Index);
end

% compare
norm(Ans_mtimesx-Answer)
% returns 0.

So if you wanted a one-liner, you could do:

Ans = reshape(mtimesx(G,reshape(S,n,1,k)),m,k)

By the way, if you post your question on the Matlab Newsreader forums there'll be plenty of gurus who compete to give you answers more elegant or efficient than mine!




回答2:


Here is the bsxfun() version. If A is an m-by-n matrix and x is an n-by-1 vector then A*x can be computed as

sum(bsxfun(@times, A, x'), 2)

The operation permute(S, [3 1 2]) will take the columns of S and distribute them along the 3rd dimension as rows. The [3 1 2] is a permutation of the dimensions of S.

Thus sum(bsxfun(@times, G, permute(S, [3 1 2])), 2) achieves the answer but leaves the result in the 3rd dimension. In order to get it in the form you want another permute is required.

permute(sum(bsxfun(@times, G, permute(S, [3 1 2])), 2), [1 3 2])



回答3:


One thing you can do is to represent your 3d matrix as a 2d block diagonal matrix, with each layer being a diagonal block. The 2d matrix in this case should be represented as a vector containing the stacked columns. if the matrix is large, declare it as a sparse matrix.



来源:https://stackoverflow.com/questions/8698456/multiply-columns-of-a-matrix-with-2d-matrix-slices-of-a-3d-matrix-in-matlab

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!