Reproduce _mm256_sllv_epi16 and _mm256_sllv_epi8 in AVX2

最后都变了- 提交于 2019-12-23 17:08:22

问题


I was surprised to see that _mm256_sllv_epi16/8(__m256i v1, __m256i v2) and _mm256_srlv_epi16/8(__m256i v1, __m256i v2) was not in the Intel Intrinsics Guide and I don't find any solution to recreate that AVX512 intrinsic with only AVX2.

This function left shifts all 16/8bits packed int by the count value of corresponding data elements in v2.

Example for epi16:

__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
v1 = _mm256_sllv_epi16(v1, v2);

Then v1 equal to -> (1111111111111111, 1111111111111110, 1111111111111100, 1111111111111000, ................, 1000000000000000);


回答1:


In the _mm256_sllv_epi8 case, it isn't too difficult to replace the shifts by multiplications, using the pshufb instruction as a tiny lookup table. It is also possible to emulate the right shifting of _mm256_srlv_epi8 with multiplications and quite a few other instructions, see the code below. I would expect that at least _mm256_sllv_epi8 is more efficient than Nyan's solution.


More or less the same idea can be used to emulate _mm256_sllv_epi16, but in that case it is less trivial to select the right multiplier (see also code below).

The solution _mm256_sllv_epi16_emu below is not necessarily any faster, nor better, than Nyan's solution. The performance depends on the surrounding code and on the CPU that is used. Nevertheless, the solution here might be of interest, at least on older computer systems. For example, the vpsllvd instruction is used twice in Nyan's solution. This instruction is fast on Intel Skylake systems or newer. On Intel Broadwell or Haswell this instruction is slow, because it decodes to 3 micro-ops. The solution here avoids this slow instruction.

It is possible to skip the two lines of code with mask_lt_15, if the shift counts are known to be less than or equal to 15.

Missing intrinsic _mm256_srlv_epi16 is left as an exercise to the reader.


/*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell shift_v_epi8.c     */
#include <immintrin.h>
#include <stdio.h>
int print_epi8(__m256i  a);
int print_epi16(__m256i  a);

__m256i _mm256_sllv_epi8(__m256i a, __m256i count) {
    __m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
    __m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);

    __m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i << n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
    __m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
    __m256i x_lo           = _mm256_mullo_epi16(a, multiplier);               /* Unfortunately _mm256_mullo_epi8 doesn't exist. Split the 16 bit elements in a high and low part. */

    __m256i multiplier_hi  = _mm256_srli_epi16(multiplier, 8);                /* The multiplier of the high bits.                                                                 */
    __m256i a_hi           = _mm256_and_si256(a, mask_hi);                    /* Mask off the low bits.                                                                           */
    __m256i x_hi           = _mm256_mullo_epi16(a_hi, multiplier_hi);
    __m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
            return x;
}


__m256i _mm256_srlv_epi8(__m256i a, __m256i count) {
    __m256i mask_hi        = _mm256_set1_epi32(0xFF00FF00);
    __m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128, 0,0,0,0, 0,0,0,0, 1,2,4,8, 16,32,64,128);

    __m256i count_sat      = _mm256_min_epu8(count, _mm256_set1_epi8(8));     /* AVX shift counts are not masked. So a_i >> n_i = 0 for n_i >= 8. count_sat is always less than 9.*/ 
    __m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count_sat);  /* Select the right multiplication factor in the lookup table.                                      */
    __m256i a_lo           = _mm256_andnot_si256(mask_hi, a);                 /* Mask off the high bits.                                                                          */
    __m256i multiplier_lo  = _mm256_andnot_si256(mask_hi, multiplier);        /* The multiplier of the low bits.                                                                  */
    __m256i x_lo           = _mm256_mullo_epi16(a_lo, multiplier_lo);         /* Shift left a_lo by multiplying.                                                                  */
            x_lo           = _mm256_srli_epi16(x_lo, 7);                      /* Shift right by 7 to get the low bits at the right position.                                      */

    __m256i multiplier_hi  = _mm256_and_si256(mask_hi, multiplier);           /* The multiplier of the high bits.                                                                 */
    __m256i x_hi           = _mm256_mulhi_epu16(a, multiplier_hi);            /* Variable shift left a_hi by multiplying. Use a instead of a_hi because the a_lo bits don't interfere */
            x_hi           = _mm256_slli_epi16(x_hi, 1);                      /* Shift left by 1 to get the high bits at the right position.                                      */
    __m256i x              = _mm256_blendv_epi8(x_lo, x_hi, mask_hi);         /* Merge the high and low part.                                                                     */
            return x;
}


