Using SIMD/AVX/SSE for tree traversal

前端 未结 2 643
感动是毒
感动是毒 2020-12-25 10:23

I am currently researching whether it would be possible to speed up a van Emde Boas (or any tree) tree traversal. Given a single search query as input, already having multip

2条回答
  •  难免孤独
    2020-12-25 11:03

    I've used SSE2/AVX2 to help perform a B+tree search. Here's code to perform a binary search on a full cache line of 16 DWORDs in AVX2:

    // perf-critical: ensure this is 64-byte aligned. (a full cache line)
    union bnode
    {
        int32_t i32[16];
        __m256i m256[2];
    };
    
    // returns from 0 (if value < i32[0]) to 16 (if value >= i32[15]) 
    unsigned bsearch_avx2(bnode const* const node, __m256i const value)
    {
        __m256i const perm_mask = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
    
        // compare the two halves of the cache line.
    
        __m256i cmp1 = _mm256_load_si256(&node->m256[0]);
        __m256i cmp2 = _mm256_load_si256(&node->m256[1]);
    
        cmp1 = _mm256_cmpgt_epi32(cmp1, value); // PCMPGTD
        cmp2 = _mm256_cmpgt_epi32(cmp2, value); // PCMPGTD
    
        // merge the comparisons back together.
        //
        // a permute is required to get the pack results back into order
        // because AVX-256 introduced that unfortunate two-lane interleave.
        //
        // alternately, you could pre-process your data to remove the need
        // for the permute.
    
        __m256i cmp = _mm256_packs_epi32(cmp1, cmp2); // PACKSSDW
        cmp = _mm256_permutevar8x32_epi32(cmp, perm_mask); // PERMD
    
        // finally create a move mask and count trailing
        // zeroes to get an index to the next node.
    
        unsigned mask = _mm256_movemask_epi8(cmp); // PMOVMSKB
        return _tzcnt_u32(mask) / 2; // TZCNT
    }
    

    You'll end up with a single highly predictable branch per bnode, to test if the end of the tree has been reached.

    This should be trivially scalable to AVX-512.

    To preprocess and get rid of that slow PERMD instruction, this would be used:

    void preprocess_avx2(bnode* const node)
    {
        __m256i const perm_mask = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
        __m256i *const middle = (__m256i*)&node->i32[4];
    
        __m256i x = _mm256_loadu_si256(middle);
        x = _mm256_permutevar8x32_epi32(x, perm_mask);
        _mm256_storeu_si256(middle, x);
    }
    

提交回复
热议问题