Tensor multiplication w/o looping in Matlab

﹥>﹥吖頭↗ 提交于 2019-12-11 16:11:47

问题


I have a 3d array A, e.g. A=rand(N,N,K).

I need an array B s.t.

B(n,m) = norm(A(:,:,n)*A(:,:,m)' - A(:,:,m)*A(:,:,n)','fro')^2 for all indices n,m in 1:K.

Here's the looping code:

B = zeros(K,K);    
for n=1:K
       for m=1:K
           B(n,m) = norm(A(:,:,n)*A(:,:,m)' - A(:,:,m)*A(:,:,n)','fro')^2;
       end
end

I don't want to loop through 1:K.

I can create an array An_x_mt of size NK x NK s.t.

An_x_mt equals A(:,:,n)*A(:,:,m)' for all n,m in 1:K by
An_x_mt = Ar*Ac_t; 

with

Ac_t=reshape(permute(A,[2 1 3]),size(A,1),[]); 
Ar=Ac_t';

How do I create an array Am_x_nt also of size NK x NK s.t.

Am_x_nt equals A(:,:,m)*A(:,:,n)' for all n,m in 1:K

so that I could do

B = An_x_mt  - Am_x_nt
B = reshape(B,N,N,[]);
B = reshape(squeeze(sum(sum(B.^2,1),2)),K,K);

Thx


回答1:


For those who can't/won't use mmx and want to stick to pure Matlab code, here's how you could do it. mat2cell and cell2mat functions are your friends:

[N,~,nmat]=size(A);
Atc = reshape(permute(A,[2 1 3]),N,[]); % A', N x N*nmat
Ar = Atc'; % A, N*nmat x N
Anmt_2d = Ar*Atc; % An*Am'
Anmt_2d_cell = mat2cell(Anmt_2d,N*ones(nmat,1),N*ones(nmat,1));
Amnt_2d_cell = Anmt_2d_cell'; % ONLY products transposed, NOT their factors
Amnt_2d = cell2mat(Amnt_2d_cell); % Am*An'
Anm = Anmt_2d - Amnt_2d;
Anm = Anm.^2;
Anm_cell = mat2cell(Anm,N*ones(nmat,1),N*ones(nmat,1));
d = cellfun(@(c) sum(c(:)), Anm_cell); % squared Frobenius norm of each product; nmat x nmat

Alternatively, after computing Anmt_2d_cell and Amnt_2d_cell, you could convert them to 3d with the 3rd dimension encoding the (n,m) and (m,n) indices and then do the rest of the computations in 3d. You would need the permn() utility from here https://www.mathworks.com/matlabcentral/fileexchange/7147-permn-v-n-k

Anmt_3d = cat(3,Anmt_2d_cell);
Amnt_3d = cat(3,Amnt_2d_cell);
Anm_3d = Anmt_3d - Amnt_3d;
Anm_3d = Anm_3d.^2;
Anm = squeeze(sum(sum(Anm_3d,1),2));
d = zeros(nmat,nmat);
nm=permn(1:nmat, 2); % all permutations (n,m) with repeat, by-row order
d(sub2ind([nmat,nmat],nm(:,1),nm(:,2))) = Anm;

For some reason, the 2nd option (3D arrays) is twice faster.

Hopes this helps.



来源:https://stackoverflow.com/questions/51916104/tensor-multiplication-w-o-looping-in-matlab

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!