__m256i _mm256_sllv_epi16_emu(__m256i a, __m256i count) {
    __m256i multiplier_lut = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);
    __m256i byte_shuf_mask = _mm256_set_epi8(14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0, 14,14,12,12, 10,10,8,8, 6,6,4,4, 2,2,0,0);

    __m256i mask_lt_15     = _mm256_cmpgt_epi16(_mm256_set1_epi16(16), count);
            a              = _mm256_and_si256(mask_lt_15, a);                    /* Set a to zero if count > 15.                                                                      */
            count          = _mm256_shuffle_epi8(count, byte_shuf_mask);         /* Duplicate bytes from the even postions to bytes at the even and odd positions.                    */
            count          = _mm256_sub_epi8(count,_mm256_set1_epi16(0x0800));   /* Subtract 8 at the even byte positions. Note that the vpshufb instruction selects a zero byte if the shuffle control mask is negative.     */
    __m256i multiplier     = _mm256_shuffle_epi8(multiplier_lut, count);         /* Select the right multiplication factor in the lookup table. Within the 16 bit elements, only the upper byte or the lower byte is nonzero. */
    __m256i x              = _mm256_mullo_epi16(a, multiplier);                  
            return x;
}


int main(){

    printf("Emulating _mm256_sllv_epi8:\n");
    __m256i a     = _mm256_set_epi8(32,31,30,29, 28,27,26,25, 24,23,22,21, 20,19,18,17, 16,15,14,13, 12,11,10,9, 8,7,6,5, 4,3,2,1);
    __m256i count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
    __m256i x     = _mm256_sllv_epi8(a, count);
    printf("a     = \n"); print_epi8(a    );
    printf("count = \n"); print_epi8(count);
    printf("x     = \n"); print_epi8(x    );
    printf("\n\n"); 


    printf("Emulating _mm256_srlv_epi8:\n");
            a     = _mm256_set_epi8(223,224,225,226, 227,228,229,230, 231,232,233,234, 235,236,237,238, 239,240,241,242, 243,244,245,246, 247,248,249,250, 251,252,253,254);
            count = _mm256_set_epi8(7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0,  11,10,9,8, 7,6,5,4, 3,2,1,0);
            x     = _mm256_srlv_epi8(a, count);
    printf("a     = \n"); print_epi8(a    );
    printf("count = \n"); print_epi8(count);
    printf("x     = \n"); print_epi8(x    );
    printf("\n\n"); 



    printf("Emulating _mm256_sllv_epi16:\n");
            a     = _mm256_set_epi16(1601,1501,1401,1301, 1200,1100,1000,900, 800,700,600,500, 400,300,200,100);
            count = _mm256_set_epi16(17,16,15,13,  11,10,9,8, 7,6,5,4, 3,2,1,0);
            x     = _mm256_sllv_epi16_emu(a, count);
    printf("a     = \n"); print_epi16(a    );
    printf("count = \n"); print_epi16(count);
    printf("x     = \n"); print_epi16(x    );
    printf("\n\n"); 

    return 0;
}


int print_epi8(__m256i  a){
  char v[32];
  int i;
  _mm256_storeu_si256((__m256i *)v,a);
  for (i = 0; i<32; i++) printf("%4hhu",v[i]);
  printf("\n");
  return 0;
}

int print_epi16(__m256i  a){
  unsigned short int  v[16];
  int i;
  _mm256_storeu_si256((__m256i *)v,a);
  for (i = 0; i<16; i++) printf("%6hu",v[i]);
  printf("\n");
  return 0;
}

The output is:

Emulating _mm256_sllv_epi8:
a     = 
   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32
count = 
   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
   1   4  12  32  80 192 192   0   0   0   0   0  13  28  60 128  16  64 192   0   0   0   0   0  25  52 108 224 208 192 192   0


