Can counting byte matches between two strings be optimized using SIMD?

前端 未结 3 2092
悲哀的现实
悲哀的现实 2021-01-05 01:34

Profiling suggests that this function here is a real bottle neck for my application:

static inline int countEqualChars(const char* string1, const char* strin         


        
相关标签:
3条回答
  • 2021-01-05 02:11

    Compiler flags for vectorization:

    -ftree-vectorize

    -ftree-vectorize -march=<your_architecture> (Use all instruction-set extensions available on your computer, not just baseline like SSE2 for x86-64). Use -march=native to optimize for the machine the compiler is running on.) -march=<foo> also sets -mtune=<foo>, which is also a good thing.

    Using SSEx intrinsics:

    • Padd and align the buffer to 16 bytes (according to the vector size you're actually going to use)

    • Create an accumlator countU8 with _mm_set1_epi8(0)

    • For all n/16 input (sub) vectors, do:

      • Load 16 chars from both strings with _mm_load_si128 or _mm_loadu_si128 (for unaligned loads)

      • _mm_cmpeq_epi8 compare the octets in parallel. Each match yields 0xFF (-1), 0x00 otherwise.

      • Substract the above result vector from countU8 using _mm_sub_epi8 (minus -1 -> +1)

      • Always after 255 cycles, the 16 8bit counters must be extracted into a larger integer type to prevent overflows. See unpack and horizontal add in this nice answer for how to do that: https://stackoverflow.com/a/10930706/1175253

    Code:

    #include <iostream>
    #include <vector>
    
    #include <cassert>
    #include <cstdint>
    #include <climits>
    #include <cstring>
    
    #include <emmintrin.h>
    
    #ifdef __SSE2__
    
    #if !defined(UINTPTR_MAX) ||  !defined(UINT64_MAX) ||  !defined(UINT32_MAX)
    #  error "Limit macros are not defined"
    #endif
    
    #if UINTPTR_MAX == UINT64_MAX
        #define PTR_64
    #elif UINTPTR_MAX == UINT32_MAX
        #define PTR_32
    #else
    #  error "Current UINTPTR_MAX is not supported"
    #endif
    
    template<typename T>
    void print_vector(std::ostream& out,const __m128i& vec)
    {
        static_assert(sizeof(vec) % sizeof(T) == 0,"Invalid element size");
        std::cout << '{';
        const T* const end   = reinterpret_cast<const T*>(&vec)-1;
        const T* const upper = end+(sizeof(vec)/sizeof(T));
        for(const T* elem = upper;
            elem != end;
            --elem
        )
        {
            if(elem != upper)
                std::cout << ',';
            std::cout << +(*elem);
        }
        std::cout << '}' << std::endl;
    }
    
    #define PRINT_VECTOR(_TYPE,_VEC) do{  std::cout << #_VEC << " : "; print_vector<_TYPE>(std::cout,_VEC);    } while(0)
    
    ///@note SSE2 required (macro: __SSE2__)
    ///@warning Not tested!
    size_t counteq_epi8(const __m128i* a_in,const __m128i* b_in,size_t count)
    {
        assert(a_in != nullptr && (uintptr_t(a_in) % 16) == 0);
        assert(b_in != nullptr && (uintptr_t(b_in) % 16) == 0);
        //assert(count > 0);
    
    
    /*
        //maybe not so good with all that branching and additional loop variables
    
        __m128i accumulatorU8 = _mm_set1_epi8(0);
        __m128i sum2xU64 = _mm_set1_epi8(0);
        for(size_t i = 0;i < count;++i)
        {
    
            //this operation could also be unrolled, where multiple result registers would be accumulated
            accumulatorU8 = _mm_sub_epi8(accumulatorU8,_mm_cmpeq_epi8(*a_in++,*b_in++));
            if(i % 255 == 0)
            {
                //before overflow of uint8, the counter will be extracted
                __m128i sum2xU16 = _mm_sad_epu8(accumulatorU8,_mm_set1_epi8(0));
                sum2xU64 = _mm_add_epi64(sum2xU64,sum2xU16);
    
                //reset accumulatorU8
                accumulatorU8 = _mm_set1_epi8(0);
            }
        }
    
        //blindly accumulate remaining values
        __m128i sum2xU16 = _mm_sad_epu8(accumulatorU8,_mm_set1_epi8(0));
        sum2xU64 = _mm_add_epi64(sum2xU64,sum2xU16);
    
        //do a horizontal addition of the two counter values
        sum2xU64 = _mm_add_epi64(sum2xU64,_mm_srli_si128(sum2xU64,64/8));
    
    #if defined PTR_64
        return _mm_cvtsi128_si64(sum2xU64);
    #elif defined PTR_32
        return _mm_cvtsi128_si32(sum2xU64);
    #else
    #  error "macro PTR_(32|64) is not set"
    #endif
    
    */
    
        __m128i sum2xU64 = _mm_set1_epi32(0);
        while(count--)
        {
            __m128i matches     = _mm_sub_epi8(_mm_set1_epi32(0),_mm_cmpeq_epi8(*a_in++,*b_in++));
            __m128i sum2xU16    = _mm_sad_epu8(matches,_mm_set1_epi32(0));
                    sum2xU64    = _mm_add_epi64(sum2xU64,sum2xU16);
    #ifndef NDEBUG
            PRINT_VECTOR(uint16_t,sum2xU64);
    #endif
        }
    
        //do a horizontal addition of the two counter values
        sum2xU64 = _mm_add_epi64(sum2xU64,_mm_srli_si128(sum2xU64,64/8));
    #ifndef NDEBUG
        std::cout << "----------------------------------------" << std::endl;
        PRINT_VECTOR(uint16_t,sum2xU64);
    #endif
    
    #if !defined(UINTPTR_MAX) ||  !defined(UINT64_MAX) ||  !defined(UINT32_MAX)
    #  error "Limit macros are not defined"
    #endif
    
    #if defined PTR_64
        return _mm_cvtsi128_si64(sum2xU64);
    #elif defined PTR_32
        return _mm_cvtsi128_si32(sum2xU64);
    #else
    #  error "macro PTR_(32|64) is not set"
    #endif
    
    }
    
    #endif
    
    int main(int argc, char* argv[])
    {
    
        std::vector<__m128i> a(64); // * 16 bytes
        std::vector<__m128i> b(a.size());
        const size_t nBytes = a.size() * sizeof(std::vector<__m128i>::value_type);
    
        char* const a_out = reinterpret_cast<char*>(a.data());
        char* const b_out = reinterpret_cast<char*>(b.data());
    
        memset(a_out,0,nBytes);
        memset(b_out,0,nBytes);
    
        a_out[1023] = 1;
        b_out[1023] = 1;
    
        size_t equalBytes = counteq_epi8(a.data(),b.data(),a.size());
    
        std::cout << "equalBytes = " << equalBytes << std::endl;
    
        return 0;
    }
    

    The fastest SSE implementation I got for large and small arrays:

    size_t counteq_epi8(const __m128i* a_in,const __m128i* b_in,size_t count)
    {
        assert((count > 0 ? a_in != nullptr : true) && (uintptr_t(a_in) % sizeof(__m128i)) == 0);
        assert((count > 0 ? b_in != nullptr : true) && (uintptr_t(b_in) % sizeof(__m128i)) == 0);
        //assert(count > 0);
    
        const size_t maxInnerLoops    = 255;
        const size_t nNestedLoops     = count / maxInnerLoops;
        const size_t nRemainderLoops  = count % maxInnerLoops;
    
        const __m128i zero  = _mm_setzero_si128();
        __m128i sum16xU8    = zero;
        __m128i sum2xU64    = zero;
    
        for(size_t i = 0;i < nNestedLoops;++i)
        {
            for(size_t j = 0;j < maxInnerLoops;++j)
            {
                sum16xU8 = _mm_sub_epi8(sum16xU8,_mm_cmpeq_epi8(*a_in++,*b_in++));
            }
            sum2xU64 = _mm_add_epi64(sum2xU64,_mm_sad_epu8(sum16xU8,zero));
            sum16xU8 = zero;
        }
    
        for(size_t j = 0;j < nRemainderLoops;++j)
        {
            sum16xU8 = _mm_sub_epi8(sum16xU8,_mm_cmpeq_epi8(*a_in++,*b_in++));
        }
        sum2xU64 = _mm_add_epi64(sum2xU64,_mm_sad_epu8(sum16xU8,zero));
    
        sum2xU64 = _mm_add_epi64(sum2xU64,_mm_srli_si128(sum2xU64,64/8));
    
    #if UINTPTR_MAX == UINT64_MAX
        return _mm_cvtsi128_si64(sum2xU64);
    #elif UINTPTR_MAX == UINT32_MAX
        return _mm_cvtsi128_si32(sum2xU64);
    #else
    #  error "macro PTR_(32|64) is not set"
    #endif
    }
    
    0 讨论(0)
  • 2021-01-05 02:14

    Auto-vectorization in current gcc is a matter of helping the compiler to understand that's easy to vectorize the code. In your case: it will understand the vectorization request if you remove the conditional and rewrite the code in a more imperative way:

        static inline int count(const char* string1, const char* string2, int size) {
                int r = 0;
                bool b;
    
                for (int j = 0; j < size; ++j) {
                        b = (string1[j] == string2[j]);
                        r += b;
                }
    
                return r;
        }
    

    In this case:

    movdqa  16(%rsp), %xmm1
    movl    $.LC2, %esi
    pxor    %xmm2, %xmm2
    movzbl  416(%rsp), %edx
    movdqa  .LC1(%rip), %xmm3
    pcmpeqb 224(%rsp), %xmm1
    cmpb    %dl, 208(%rsp)
    movzbl  417(%rsp), %eax
    movl    $1, %edi
    pand    %xmm3, %xmm1
    movdqa  %xmm1, %xmm5
    sete    %dl
    movdqa  %xmm1, %xmm4
    movzbl  %dl, %edx
    punpcklbw   %xmm2, %xmm5
    punpckhbw   %xmm2, %xmm4
    pxor    %xmm1, %xmm1
    movdqa  %xmm5, %xmm6
    movdqa  %xmm5, %xmm0
    movdqa  %xmm4, %xmm5
    punpcklwd   %xmm1, %xmm6
    

    (etc.)

    0 讨论(0)
  • 2021-01-05 02:26

    Of course it can.

    pcmpeqb compares two vectors of 16 bytes and produces a vector with zeros where they differed, and -1 where they match. Use this to compare 16 bytes at a time, adding the result to an accumulator vector (make sure to accumulate the results of at most 255 vector compares to avoid overflow). When you're done, there are 16 results in the accumulator. Sum them and negate to get the number of equal elements.

    If the lengths are very short, it will be hard to get a significant speedup from this approach. If the lengths are long, then it will be worth pursuing.

    0 讨论(0)
提交回复
热议问题