Find nth SET bit in an int

后端 未结 11 963
栀梦
栀梦 2020-12-14 18:14

Instead of just the lowest set bit, I want to find the position of the nth lowest set bit. (I\'m NOT talking about value on the nt

11条回答
  •  挽巷
    挽巷 (楼主)
    2020-12-14 18:51

    My approach is to calculate the population count for each 8-bit quarters of the 32-bit integer in parallel, then find which quarter contains the nth bit. The population count of quarters that are lower than the found one can be summarized as the initial value of later calculation.

    After that count set bits one-by-one until the n is reached. Without branches and using an incomplete implementation of population count algorithm, my example is the following:

    #include 
    #include 
    
    int main() {
        uint32_t n = 10, test = 3124375902u; /* 10111010001110100011000101011110 */
        uint32_t index, popcnt, quarter = 0, q_popcnt;
    
        /* count set bits of each quarter of 32-bit integer in parallel */
        q_popcnt = test - ((test >> 1) & 0x55555555);
        q_popcnt = (q_popcnt & 0x33333333) + ((q_popcnt >> 2) & 0x33333333);
        q_popcnt = (q_popcnt + (q_popcnt >> 4)) & 0x0F0F0F0F;
    
        popcnt = q_popcnt;
    
        /* find which quarters can be summarized and summarize them */
        quarter += (n + 1 >= (q_popcnt & 0xff));
        quarter += (n + 1 >= ((q_popcnt += q_popcnt >> 8) & 0xff));
        quarter += (n + 1 >= ((q_popcnt += q_popcnt >> 16) & 0xff));
        quarter += (n + 1 >= ((q_popcnt += q_popcnt >> 24) & 0xff));
    
        popcnt &= (UINT32_MAX >> (8 * quarter));
        popcnt = (popcnt * 0x01010101) >> 24;
    
        /* find the index of nth bit in quarter where it should be */
        index = 8 * quarter;
        index += ((popcnt += (test >> index) & 1) <= n);
        index += ((popcnt += (test >> index) & 1) <= n);
        index += ((popcnt += (test >> index) & 1) <= n);
        index += ((popcnt += (test >> index) & 1) <= n);
        index += ((popcnt += (test >> index) & 1) <= n);
        index += ((popcnt += (test >> index) & 1) <= n);
        index += ((popcnt += (test >> index) & 1) <= n);
        index += ((popcnt += (test >> index) & 1) <= n);
    
        printf("index = %u\n", index);
        return 0;
    }
    

    A simple approach which uses loops and conditionals can be the following as well:

    #include 
    #include 
    
    int main() {
        uint32_t n = 11, test = 3124375902u; /* 10111010001110100011000101011110 */
        uint32_t popcnt = 0, index = 0;
        while(popcnt += ((test >> index) & 1), popcnt <= n && ++index < 32);
    
        printf("index = %u\n", index);
        return 0;
    }
    

提交回复
热议问题