Emulating _mm256_srlv_epi8:
a     = 
 254 253 252 251 250 249 248 247 246 245 244 243 242 241 240 239 238 237 236 235 234 233 232 231 230 229 228 227 226 225 224 223
count = 
   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7   8   9  10  11   0   1   2   3   4   5   6   7
x     = 
 254 126  63  31  15   7   3   1   0   0   0   0 242 120  60  29  14   7   3   1   0   0   0   0 230 114  57  28  14   7   3   1


Emulating _mm256_sllv_epi16:
a     = 
   100   200   300   400   500   600   700   800   900  1000  1100  1200  1301  1401  1501  1601
count = 
     0     1     2     3     4     5     6     7     8     9    10    11    13    15    16    17
x     = 
   100   400  1200  3200  8000 19200 44800 36864 33792 53248 12288 32768 40960 32768     0     0

Indeed some AVX2 instructions are missing. However, note that it is not always a good idea fill these gaps by emulating the 'missing' AVX2 instructions. Sometimes it is more efficient to redesign your code in such a way that these emulated instructions are avoided. For example, by working with wider vector elements (_epi32 instead of _epi16), with native support.




回答2:


It's strange that they missed that, though it seems many AVX integer instructions are only available for 32/64-bit widths. At least 16-bit got added in AVX512BW (though I still don't get why Intel refuses to add 8-bit shifts).

We can emulate 16-bit variable shifts using only AVX2 by using 32-bit variable shifts with some masking and blending.

We need the right shift count at the bottom of the 32-bit element containing each 16-bit element, which we can do with an AND (for the low element) and an immediate shift for the high half. (Unlike scalar shifts, x86 vector shifts saturate their count instead of wrapping/masking).

We also need to mask off the low 16 bits of of data before doing the high-half shift, so we aren't shifting garbage into the high 16-bit half of the containing 32-bit element.

__m256i _mm256_sllv_epi16(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi32(0xffff0000);
    __m256i low_half = _mm256_sllv_epi32(
        a,
        _mm256_andnot_si256(mask, count)
    );
    __m256i high_half = _mm256_sllv_epi32(
        _mm256_and_si256(mask, a),
        _mm256_srli_epi32(count, 16)
    );
    return _mm256_blend_epi16(low_half, high_half, 0xaa);
}
__m256i _mm256_sllv_epi16(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi32(0xffff0000); // alternating low/high words of a dword
    // shift low word of each dword: low_half = (a << (count & 0xffff)) [for each 32b element]
    // note that, because `a` isn't being masked here, we may get some "junk" bits, but these will get eliminated by the blend below
    __m256i low_half = _mm256_sllv_epi32(
        a,
        _mm256_andnot_si256(mask, count)
    );
    // shift high word of each dword: high_half = ((a & 0xffff0000) << (count >> 16)) [for each 32b element]
    __m256i high_half = _mm256_sllv_epi32(
        _mm256_and_si256(mask, a),     // make sure we shift in zeros
        _mm256_srli_epi32(count, 16)   // need the high-16 count at the bottom of a 32-bit element
    );
    // combine low and high words
    return _mm256_blend_epi16(low_half, high_half, 0xaa);
}

__m256i _mm256_srlv_epi16(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi32(0x0000ffff);
    __m256i low_half = _mm256_srlv_epi32(
        _mm256_and_si256(mask, a),
        _mm256_and_si256(mask, count)
    );
    __m256i high_half = _mm256_srlv_epi32(
        a,
        _mm256_srli_epi32(count, 16)
    );
    return _mm256_blend_epi16(low_half, high_half, 0xaa);
}

GCC 8.2 compiles this to more-or-less what you'd expect:

_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
        vmovdqa       ymm3, YMMWORD PTR .LC0[rip]
        vpand   ymm2, ymm0, ymm3
        vpand   ymm3, ymm1, ymm3
        vpsrld  ymm1, ymm1, 16
        vpsrlvd ymm2, ymm2, ymm3
        vpsrlvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm2, ymm0, 170
        ret
_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
        vmovdqa       ymm3, YMMWORD PTR .LC1[rip]
        vpandn  ymm2, ymm3, ymm1
        vpsrld  ymm1, ymm1, 16
        vpsllvd ymm2, ymm0, ymm2
        vpand   ymm0, ymm0, ymm3
        vpsllvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm2, ymm0, 170
        ret

