Speed-efficient classification in Matlab

走远了吗. 提交于 2019-11-26 17:51:43
Divakar

Approach #1

For a N x 2 sized points/pixels array, you can avoid permute as suggested in the other solution by Luis, which could slow down things a bit, to have a kind of "permute-unrolled" version of it and also let's bsxfun work towards a 2D array instead of a 3D array, which must be better with performance.

Thus, assuming clusters to be ordered as a N x 2 sized array, you may try this other bsxfun based approach -

%// Get a's and b's
im_a = im(:,:,2);
im_b = im(:,:,3);

%// Get the minimum indices that correspond to the cluster IDs
[~,idx]  = min(bsxfun(@minus,im_a(:),clusters(:,1).').^2 + ...
    bsxfun(@minus,im_b(:),clusters(:,2).').^2,[],2);
idx = reshape(idx,size(im,1),[]);

Approach #2

You can try out another approach that leverages fast matrix multiplication in MATLAB and is based on this smart solution -

d = 2; %// dimension of the problem size

im23 = reshape(im(:,:,2:3),[],2);

numA = size(im23,1);
numB = size(clusters,1);

A_ext = zeros(numA,3*d);
B_ext = zeros(numB,3*d);
for id = 1:d
    A_ext(:,3*id-2:3*id) = [ones(numA,1), -2*im23(:,id), im23(:,id).^2 ];
    B_ext(:,3*id-2:3*id) = [clusters(:,id).^2 ,  clusters(:,id), ones(numB,1)];
end
[~, idx] = min(A_ext * B_ext',[],2); %//'
idx = reshape(idx, size(im,1),[]); %// Desired IDs

What’s going on with the matrix multiplication based distance matrix calculation?

Let us consider two matrices A and B between whom we want to calculate the distance matrix. For the sake of an easier explanation that follows next, let us consider A as 3 x 2 and B as 4 x 2 sized arrays, thus indicating that we are working with X-Y points. If we had A as N x 3 and B as M x 3 sized arrays, then those would be X-Y-Z points.

Now, if we have to manually calculate the first element of the square of distance matrix, it would look like this –

first_element = ( A(1,1) – B(1,1) )^2 + ( A(1,2) – B(1,2) )^2         

which would be –

first_element = A(1,1)^2 + B(1,1)^2 -2*A(1,1)* B(1,1)   +  ...
                A(1,2)^2 + B(1,2)^2 -2*A(1,2)* B(1,2)    … Equation  (1)

Now, according to our proposed matrix multiplication, if you check the output of A_ext and B_ext after the loop in the earlier code ends, they would look like the following –

So, if you perform matrix multiplication between A_ext and transpose of B_ext, the first element of the product would be the sum of elementwise multiplication between the first rows of A_ext and B_ext, i.e. sum of these –

The result would be identical to the result obtained from Equation (1) earlier. This would continue for all the elements of A against all the elements of B that are in the same column as in A. Thus, we would end up with the complete squared distance matrix. That’s all there is!!

Vectorized Variations

Vectorized variations of the matrix multiplication based distance matrix calculations are possible, though there weren't any big performance improvements seen with them. Two such variations are listed next.

Variation #1

[nA,dim] = size(A);
nB = size(B,1);

A_ext = ones(nA,dim*3);
A_ext(:,2:3:end) = -2*A;
A_ext(:,3:3:end) = A.^2;

B_ext = ones(nB,dim*3);
B_ext(:,1:3:end) = B.^2;
B_ext(:,2:3:end) = B;

distmat = A_ext * B_ext.';

Variation #2

[nA,dim] = size(A);
nB = size(B,1);

A_ext = [ones(nA*dim,1) -2*A(:) A(:).^2];
B_ext = [B(:).^2 B(:) ones(nB*dim,1)];

A_ext = reshape(permute(reshape(A_ext,nA,dim,[]),[1 3 2]),nA,[]);
B_ext = reshape(permute(reshape(B_ext,nB,dim,[]),[1 3 2]),nB,[]);

distmat = A_ext * B_ext.';

So, these could be considered as experimental versions too.

Luis Mendo

Use pdist2 (Statistics Toolbox) to compute the distances in a vectorized manner:

ab = im(:,:,2:3);                              % // get A, B components
ab = reshape(ab, [size(im,1)*size(im,2) 2]);   % // reshape into 2-column
dist = pdist2(clusters, ab);                   % // compute distances
[~, idx] = min(dist);                          % // find minimizer for each pixel
idx = reshape(idx, size(im,1), size(im,2));    % // reshape result

If you don't have the Statistics Toolbox, you can replace the third line by

dist = squeeze(sum(bsxfun(@minus, clusters, permute(ab, [3 2 1])).^2, 2));

This gives squared distance instead of distance, but for the purposes of minimizing it doesn't matter.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!