Fast way to remove bits from a ulong

后端 未结 3 2071
太阳男子
太阳男子 2021-01-21 14:09

I want to remove bits from a 64 bit string (represented by a unsigned long). I could do this with a sequence of mask and shift operations, or iterate over each bit as in the cod

3条回答
  •  轮回少年
    2021-01-21 14:41

    The bit twiddling hacks site doesn't have this particular operation, though it has the one that inspired this answer.

    The idea is to compute, offline, a list of magic numbers that can be plunked into the following template. The template consists of a basic step repeated 6 = lg 64 times: rectify the indexes of the output bits mod 2**k for k = 1, 2, ..., 6, assuming at the start of each step that the indexes are correct mod 2**(k-1).

    For example, suppose that we wish to transform

    x = a.b..c.d
        76543210
    

    into

    ....abcd
    76543210.
    

    Bit a is at position 7 and needs to go to 3 (correct position mod 2). Bit b is at position 5 and needs to go to 2 (incorrect position mod 2). Bit c is at position 2 and needs to go to 1 (incorrect position mod 2). Bit d is at position 0 and needs to stay (correct position mod 2). The first intermediate step is to move b and c like so.

    a..b..cd
    76543210
    

    This is accomplished with

    x = (x & 0b10000001) | ((x >>> 1) & 0b00010010);
             //76543210                 //76543210
    

    Here >>> denotes a logical shift and 0bxxxxxxxx denotes a big-endian binary literal. Now we're left with two problems: one on the odd-indexed bits and one on the even-. What makes this algorithm fast is that these now can be handled in parallel.

    For completeness, the other two operations are as follows. Bit a is now at position 7 and needs to go to 3 (correct position mod 4). Bit b is now at position 6 and needs to go to 4 (incorrect position mod 4). Bits c and d need to stay (correct positions mod 4). To get

    a....bcd
    76543210,
    

    we do

    x = (x & 0b10000011) | ((x >>> 2) & 0b00000100);
             //76543210                 //76543210
    

    Bit a is now at position 7 and needs to go to 3 (incorrect position mod 8). Bits b, c, and d need to stay (correct positions mod 8). To get

    ....abcd
    76543210,
    

    we do

    x = (x & 0b00000111) | ((x >>> 4) & 0b00001000);
             //76543210                 //76543210
    

    Here's some proof of concept Python (sorry).

    def compute_mask_pairs(retained_indexes):
        mask_pairs = []
        retained_indexes = sorted(retained_indexes)
        shift = 1
        while (retained_indexes != list(range(len(retained_indexes)))):
            mask0 = 0
            mask1 = 0
            for (i, j) in enumerate(retained_indexes):
                assert (i <= j)
                assert ((i % shift) == (j % shift))
                if ((i % (shift * 2)) != (j % (shift * 2))):
                    retained_indexes[i] = (j - shift)
                    mask1 |= (1 << j)
                else:
                    mask0 |= (1 << j)
            mask_pairs.append((mask0, mask1))
            shift *= 2
        return mask_pairs
    
    def remove_bits_fast(mask_pairs, x):
        for (log_shift, (mask0, mask1)) in enumerate(mask_pairs):
            x = ((x & mask0) | ((x >> (2 ** log_shift)) & mask1))
        return x
    
    def remove_bits_slow(retained_indexes, x):
        return sum(((((x // (2 ** j)) % 2) * (2 ** i)) for (i, j) in enumerate(sorted(retained_indexes))))
    
    def test():
        k = 8
        for mask in range((2 ** k)):
            retained_indexes = {i for i in range(k) if (((mask // (2 ** k)) % 2) == 0)}
            mask_pairs = compute_mask_pairs(retained_indexes)
            for x in range((2 ** k)):
                assert (remove_bits_fast(mask_pairs, x) == remove_bits_slow(retained_indexes, x))
    test()
    

提交回复
热议问题