Speedup a short to float cast?

后端 未结 7 2044
孤独总比滥情好
孤独总比滥情好 2020-12-17 01:39

I have a short to float cast in C++ that is bottlenecking my code.

The code translates from a hardware device buffer which is natively shorts, this represents the in

7条回答
  •  青春惊慌失措
    2020-12-17 01:51

    Here's a basic SSE4.1 implementation:

    __m128 factor = _mm_set1_ps(1.0f / value);
    for (int i = 0; i < W*H; i += 8)
    {
        //  Load 8 16-bit ushorts.
        //  vi = {a,b,c,d,e,f,g,h}
        __m128i vi = _mm_load_si128((const __m128i*)(source + i));
    
        //  Convert to 32-bit integers
        //  vi0 = {a,0,b,0,c,0,d,0}
        //  vi1 = {e,0,f,0,g,0,h,0}
        __m128i vi0 = _mm_cvtepu16_epi32(vi);
        __m128i vi1 = _mm_cvtepu16_epi32(_mm_unpackhi_epi64(vi,vi));
    
        //  Convert to float
        __m128 vf0 = _mm_cvtepi32_ps(vi0);
        __m128 vf1 = _mm_cvtepi32_ps(vi1);
    
        //  Multiply
        vf0 = _mm_mul_ps(vf0,factor);
        vf1 = _mm_mul_ps(vf1,factor);
    
        //  Store
        _mm_store_ps(destination + i + 0,vf0);
        _mm_store_ps(destination + i + 4,vf1);
    }
    

    This assumes:

    1. source and destination are both aligned to 16 bytes.
    2. W*H is a multiple of 8.

    It's possible to do better by further unrolling this loop. (see below)


    The idea here is as follows:

    1. Load 8 shorts into a single SSE register.
    2. Split the register into two: One with the bottom 4 shorts and the other with the top 4 shorts.
    3. Zero-extend both registers into 32-bit integers.
    4. Convert them both to floats.
    5. Multiply by the factor.
    6. Store them into destination.

    EDIT :

    It's been a while since I've done this type of optimization, so I went ahead and unrolled the loops.

    Core i7 920 @ 3.5 GHz
    Visual Studio 2012 - Release x64:

    Original Loop      : 4.374 seconds
    Vectorize no unroll: 1.665
    Vectorize unroll 2 : 1.416
    

    Further unrolling resulted in diminishing returns.

    Here's the test code:

    #include 
    #include 
    #include 
    #include 
    using namespace std;
    
    
    void default_loop(float *destination,const short* source,float value,int size){
        float factor = 1.0f / value; 
        for (int i = 0; i < size; i++)
        {
            int value = source[i];
            destination[i] = value*factor;
        }
    }
    void vectorize8_unroll1(float *destination,const short* source,float value,int size){
        __m128 factor = _mm_set1_ps(1.0f / value);
        for (int i = 0; i < size; i += 8)
        {
            //  Load 8 16-bit ushorts.
            __m128i vi = _mm_load_si128((const __m128i*)(source + i));
    
            //  Convert to 32-bit integers
            __m128i vi0 = _mm_cvtepu16_epi32(vi);
            __m128i vi1 = _mm_cvtepu16_epi32(_mm_unpackhi_epi64(vi,vi));
    
            //  Convert to float
            __m128 vf0 = _mm_cvtepi32_ps(vi0);
            __m128 vf1 = _mm_cvtepi32_ps(vi1);
    
            //  Multiply
            vf0 = _mm_mul_ps(vf0,factor);
            vf1 = _mm_mul_ps(vf1,factor);
    
            //  Store
            _mm_store_ps(destination + i + 0,vf0);
            _mm_store_ps(destination + i + 4,vf1);
        }
    }
    void vectorize8_unroll2(float *destination,const short* source,float value,int size){
        __m128 factor = _mm_set1_ps(1.0f / value);
        for (int i = 0; i < size; i += 16)
        {
            __m128i a0 = _mm_load_si128((const __m128i*)(source + i + 0));
            __m128i a1 = _mm_load_si128((const __m128i*)(source + i + 8));
    
            //  Split into two registers
            __m128i b0 = _mm_unpackhi_epi64(a0,a0);
            __m128i b1 = _mm_unpackhi_epi64(a1,a1);
    
            //  Convert to 32-bit integers
            a0 = _mm_cvtepu16_epi32(a0);
            b0 = _mm_cvtepu16_epi32(b0);
            a1 = _mm_cvtepu16_epi32(a1);
            b1 = _mm_cvtepu16_epi32(b1);
    
            //  Convert to float
            __m128 c0 = _mm_cvtepi32_ps(a0);
            __m128 d0 = _mm_cvtepi32_ps(b0);
            __m128 c1 = _mm_cvtepi32_ps(a1);
            __m128 d1 = _mm_cvtepi32_ps(b1);
    
            //  Multiply
            c0 = _mm_mul_ps(c0,factor);
            d0 = _mm_mul_ps(d0,factor);
            c1 = _mm_mul_ps(c1,factor);
            d1 = _mm_mul_ps(d1,factor);
    
            //  Store
            _mm_store_ps(destination + i +  0,c0);
            _mm_store_ps(destination + i +  4,d0);
            _mm_store_ps(destination + i +  8,c1);
            _mm_store_ps(destination + i + 12,d1);
        }
    }
    void print_sum(const float *destination,int size){
        float sum = 0;
        for (int i = 0; i < size; i++){
            sum += destination[i];
        }
        cout << sum << endl;
    }
    
    int main(){
    
        int size = 8000;
    
        short *source       = (short*)_mm_malloc(size * sizeof(short), 16);
        float *destination  = (float*)_mm_malloc(size * sizeof(float), 16);
    
        for (int i = 0; i < size; i++){
            source[i] = i;
        }
    
        float value = 1.1;
    
        int iterations = 1000000;
        clock_t start;
    
        //  Default Loop
        start = clock();
        for (int it = 0; it < iterations; it++){
            default_loop(destination,source,value,size);
        }
        cout << (double)(clock() - start) / CLOCKS_PER_SEC << endl;
        print_sum(destination,size);
    
        //  Vectorize 8, no unroll
        start = clock();
        for (int it = 0; it < iterations; it++){
            vectorize8_unroll1(destination,source,value,size);
        }
        cout << (double)(clock() - start) / CLOCKS_PER_SEC << endl;
        print_sum(destination,size);
    
        //  Vectorize 8, unroll 2
        start = clock();
        for (int it = 0; it < iterations; it++){
            vectorize8_unroll2(destination,source,value,size);
        }
        cout << (double)(clock() - start) / CLOCKS_PER_SEC << endl;
        print_sum(destination,size);
    
        _mm_free(source);
        _mm_free(destination);
    
        system("pause");
    }
    

提交回复
热议问题