I want to remove bits from a 64 bit string (represented by a unsigned long). I could do this with a sequence of mask and shift operations, or iterate over each bit as in the cod
The bit twiddling hacks site doesn't have this particular operation, though it has the one that inspired this answer.
The idea is to compute, offline, a list of magic numbers that can be plunked into the following template. The template consists of a basic step repeated 6 = lg 64 times: rectify the indexes of the output bits mod 2**k for k = 1, 2, ..., 6, assuming at the start of each step that the indexes are correct mod 2**(k-1).
For example, suppose that we wish to transform
x = a.b..c.d
76543210
into
....abcd
76543210.
Bit a
is at position 7
and needs to go to 3
(correct position mod 2). Bit b
is at position 5
and needs to go to 2
(incorrect position mod 2). Bit c
is at position 2
and needs to go to 1
(incorrect position mod 2). Bit d
is at position 0
and needs to stay (correct position mod 2). The first intermediate step is to move b
and c
like so.
a..b..cd
76543210
This is accomplished with
x = (x & 0b10000001) | ((x >>> 1) & 0b00010010);
//76543210 //76543210
Here >>>
denotes a logical shift and 0bxxxxxxxx
denotes a big-endian binary literal. Now we're left with two problems: one on the odd-indexed bits and one on the even-. What makes this algorithm fast is that these now can be handled in parallel.
For completeness, the other two operations are as follows. Bit a
is now at position 7
and needs to go to 3
(correct position mod 4). Bit b
is now at position 6
and needs to go to 4
(incorrect position mod 4). Bits c
and d
need to stay (correct positions mod 4). To get
a....bcd
76543210,
we do
x = (x & 0b10000011) | ((x >>> 2) & 0b00000100);
//76543210 //76543210
Bit a
is now at position 7
and needs to go to 3
(incorrect position mod 8). Bits b
, c
, and d
need to stay (correct positions mod 8). To get
....abcd
76543210,
we do
x = (x & 0b00000111) | ((x >>> 4) & 0b00001000);
//76543210 //76543210
Here's some proof of concept Python (sorry).
def compute_mask_pairs(retained_indexes):
mask_pairs = []
retained_indexes = sorted(retained_indexes)
shift = 1
while (retained_indexes != list(range(len(retained_indexes)))):
mask0 = 0
mask1 = 0
for (i, j) in enumerate(retained_indexes):
assert (i <= j)
assert ((i % shift) == (j % shift))
if ((i % (shift * 2)) != (j % (shift * 2))):
retained_indexes[i] = (j - shift)
mask1 |= (1 << j)
else:
mask0 |= (1 << j)
mask_pairs.append((mask0, mask1))
shift *= 2
return mask_pairs
def remove_bits_fast(mask_pairs, x):
for (log_shift, (mask0, mask1)) in enumerate(mask_pairs):
x = ((x & mask0) | ((x >> (2 ** log_shift)) & mask1))
return x
def remove_bits_slow(retained_indexes, x):
return sum(((((x // (2 ** j)) % 2) * (2 ** i)) for (i, j) in enumerate(sorted(retained_indexes))))
def test():
k = 8
for mask in range((2 ** k)):
retained_indexes = {i for i in range(k) if (((mask // (2 ** k)) % 2) == 0)}
mask_pairs = compute_mask_pairs(retained_indexes)
for x in range((2 ** k)):
assert (remove_bits_fast(mask_pairs, x) == remove_bits_slow(retained_indexes, x))
test()