Meaning that the emulation results in 1x load + 2x AND/ANDN + 2x variable-shift + 1x right-shift + 1x blend.

Clang 6.0 does something interesting - it eliminates the memory load (and corresponding masking) by using blends:

_mm256_sllv_epi16(long long __vector(4), long long __vector(4)):
        vpxor   xmm2, xmm2, xmm2
        vpblendw        ymm3, ymm1, ymm2, 170
        vpsllvd ymm3, ymm0, ymm3
        vpsrld  ymm1, ymm1, 16
        vpblendw        ymm0, ymm2, ymm0, 170
        vpsllvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm3, ymm0, 170
        ret
_mm256_srlv_epi16(long long __vector(4), long long __vector(4)):
        vpxor   xmm2, xmm2, xmm2
        vpblendw        ymm3, ymm0, ymm2, 170
        vpblendw        ymm2, ymm1, ymm2, 170
        vpsrlvd ymm2, ymm3, ymm2
        vpsrld  ymm1, ymm1, 16
        vpsrlvd ymm0, ymm0, ymm1
        vpblendw        ymm0, ymm2, ymm0, 170
        ret

This results in: 1x clear + 3x blend + 2x variable-shift + 1x right-shift.

I haven't done any benchmarking as to which approach is faster, but I suspect it may depend on the CPU, in particular, the cost of a PBLENDW on the CPU.

Of course, if your use case is a little more constrained, the above could be simplified, e.g. if your shift amounts are all constants, you could remove the masking/shifting needed to get that to work (assuming the compiler doesn't do this automatically for you).
For left shift, if the shift amounts are constant, you could use _mm256_mullo_epi16 instead, converting the shift amounts to something that can be multiplied, e.g. for the example you gave:

__m256i v1 = _mm256_set1_epi16(0b1111111111111111);
__m256i v2 = _mm256_setr_epi16(1<<0,1<<1,1<<2,1<<3,1<<4,1<<5,1<<6,1<<7,1<<8,1<<9,1<<10,1<<11,1<<12,1<<13,1<<14,1<<15);
v1 = _mm256_mullo_epi16(v1, v2);

Update: Peter mentions (see comment below) that right-shift can also be implemented with _mm256_mulhi_epi16 (e.g. to perform v>>1 multiply v by 1<<15 and take the high word).


For 8-bit variable shifts, this doesn't exist in AVX512 either (again, I don't know why Intel doesn't have 8-bit SIMD shifts).
If AVX512BW is available, you could use a similar trick to the above, using _mm256_sllv_epi16. For AVX2, I can't think of a particularly better approach than applying the emulation for 16-bit a second time, as you ultimately have to do 4x the shifting of what the 32-bit shift gives you. See @wim's answer for a nice solution for 8-bit in AVX2.

This is what I came up with (basically 16-bit version adopted for 8-bit on AVX512):

__m256i _mm256_sllv_epi8(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi16(0xff00);
    __m256i low_half = _mm256_sllv_epi16(
        a,
        _mm256_andnot_si256(mask, count)
    );
    __m256i high_half = _mm256_sllv_epi16(
        _mm256_and_si256(mask, a),
        _mm256_srli_epi16(count, 8)
    );
    return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));
}

__m256i _mm256_srlv_epi8(__m256i a, __m256i count) {
    const __m256i mask = _mm256_set1_epi16(0x00ff);
    __m256i low_half = _mm256_srlv_epi16(
        _mm256_and_si256(mask, a),
        _mm256_and_si256(mask, count)
    );
    __m256i high_half = _mm256_srlv_epi16(
        a,
        _mm256_srli_epi16(count, 8)
    );
    return _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00));
}

(Peter Cordes mentions below that _mm256_blendv_epi8(low_half, high_half, _mm256_set1_epi16(0xff00)) can be replaced with _mm256_mask_blend_epi8(0xaaaaaaaa, low_half, high_half) in a pure AVX512BW(+VL) implementation, which is likely faster)



来源:https://stackoverflow.com/questions/51789685/reproduce-mm256-sllv-epi16-and-mm256-sllv-epi8-in-avx2

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!