Efficient 4x4 matrix multiplication (C vs assembly)

寵の児 提交于 2019-11-27 11:30:18
Z boson

4x4 matrix multiplication is 64 multiplications and 48 additions. Using SSE this can be reduced to 16 multiplications and 12 additions (and 16 broadcasts). The following code will do this for you. It only requires SSE (#include <xmmintrin.h>). The arrays A, B, and C need to be 16 byte aligned. Using horizontal instructions such as hadd (SSE3) and dpps (SSE4.1) will be less efficient (especially dpps). I don't know if loop unrolling will help.

void M4x4_SSE(float *A, float *B, float *C) {
    __m128 row1 = _mm_load_ps(&B[0]);
    __m128 row2 = _mm_load_ps(&B[4]);
    __m128 row3 = _mm_load_ps(&B[8]);
    __m128 row4 = _mm_load_ps(&B[12]);
    for(int i=0; i<4; i++) {
        __m128 brod1 = _mm_set1_ps(A[4*i + 0]);
        __m128 brod2 = _mm_set1_ps(A[4*i + 1]);
        __m128 brod3 = _mm_set1_ps(A[4*i + 2]);
        __m128 brod4 = _mm_set1_ps(A[4*i + 3]);
        __m128 row = _mm_add_ps(
                    _mm_add_ps(
                        _mm_mul_ps(brod1, row1),
                        _mm_mul_ps(brod2, row2)),
                    _mm_add_ps(
                        _mm_mul_ps(brod3, row3),
                        _mm_mul_ps(brod4, row4)));
        _mm_store_ps(&C[4*i], row);
    }
}
Krzysztof Abramowicz

There is a way to accelerate the code and outplay the compiler. It does not involve any sophisticated pipeline analysis or deep code micro-optimisation (which doesn't mean that it couldn't further benefit from these). The optimisation uses three simple tricks:

  1. The function is now 32-byte aligned (which significantly boosted performance),

  2. Main loop goes inversely, which reduces comparison to a zero test (based on EFLAGS),

  3. Instruction-level address arithmetic proved to be faster than the "external" pointer calculation (even though it requires twice as much additions «in 3/4 cases»). It shortened the loop body by four instructions and reduced data dependencies within its execution path. See related question.

Additionally, the code uses a relative jump syntax which suppresses symbol redefinition error, which occurs when GCC tries to inline it (after being placed within asm statement and compiled with -O3).

    .text
    .align 32                           # 1. function entry alignment
    .globl matrixMultiplyASM            #    (for a faster call)
    .type matrixMultiplyASM, @function
matrixMultiplyASM:
    movaps   (%rdi), %xmm0
    movaps 16(%rdi), %xmm1
    movaps 32(%rdi), %xmm2
    movaps 48(%rdi), %xmm3
    movq $48, %rcx                      # 2. loop reversal
1:                                      #    (for simpler exit condition)
    movss (%rsi, %rcx), %xmm4           # 3. extended address operands
    shufps $0, %xmm4, %xmm4             #    (faster than pointer calculation)
    mulps %xmm0, %xmm4
    movaps %xmm4, %xmm5
    movss 4(%rsi, %rcx), %xmm4
    shufps $0, %xmm4, %xmm4
    mulps %xmm1, %xmm4
    addps %xmm4, %xmm5
    movss 8(%rsi, %rcx), %xmm4
    shufps $0, %xmm4, %xmm4
    mulps %xmm2, %xmm4
    addps %xmm4, %xmm5
    movss 12(%rsi, %rcx), %xmm4
    shufps $0, %xmm4, %xmm4
    mulps %xmm3, %xmm4
    addps %xmm4, %xmm5
    movaps %xmm5, (%rdx, %rcx)
    subq $16, %rcx                      # one 'sub' (vs 'add' & 'cmp')
    jge 1b                              # SF=OF, idiom: jump if positive
    ret

This is the fastest x86-64 implementation I have seen so far. I will appreciate, vote up and accept any answer providing a faster piece of assembly for that purpose!

I wonder if transposing one of the matrices may be beneficial.

Consider how we multiply the following two matrices ...

A1 A2 A3 A4        W1 W2 W3 W4
B1 B2 B3 B4        X1 X2 X3 X4
C1 C2 C3 C4    *   Y1 Y2 Y3 Y4
D1 D2 D3 D4        Z1 Z2 Z3 Z4

This would result in ...

dot(A,?1) dot(A,?2) dot(A,?3) dot(A,?4)
dot(B,?1) dot(B,?2) dot(B,?3) dot(B,?4)
dot(C,?1) dot(C,?2) dot(C,?3) dot(C,?4)
dot(D,?1) dot(D,?2) dot(D,?3) dot(D,?4)

Doing the dot product of a row and a column is a pain.

What if we transposed the second matrix before we multiplied?

A1 A2 A3 A4        W1 X1 Y1 Z1
B1 B2 B3 B4        W2 X2 Y2 Z2
C1 C2 C3 C4    *   W3 X3 Y3 Z3
D1 D2 D3 D4        W4 X4 Y4 Z4

Now instead doing the dot product of a row and column, we are doing the dot product of two rows. This could lend itself to better use of the SIMD instructions.

Hope this helps.

Sandy Bridge an above extend the instruction set to support 8 element vector arithmetic. Consider this implementation.

struct MATRIX {
    union {
        float  f[4][4];
        __m128 m[4];
        __m256 n[2];
    };
};
MATRIX myMultiply(MATRIX M1, MATRIX M2) {
    // Perform a 4x4 matrix multiply by a 4x4 matrix 
    // Be sure to run in 64 bit mode and set right flags
    // Properties, C/C++, Enable Enhanced Instruction, /arch:AVX 
    // Having MATRIX on a 32 byte bundry does help performance
    MATRIX mResult;
    __m256 a0, a1, b0, b1;
    __m256 c0, c1, c2, c3, c4, c5, c6, c7;
    __m256 t0, t1, u0, u1;

    t0 = M1.n[0];                                                   // t0 = a00, a01, a02, a03, a10, a11, a12, a13
    t1 = M1.n[1];                                                   // t1 = a20, a21, a22, a23, a30, a31, a32, a33
    u0 = M2.n[0];                                                   // u0 = b00, b01, b02, b03, b10, b11, b12, b13
    u1 = M2.n[1];                                                   // u1 = b20, b21, b22, b23, b30, b31, b32, b33

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(0, 0, 0, 0));        // a0 = a00, a00, a00, a00, a10, a10, a10, a10
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(0, 0, 0, 0));        // a1 = a20, a20, a20, a20, a30, a30, a30, a30
    b0 = _mm256_permute2f128_ps(u0, u0, 0x00);                      // b0 = b00, b01, b02, b03, b00, b01, b02, b03  
    c0 = _mm256_mul_ps(a0, b0);                                     // c0 = a00*b00  a00*b01  a00*b02  a00*b03  a10*b00  a10*b01  a10*b02  a10*b03
    c1 = _mm256_mul_ps(a1, b0);                                     // c1 = a20*b00  a20*b01  a20*b02  a20*b03  a30*b00  a30*b01  a30*b02  a30*b03

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(1, 1, 1, 1));        // a0 = a01, a01, a01, a01, a11, a11, a11, a11
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(1, 1, 1, 1));        // a1 = a21, a21, a21, a21, a31, a31, a31, a31
    b0 = _mm256_permute2f128_ps(u0, u0, 0x11);                      // b0 = b10, b11, b12, b13, b10, b11, b12, b13
    c2 = _mm256_mul_ps(a0, b0);                                     // c2 = a01*b10  a01*b11  a01*b12  a01*b13  a11*b10  a11*b11  a11*b12  a11*b13
    c3 = _mm256_mul_ps(a1, b0);                                     // c3 = a21*b10  a21*b11  a21*b12  a21*b13  a31*b10  a31*b11  a31*b12  a31*b13

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(2, 2, 2, 2));        // a0 = a02, a02, a02, a02, a12, a12, a12, a12
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(2, 2, 2, 2));        // a1 = a22, a22, a22, a22, a32, a32, a32, a32
    b1 = _mm256_permute2f128_ps(u1, u1, 0x00);                      // b0 = b20, b21, b22, b23, b20, b21, b22, b23
    c4 = _mm256_mul_ps(a0, b1);                                     // c4 = a02*b20  a02*b21  a02*b22  a02*b23  a12*b20  a12*b21  a12*b22  a12*b23
    c5 = _mm256_mul_ps(a1, b1);                                     // c5 = a22*b20  a22*b21  a22*b22  a22*b23  a32*b20  a32*b21  a32*b22  a32*b23

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(3, 3, 3, 3));        // a0 = a03, a03, a03, a03, a13, a13, a13, a13
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(3, 3, 3, 3));        // a1 = a23, a23, a23, a23, a33, a33, a33, a33
    b1 = _mm256_permute2f128_ps(u1, u1, 0x11);                      // b0 = b30, b31, b32, b33, b30, b31, b32, b33
    c6 = _mm256_mul_ps(a0, b1);                                     // c6 = a03*b30  a03*b31  a03*b32  a03*b33  a13*b30  a13*b31  a13*b32  a13*b33
    c7 = _mm256_mul_ps(a1, b1);                                     // c7 = a23*b30  a23*b31  a23*b32  a23*b33  a33*b30  a33*b31  a33*b32  a33*b33

    c0 = _mm256_add_ps(c0, c2);                                     // c0 = c0 + c2 (two terms, first two rows)
    c4 = _mm256_add_ps(c4, c6);                                     // c4 = c4 + c6 (the other two terms, first two rows)
    c1 = _mm256_add_ps(c1, c3);                                     // c1 = c1 + c3 (two terms, second two rows)
    c5 = _mm256_add_ps(c5, c7);                                     // c5 = c5 + c7 (the other two terms, second two rose)

                                                                    // Finally complete addition of all four terms and return the results
    mResult.n[0] = _mm256_add_ps(c0, c4);       // n0 = a00*b00+a01*b10+a02*b20+a03*b30  a00*b01+a01*b11+a02*b21+a03*b31  a00*b02+a01*b12+a02*b22+a03*b32  a00*b03+a01*b13+a02*b23+a03*b33
                                                //      a10*b00+a11*b10+a12*b20+a13*b30  a10*b01+a11*b11+a12*b21+a13*b31  a10*b02+a11*b12+a12*b22+a13*b32  a10*b03+a11*b13+a12*b23+a13*b33
    mResult.n[1] = _mm256_add_ps(c1, c5);       // n1 = a20*b00+a21*b10+a22*b20+a23*b30  a20*b01+a21*b11+a22*b21+a23*b31  a20*b02+a21*b12+a22*b22+a23*b32  a20*b03+a21*b13+a22*b23+a23*b33
                                                //      a30*b00+a31*b10+a32*b20+a33*b30  a30*b01+a31*b11+a32*b21+a33*b31  a30*b02+a31*b12+a32*b22+a33*b32  a30*b03+a31*b13+a32*b23+a33*b33
    return mResult;
}

Obviously you can fetch terms from four matrices at a time and multiply four matrices simultaneously using the same algorithm.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!