Matlab: fast way to sum ones in binary numbers with Sparse structure?

泪湿孤枕 提交于 2019-12-19 05:08:49

问题


Most answers only address the already-answered question about Hamming weights but ignore the point about find and dealing with the sparsity. Apparently the answer by Shai here addresses the point about find -- but I am not yet able to verify it. My answer here does not utilise the ingenuity of other answers such as the bitshifting but good enough example answer.

Input

>> mlf=sparse([],[],[],2^31+1,1);mlf(1)=10;mlf(10)=111;mlf(77)=1010;  
>> transpose(dec2bin(find(mlf)))

ans =

001
000
000
011
001
010
101

Goal

1
0
0
2
1
1
2    

Fast calculation for the amount of ones in binary numbers with the sparse structure?


回答1:


You can do this in tons of ways. The simplest I think would be

% Example data
F = [268469248 285213696 536904704 553649152];

% Solution 1
sum(dec2bin(F)-'0',2)

And the fastest (as found here):

% Solution 2
w = uint32(F');

p1 = uint32(1431655765);
p2 = uint32(858993459);
p3 = uint32(252645135);
p4 = uint32(16711935);
p5 = uint32(65535);

w = bitand(bitshift(w, -1), p1) + bitand(w, p1);
w = bitand(bitshift(w, -2), p2) + bitand(w, p2);
w = bitand(bitshift(w, -4), p3) + bitand(w, p3);
w = bitand(bitshift(w, -8), p4) + bitand(w, p4);
w = bitand(bitshift(w,-16), p5) + bitand(w, p5);



回答2:


According to your comments, you convert a vector of numbers to binary string representations using dec2bin. Then you can achieve what you want as follows, where I'm using vector [10 11 12] as an example:

>> sum(dec2bin([10 11 12])=='1',2)

ans =

     2
     3
     2

Or equivalently,

>> sum(dec2bin([10 11 12])-'0',2)

For speed, you could avoid dec2bin like this (uses modulo-2 operations, inspired in dec2bin code):

>> sum(rem(floor(bsxfun(@times, [10 11 12].', pow2(1-N:0))),2),2)

ans =

     2
     3
     2

where N is the maximum number of binary digits you expect.




回答3:


If you really want fast, I think a look-up-table would be handy. You can simply map, for 0..255 how many ones they have. Do this once, and then you only need to decompose an int to its bytes look the sum up in the table and add the results - no need to go to strings...


An example:

>> LUT = sum(dec2bin(0:255)-'0',2); % construct the look up table (only once)
>> ii = uint32( find( mlf ) ); % get the numbers
>> vals = LUT( mod( ii, 256 ) + 1 ) + ... % lower bytes
          LUT( mod( ii/256, 256 ) + 1 ) + ...
          LUT( mod( ii/65536, 256 ) + 1 ) + ...
          LUT( mod( ii/16777216, 256 ) + 1 );

Using typecast (as suggested by Amro):

>> vals = sum( reshape(LUT(double(typecast(ii,'uint8'))+1), 4, [] ), 1 )';

Run time comparison

>> ii = uint32(randi(intmax('uint32'),100000,1));
>> tic; vals1 = sum( reshape(LUT(typecast(ii,'uint8')+1), 4, [] ), 1 )'; toc, %//'
>> tic; vals2 = sum(dec2bin(ii)-'0',2); toc
>> dii = double(ii); % type issues
>> tic; vals3 = sum(rem(floor(bsxfun(@times, dii, pow2(1-32:0))),2),2); toc

Results:

Elapsed time is 0.006144 seconds.  <-- this answer
Elapsed time is 0.120216 seconds.  <-- using dec2bin
Elapsed time is 0.118009 seconds.  <-- using rem and bsxfun



回答4:


Here is an example to show @Shai's idea of using a lookup table:

% build lookup table for 8-bit integers
lut = sum(dec2bin(0:255)-'0', 2);

% get indices
idx = find(mlf);

% break indices into 8-bit integers and apply LUT
nbits = lut(double(typecast(uint32(idx),'uint8')) + 1);

% sum number of bits in each
s = sum(reshape(nbits,4,[]))

you might have to switch to uint64 instead if you have really large sparse arrays with large indices outside the 32-bit range..


EDIT:

Here is another solution for you using Java:

idx = find(mlf);
s = arrayfun(@java.lang.Integer.bitCount, idx);

EDIT#2:

Here is yet another solution implemented as C++ MEX function. It relies on std::bitset::count:

bitset_count.cpp

#include "mex.h"
#include <bitset>

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    // validate input/output arguments
    if (nrhs != 1) {
        mexErrMsgTxt("One input argument required.");
    }
    if (!mxIsUint32(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0])) {
        mexErrMsgTxt("Input must be a 32-bit integer dense matrix.");
    }
    if (nlhs > 1) {
        mexErrMsgTxt("Too many output arguments.");
    }

    // create output array
    mwSize N = mxGetNumberOfElements(prhs[0]);
    plhs[0] = mxCreateDoubleMatrix(N, 1, mxREAL);

    // get pointers to data
    double *counts = mxGetPr(plhs[0]);
    uint32_T *idx = reinterpret_cast<uint32_T*>(mxGetData(prhs[0]));

    // count bits set for each 32-bit integer number
    for(mwSize i=0; i<N; i++) {
        std::bitset<32> bs(idx[i]);
        counts[i] = bs.count();
    }
}

Compile the above function as mex -largeArrayDims bitset_count.cpp, then run it as usual:

idx = find(mlf);
s = bitset_count(uint32(idx))

I decided to compare all the solutions mentioned so far:

function [t,v] = testBitsetCount()
    % random data (uint32 vector)
    x = randi(intmax('uint32'), [1e5,1], 'uint32');

    % build lookup table (done once)
    LUT = sum(dec2bin(0:255,8)-'0', 2);

    % functions to compare
    f = {
        @() bit_twiddling(x)      % bit twiddling method
        @() lookup_table(x,LUT);  % lookup table method
        @() bitset_count(x);      % MEX-function (std::bitset::count)
        @() dec_to_bin(x);        % dec2bin
        @() java_bitcount(x);     % Java Integer.bitCount
    };

    % compare timings and check results are valid
    t = cellfun(@timeit, f, 'UniformOutput',true);
    v = cellfun(@feval, f, 'UniformOutput',false);
    assert(isequal(v{:}));
end

function s = lookup_table(x,LUT)
    s = sum(reshape(LUT(double(typecast(x,'uint8'))+1),4,[]))';
end

function s = dec_to_bin(x)
    s = sum(dec2bin(x,32)-'0', 2);
end

function s = java_bitcount(x)
    s = arrayfun(@java.lang.Integer.bitCount, x);
end

function s = bit_twiddling(x)
    p1 = uint32(1431655765);
    p2 = uint32(858993459);
    p3 = uint32(252645135);
    p4 = uint32(16711935);
    p5 = uint32(65535);

    s = x;
    s = bitand(bitshift(s, -1), p1) + bitand(s, p1);
    s = bitand(bitshift(s, -2), p2) + bitand(s, p2);
    s = bitand(bitshift(s, -4), p3) + bitand(s, p3);
    s = bitand(bitshift(s, -8), p4) + bitand(s, p4);
    s = bitand(bitshift(s,-16), p5) + bitand(s, p5);
end

The times elapsed in seconds:

t = 
    0.0009    % bit twiddling method
    0.0087    % lookup table method
    0.0134    % C++ std::bitset::count
    0.1946    % MATLAB dec2bin
    0.2343    % Java Integer.bitCount



回答5:


This gives you the rowsums of the binary numbers from the sparse structure.

>> mlf=sparse([],[],[],2^31+1,1);mlf(1)=10;mlf(10)=111;mlf(77)=1010;  
>> transpose(dec2bin(find(mlf)))

ans =

001
000
000
011
001
010
101

>> sum(ismember(transpose(dec2bin(find(mlf))),'1'),2)

ans =

     1
     0
     0
     2
     1
     1
     2

Hope someone able to find faster rowsummation!




回答6:


Mex it!


Save this code as countTransBits.cpp:

#include "mex.h"

void mexFunction( int nout, mxArray* pout[], int nin, mxArray* pin[] ) {
mxAssert( nin == 1 && mxIsSparse(pin[0]) && mxGetN( pin[0] ) == 1,
            "expecting single sparse column vector input" );
    mxAssert( nout == 1, "expecting single output" );

    // set output, assuming 32 bits, set to 64 if needed
    pout[0] = mxCreateNumericMatrix( 32, 1, mxUINT32_CLASS, mxREAL );
    unsigned int* counter = (unsigned int*)mxGetData( pout[0] );
    for ( int i = 0; i < 32; i++ ) {
        counter[i] = 0;
    }

    // start working
    mwIndex *pIr = mxGetIr( pin[0] );
    mwIndex* pJc = mxGetJc( pin[0] );
    double*  pr = mxGetPr( pin[0] );
    for ( mwSize i = pJc[0]; i < pJc[1]; i++ ) {
        if ( pr[i] != 0 ) {// make sure entry is non-zero
            unsigned int entry = pIr[i] + 1; // cast to unsigned int and add 1 for 1-based indexing in Matlab
            int bit = 0;
            while ( entry != 0 && bit < 32 ) {
                counter[bit] += ( entry & 0x1 ); // count the lsb
                bit++;
                entry >>= 1; // shift right
            }
        }
    }
}

Compile it in Matlab

>> mex -largeArrayDims -O countTransBits.cpp

Run the code

>> countTransBits( mlf )

Note that the output count in 32 bins lsb to msb.




回答7:


The bitcount FEX contribution offers a solution based on the lookup table approach, but is better optimized. It runs more than twice as fast as the bit twiddling method (i.e. the fastest pure-MATLAB method reported by Amro) over a 1 million uint32 vector, using R2015a on my old laptop.



来源:https://stackoverflow.com/questions/19835495/matlab-fast-way-to-sum-ones-in-binary-numbers-with-sparse-structure

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!