Counting 1 bits (population count) on large data using AVX-512 or AVX-2

前端 未结 2 1610
攒了一身酷
攒了一身酷 2020-12-06 18:24

I have a long chunk of memory, say, 256 KiB or longer. I want to count the number of 1 bits in this entire chunk, or in other words: Add up the \"population count\" values f

相关标签:
2条回答
  • 2020-12-06 18:51

    Wojciech Muła's big-array popcnt functions look optimal except for the scalar cleanup loops. (See @einpoklum's answer for details on the main loops).

    A 256-entry LUT you use only a couple times at the end is likely to cache-miss, and isn't optimal for more than 1 byte even if cache was hot. I believe all AVX2 CPUs have hardware popcnt, and we can easily isolate the last up-to-8 bytes that haven't been counted yet to set us up for a single popcnt.

    As usual with SIMD algorithms, it often works well to do a full-width load that ends at the last byte of the buffer. But unlike with a vector register, variable-count shifts of the full integer register are cheap (especially with BMI2). Popcnt doesn't care where the bits are, so we can just use a shift instead of needing to construct an AND mask or whatever.

    // untested
    // ptr points at the first byte that hasn't been counted yet
    uint64_t final_bytes = reinterpret_cast<const uint64_t*>(end)[-1] >> (8*(end-ptr));
    total += _mm_popcnt_u64( final_bytes );
    // Careful, this could read outside a small buffer.
    

    Or even better, use more sophisticated logic to avoid page-crossing. This can avoid page-crossing for a 6-byte buffer at the start of a page, for example.

    0 讨论(0)
  • 2020-12-06 18:53

    AVX-2

    @HadiBreis' comment links to an article on fast population-count with SSSE3, by Wojciech Muła; the article links to this GitHub repository; and the repository has the following AVX-2 implementation. It's based on a vectorized lookup instruction, and using a 16-value lookup table for the bit counts of nibbles.

    #   include <immintrin.h>
    #   include <x86intrin.h>
    
    std::uint64_t popcnt_AVX2_lookup(const uint8_t* data, const size_t n) {
    
        size_t i = 0;
    
        const __m256i lookup = _mm256_setr_epi8(
            /* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
            /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
            /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
            /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4,
    
            /* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
            /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
            /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
            /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4
        );
    
        const __m256i low_mask = _mm256_set1_epi8(0x0f);
    
        __m256i acc = _mm256_setzero_si256();
    
    #define ITER { \
            const __m256i vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data + i)); \
            const __m256i lo  = _mm256_and_si256(vec, low_mask); \
            const __m256i hi  = _mm256_and_si256(_mm256_srli_epi16(vec, 4), low_mask); \
            const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo); \
            const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi); \
            local = _mm256_add_epi8(local, popcnt1); \
            local = _mm256_add_epi8(local, popcnt2); \
            i += 32; \
        }
    
        while (i + 8*32 <= n) {
            __m256i local = _mm256_setzero_si256();
            ITER ITER ITER ITER
            ITER ITER ITER ITER
            acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
        }
    
        __m256i local = _mm256_setzero_si256();
    
        while (i + 32 <= n) {
            ITER;
        }
    
        acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
    
    #undef ITER
    
        uint64_t result = 0;
    
        result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 0));
        result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 1));
        result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 2));
        result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 3));
    
        for (/**/; i < n; i++) {
            result += lookup8bit[data[i]];
        }
    
        return result;
    }
    

    AVX-512

    The same repository also has a VPOPCNT-based AVX-512 implementation:

    #   include <immintrin.h>
    #   include <x86intrin.h>
    
    uint64_t avx512_vpopcnt(const uint8_t* data, const size_t size) {
    
        const size_t chunks = size / 64;
    
        uint8_t* ptr = const_cast<uint8_t*>(data);
        const uint8_t* end = ptr + size;
    
        // count using AVX512 registers
        __m512i accumulator = _mm512_setzero_si512();
        for (size_t i=0; i < chunks; i++, ptr += 64) {
    
            // Note: a short chain of dependencies, likely unrolling will be needed.
            const __m512i v = _mm512_loadu_si512((const __m512i*)ptr);
            const __m512i p = _mm512_popcnt_epi64(v);
    
            accumulator = _mm512_add_epi64(accumulator, p);
        }
    
        // horizontal sum of a register
        uint64_t tmp[8] __attribute__((aligned(64)));
        _mm512_store_si512((__m512i*)tmp, accumulator);
    
        uint64_t total = 0;
        for (size_t i=0; i < 8; i++) {
            total += tmp[i];
        }
    
        // popcount the tail
        while (ptr + 8 < end) {
            total += _mm_popcnt_u64(*reinterpret_cast<const uint64_t*>(ptr));
            ptr += 8;
        }
    
        while (ptr < end) {
            total += lookup8bit[*ptr++];
        }
    
        return total;
    }
    

    The lookup8bit is a popcnt lookup table for bytes rather than bits, and is defined here. edit: As commenters note, using an 8-bit lookup table at the end is not a very good idea and can be improved on.

    0 讨论(0)
提交回复
热议问题