Optimizing three nested loops with multiple calculation in MATLAB

﹥>﹥吖頭↗ 提交于 2021-01-28 03:12:33

问题


For the following code I want to optimize it using the pattern introduced in this solution. However, the problem is how to deal with referring to three nested loops in a single statement. Moreover, the condition is far different from that post.

hint: W and S are NxN sparse double matrices.

    for i=1:N
    for j=1:N
        for k=1:N
            if W(j,k)~=0       
                temp(k)=S(i,j)-S(i,k); 
            end
        end
              sum_temp=max(temp)+sum_temp;
              temp=0;
    end
    B(i,i)=sum_temp;
    sum_temp=0;
end

回答1:


In this situation I would opt against fully vectorizing your solution. Calculating S(i,j)-S(i,k) for each combination would mean an intermediate result of size [N,N,N]. Instead I went through your code and eliminated as much iteration as possible without increasing the memory consumption. Step by step so you can understand how I ended up there.

N=30;
S=rand(N,N);
W=rand(N,N)<.1;
sum_temp=0;
temp=0;
%Your original code for reference
for i=1:N
    for j=1:N
        for k=1:N
            if W(j,k)~=0
                temp(k)=S(i,j)-S(i,k);
            end
        end
        sum_temp=max(temp)+sum_temp;
        temp=0;
    end
    B(i,i)=sum_temp;n
    sum_temp=0;
end
B_orig=B;
%1) you only want the max, no need to make temp a vector
for i=1:N
    sum_temp=0;
    for j=1:N
        temp=0;
        for k=1:N
            if W(j,k)~=0
                temp=max(temp,S(i,j)-S(i,k));
            end
        end
        sum_temp=temp+sum_temp;
    end
    B(i,i)=sum_temp;
end
assert(all(all(B==B_orig)))
%2) eliminate the outer loop
sum_temp=zeros(N,1);
for j=1:N
    temp=zeros(N,1);
    for k=1:N
        if W(j,k)~=0
            temp=max(temp,S(:,j)-S(:,k));
        end
    end
    sum_temp=temp+sum_temp;
end
B=diag(sum_temp);
assert(all(all(B==B_orig)))

%3) combine the inner loop with the condition
sum_temp=zeros(N,1);
for j=1:N
    temp=zeros(N,1);
    for k=find(W(j,:))
        temp=max(temp,S(:,j)-S(:,k));
    end
    sum_temp=temp+sum_temp;
end
B=diag(sum_temp);
assert(all(all(B==B_orig)))


来源:https://stackoverflow.com/questions/60323652/optimizing-three-nested-loops-with-multiple-calculation-in-matlab

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!