What's the fastest stride-3 gather instruction sequence?

后端 未结 2 1529
-上瘾入骨i
-上瘾入骨i 2020-12-09 19:03

The question:

What is the most efficient sequence to generate a stride-3 gather of 32-bit elements from memory? If the memory is arranged as:

MEM =         


        
2条回答
  •  旧时难觅i
    2020-12-09 19:26

    This article from Intel describes how to do exactly the 3x8 case that you want.

    That article covers the float case. If you want int32, you'll need to cast the outputs since there's no integer version of _mm256_shuffle_ps().

    Copying their solution verbatim:

    float *p;  // address of first vector
    __m128 *m = (__m128*) p;
    __m256 m03;
    __m256 m14; 
    __m256 m25; 
    m03  = _mm256_castps128_ps256(m[0]); // load lower halves
    m14  = _mm256_castps128_ps256(m[1]);
    m25  = _mm256_castps128_ps256(m[2]);
    m03  = _mm256_insertf128_ps(m03 ,m[3],1);  // load upper halves
    m14  = _mm256_insertf128_ps(m14 ,m[4],1);
    m25  = _mm256_insertf128_ps(m25 ,m[5],1);
    
    __m256 xy = _mm256_shuffle_ps(m14, m25, _MM_SHUFFLE( 2,1,3,2)); // upper x's and y's 
    __m256 yz = _mm256_shuffle_ps(m03, m14, _MM_SHUFFLE( 1,0,2,1)); // lower y's and z's
    __m256 x  = _mm256_shuffle_ps(m03, xy , _MM_SHUFFLE( 2,0,3,0)); 
    __m256 y  = _mm256_shuffle_ps(yz , xy , _MM_SHUFFLE( 3,1,2,0)); 
    __m256 z  = _mm256_shuffle_ps(yz , m25, _MM_SHUFFLE( 3,0,3,1)); 
    

    So this is 11 instructions. (6 loads, 5 shuffles)


    In the general case, it's possible to do an S x W transpose in O(S*log(W)) instructions. Where:

    • S is the stride
    • W is the SIMD width

    Assuming the existence of 2-vector permutes and half-vector insert-loads, then the formula becomes:

    (S x W load-permute)  <=  S * (lg(W) + 1) instructions
    

    Ignoring reg-reg moves. For degenerate cases like the 3 x 4, it may be possible to do better.

    Here's the 3 x 16 load-transpose with AVX512: (6 loads, 3 shuffles, 6 blends)

    FORCE_INLINE void transpose_f32_16x3_forward_AVX512(
        const float T[48],
        __m512& r0, __m512& r1, __m512& r2
    ){
        __m512 a0, a1, a2;
    
        //   0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
        //  16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
        //  32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
    
        a0 = _mm512_castps256_ps512(_mm256_loadu_ps(T +  0));
        a1 = _mm512_castps256_ps512(_mm256_loadu_ps(T +  8));
        a2 = _mm512_castps256_ps512(_mm256_loadu_ps(T + 16));
        a0 = _mm512_insertf32x8(a0, ((const __m256*)T)[3], 1);
        a1 = _mm512_insertf32x8(a1, ((const __m256*)T)[4], 1);
        a2 = _mm512_insertf32x8(a2, ((const __m256*)T)[5], 1);
    
        //   0  1  2  3  4  5  6  7 24 25 26 27 28 29 30 31
        //   8  9 10 11 12 13 14 15 32 33 34 35 36 37 38 39
        //  16 17 18 19 20 21 22 23 40 41 42 43 44 45 46 47
    
        r0 = _mm512_mask_blend_ps(0xf0f0, a0, a1);
        r1 = _mm512_permutex2var_ps(a0, _mm512_setr_epi32(  4,  5,  6,  7, 16, 17, 18, 19, 12, 13, 14, 15, 24, 25, 26, 27), a2);
        r2 = _mm512_mask_blend_ps(0xf0f0, a1, a2);
    
        //   0  1  2  3 12 13 14 15 24 25 26 27 36 37 38 39
        //   4  5  6  7 16 17 18 19 28 29 30 31 40 41 42 43
        //   8  9 10 11 20 21 22 23 32 33 34 35 44 45 46 47
    
        a0 = _mm512_mask_blend_ps(0xcccc, r0, r1);
        a1 = _mm512_shuffle_ps(r0, r2, 78);
        a2 = _mm512_mask_blend_ps(0xcccc, r1, r2);
    
        //   0  1  6  7 12 13 18 19 24 25 30 31 36 37 42 43
        //   2  3  8  9 14 15 20 21 26 27 32 33 38 39 44 45
        //   4  5 10 11 16 17 22 23 28 29 34 35 40 41 46 47
    
        r0 = _mm512_mask_blend_ps(0xaaaa, a0, a1);
        r1 = _mm512_permutex2var_ps(a0, _mm512_setr_epi32(  1,  16,  3, 18,  5, 20,  7, 22,  9, 24, 11, 26, 13, 28, 15, 30), a2);
        r2 = _mm512_mask_blend_ps(0xaaaa, a1, a2);
    
        //   0  3  6  9 12 15 18 21 24 27 30 33 36 39 42 45
        //   1  4  7 10 13 16 19 22 25 28 31 34 37 40 43 46
        //   2  5  8 11 14 17 20 23 26 29 32 35 38 41 44 47
    }
    

    The inverse 3 x 16 transpose-store will be left as an exercise to the reader.

    The pattern is not at all trivial to see since the S = 3 is somewhat degenerate. But if you can see the pattern, you'll be able to generalize this to any odd integer S as well as any power-of-two W.

提交回复
热议问题