Writing a fast linear system solver in OpenCL C

浪尽此生 提交于 2020-08-27 22:28:23

问题


I'm writing an OpenCL kernel which will involve solving a linear system. Currently my kernel is simply too slow, and improving the performance of the linear system portion seemed like a good place to start.

I should also note that I'm not trying make my linear solver parallel, the problem I'm working on is already embarassingly parallel at a macroscopic level.

The following is C code I wrote for solving Ax=b using Gaussian elimination with partial pivoting,

#import <stdio.h>
#import <math.h>
#import <time.h>

#define K 50

// Solve the system Ax=b using Gaussian elimination with partial pivoting.
void linear_solve(float A[K * K], float b[K])
{
    for (long j=0; j<K; j++)
    {
        // Begin partial pivoting.
        float maxval = fabs(A[K * j + j]);

        long maxrow = j;

        for (long i=j+1; i<K; i++)
        {
            if (fabs(A[K * j + i]) > maxval)
            {
                maxval = fabs(A[K * j + i]);
                maxrow = i;
            }
        }
            
        for (long l=0; l<K; l++)
        {
            float A_temp = A[K * l + maxrow];
            A[K * l + maxrow] = A[K * l + j];
            A[K * l + j] = A_temp;
        }

        float b_temp = b[maxrow];
        b[maxrow] = b[j];
        b[j] = b_temp;
        // End partial pivoting.

        // Begin putting [A; b] into row echelon form.
        for (long i=j; i<K-1; i++)
        {
            float c = -A[K * j + (i + 1)] / A[K * j + j];
    
            for (long l=j+1; l<K; l++)
                A[K * l + (i + 1)] += c * A[K * l + j];

            b[i + 1] += c * b[j];
        }
        // End putting [A; b] into row echelon form.
    }

    // Begin backsolving for x (by overwritting b).
    for (long j=K-1; j>0; j--)
        for (long i=j-1; i>=0; i--)
            b[i] -= b[j] * A[K * j + i] / A[K * j + j];

    for (long j=0; j<K; j++)
        b[j] *= 1 / A[K * j + j];
    // End backsolving for x.
}

int main()
{
    int i, j;

    float A[K * K] = {38, 49, 38, 73, 70, 71, 33, 24, 14, 82, 46, 99, 82, 36, 21, 32, 48, 40, 27, 60, 31, 15, 38, 88, 95, 57, 36, 86, 42, 56, 1, 37, 73, 7, 92, 93, 16, 95, 59, 76, 18, 42, 57, 9, 14, 40, 68, 61, 8, 26, 90, 33, 95, 8, 5, 87, 66, 84, 45, 78, 27, 16, 9, 83, 46, 61, 74, 44, 17, 21, 21, 53, 96, 49, 58, 67, 73, 60, 18, 40, 32, 68, 68, 21, 57, 86, 69, 7, 80, 10, 36, 46, 94, 59, 41, 80, 70, 2, 90, 57, 92, 50, 92, 98, 88, 14, 39, 80, 68, 78, 49, 40, 54, 51, 68, 80, 95, 22, 37, 88, 10, 30, 54, 7, 84, 99, 42, 94, 75, 45, 22, 41, 75, 38, 54, 97, 64, 62, 6, 48, 92, 49, 72, 5, 75, 67, 24, 55, 76, 17, 62, 19, 75, 41, 63, 97, 19, 83, 69, 12, 43, 94, 48, 92, 94, 54, 76, 11, 99, 96, 20, 29, 43, 97, 86, 23, 55, 2, 75, 61, 17, 45, 88, 79, 9, 26, 1, 3, 10, 91, 94, 85, 13, 58, 3, 53, 24, 76, 9, 2, 33, 34, 51, 65, 100, 67, 84, 21, 77, 17, 88, 65, 2, 46, 1, 18, 15, 57, 1, 88, 60, 64, 39, 36, 79, 89, 51, 39, 98, 67, 62, 34, 56, 98, 74, 52, 93, 11, 87, 45, 48, 82, 87, 5, 97, 65, 1, 81, 39, 85, 33, 26, 24, 90, 41, 69, 74, 43, 21, 54, 91, 94, 78, 41, 17, 11, 71, 25, 72, 52, 36, 27, 100, 48, 67, 52, 94, 44, 94, 91, 83, 95, 76, 19, 70, 34, 87, 67, 62, 67, 81, 55, 81, 45, 68, 1, 56, 95, 76, 38, 72, 88, 37, 64, 29, 16, 19, 81, 36, 18, 25, 28, 21, 17, 57, 51, 22, 87, 61, 39, 56, 51, 65, 44, 59, 3, 75, 98, 5, 21, 48, 95, 53, 23, 96, 4, 11, 11, 77, 21, 58, 78, 9, 93, 81, 17, 77, 97, 97, 44, 96, 26, 35, 89, 73, 26, 37, 3, 51, 76, 14, 67, 45, 92, 52, 83, 43, 91, 20, 62, 4, 48, 75, 35, 17, 65, 6, 98, 2, 78, 69, 39, 30, 57, 27, 49, 8, 71, 46, 82, 16, 62, 57, 69, 94, 15, 56, 15, 29, 42, 93, 96, 57, 2, 63, 23, 57, 54, 47, 88, 40, 1, 90, 48, 1, 4, 26, 32, 12, 97, 38, 62, 72, 92, 71, 72, 34, 93, 84, 56, 20, 33, 53, 42, 7, 54, 98, 37, 27, 2, 13, 88, 30, 24, 91, 22, 95, 100, 53, 53, 31, 91, 95, 9, 36, 89, 25, 60, 28, 47, 61, 81, 41, 47, 88, 6, 46, 83, 4, 48, 73, 88, 8, 83, 78, 18, 21, 75, 6, 90, 87, 92, 18, 71, 5, 82, 36, 2, 50, 86, 49, 72, 92, 67, 41, 38, 81, 37, 67, 93, 99, 51, 79, 95, 76, 85, 90, 27, 93, 44, 79, 97, 7, 11, 52, 76, 61, 23, 52, 97, 58, 74, 87, 58, 70, 77, 97, 74, 85, 65, 71, 79, 91, 36, 92, 35, 97, 9, 6, 38, 90, 46, 84, 98, 65, 4, 89, 9, 72, 55, 3, 21, 77, 43, 76, 83, 34, 16, 33, 21, 6, 28, 98, 27, 86, 93, 66, 55, 34, 76, 93, 42, 1, 36, 82, 82, 13, 45, 48, 8, 4, 66, 51, 32, 68, 81, 49, 70, 93, 73, 89, 16, 76, 95, 90, 37, 83, 28, 40, 14, 3, 18, 27, 34, 24, 53, 42, 24, 57, 93, 48, 43, 91, 28, 75, 86, 47, 40, 61, 20, 34, 81, 31, 62, 20, 75, 80, 81, 95, 75, 14, 8, 89, 13, 7, 9, 27, 80, 24, 52, 27, 75, 4, 58, 20, 82, 89, 31, 100, 48, 57, 73, 34, 52, 24, 26, 64, 18, 90, 74, 17, 58, 8, 44, 43, 56, 56, 51, 58, 56, 4, 87, 80, 24, 100, 47, 72, 60, 41, 2, 26, 81, 17, 57, 28, 6, 21, 4, 99, 92, 42, 37, 22, 45, 5, 93, 72, 27, 91, 13, 44, 93, 6, 100, 31, 17, 78, 16, 96, 32, 57, 45, 95, 76, 92, 3, 77, 84, 92, 87, 63, 42, 70, 79, 77, 90, 16, 100, 82, 61, 23, 67, 55, 45, 38, 27, 95, 19, 10, 4, 53, 75, 62, 1, 99, 62, 94, 30, 95, 65, 35, 62, 25, 59, 26, 62, 98, 50, 73, 31, 11, 89, 20, 1, 74, 45, 49, 55, 78, 49, 82, 35, 9, 45, 100, 99, 87, 10, 56, 79, 85, 89, 8, 9, 53, 87, 13, 27, 95, 81, 7, 71, 63, 44, 38, 84, 40, 87, 79, 54, 42, 58, 49, 85, 49, 6, 55, 83, 93, 52, 63, 76, 52, 40, 91, 36, 74, 70, 92, 92, 67, 57, 51, 74, 22, 35, 22, 48, 60, 86, 87, 79, 18, 65, 1, 36, 65, 91, 24, 33, 71, 52, 43, 20, 100, 94, 68, 19, 93, 66, 89, 45, 39, 97, 57, 67, 51, 92, 20, 97, 45, 32, 10, 82, 86, 2, 8, 27, 15, 60, 7, 6, 90, 71, 40, 91, 10, 16, 39, 40, 32, 2, 11, 5, 81, 31, 72, 41, 7, 89, 89, 85, 28, 67, 54, 44, 47, 26, 44, 51, 50, 65, 41, 68, 17, 88, 45, 43, 8, 11, 79, 10, 99, 58, 42, 75, 75, 86, 73, 24, 33, 15, 46, 84, 33, 27, 96, 14, 25, 11, 67, 48, 51, 85, 61, 87, 71, 85, 62, 32, 71, 15, 56, 6, 20, 43, 64, 97, 81, 94, 94, 61, 39, 46, 99, 37, 66, 40, 17, 74, 44, 6, 2, 11, 53, 44, 75, 29, 58, 77, 66, 96, 82, 13, 32, 43, 13, 36, 10, 39, 54, 39, 79, 22, 4, 41, 19, 44, 37, 73, 76, 84, 78, 94, 13, 98, 26, 56, 55, 51, 38, 37, 60, 55, 92, 19, 53, 48, 4, 7, 85, 82, 8, 60, 34, 67, 98, 76, 38, 14, 20, 62, 41, 58, 29, 70, 71, 16, 60, 26, 8, 64, 92, 17, 26, 40, 12, 59, 69, 97, 63, 52, 81, 27, 10, 99, 73, 74, 68, 8, 44, 70, 38, 65, 3, 27, 80, 90, 8, 64, 98, 89, 10, 45, 42, 55, 61, 49, 45, 82, 48, 27, 22, 16, 50, 58, 41, 92, 64, 54, 35, 65, 23, 66, 22, 9, 68, 79, 45, 69, 71, 94, 24, 41, 55, 48, 84, 12, 80, 71, 41, 91, 77, 83, 2, 12, 55, 21, 100, 99, 65, 20, 77, 37, 29, 75, 6, 59, 84, 25, 70, 40, 31, 73, 26, 61, 77, 16, 73, 41, 5, 83, 51, 9, 60, 97, 44, 21, 21, 87, 20, 74, 91, 43, 10, 69, 67, 14, 30, 71, 31, 20, 21, 98, 58, 21, 51, 83, 20, 69, 70, 13, 8, 62, 66, 28, 46, 75, 66, 65, 21, 32, 83, 7, 62, 4, 46, 98, 89, 20, 11, 57, 93, 72, 14, 80, 57, 10, 53, 67, 52, 88, 21, 97, 67, 42, 14, 86, 5, 12, 44, 35, 82, 3, 69, 87, 32, 10, 15, 54, 40, 60, 11, 46, 23, 77, 97, 46, 61, 90, 74, 82, 50, 15, 73, 59, 83, 68, 52, 54, 54, 89, 99, 44, 7, 85, 29, 65, 87, 20, 57, 5, 45, 98, 36, 98, 36, 99, 3, 54, 78, 100, 91, 73, 77, 63, 30, 11, 31, 21, 12, 78, 66, 36, 6, 50, 27, 55, 97, 79, 85, 29, 91, 72, 64, 18, 78, 77, 93, 74, 76, 33, 68, 71, 48, 10, 4, 19, 32, 53, 87, 75, 11, 25, 71, 23, 55, 16, 74, 28, 66, 90, 49, 75, 95, 19, 50, 75, 49, 52, 28, 57, 90, 20, 77, 52, 9, 42, 4, 20, 49, 78, 99, 78, 38, 100, 90, 7, 12, 8, 35, 26, 49, 54, 78, 43, 86, 23, 55, 11, 79, 20, 56, 61, 26, 81, 42, 93, 4, 3, 84, 3, 55, 46, 27, 67, 74, 28, 100, 44, 5, 14, 65, 22, 71, 13, 61, 65, 53, 14, 44, 53, 67, 69, 2, 76, 76, 90, 63, 21, 46, 46, 96, 19, 40, 12, 22, 45, 98, 6, 81, 7, 70, 51, 16, 62, 66, 33, 21, 69, 34, 24, 92, 23, 14, 51, 84, 36, 73, 83, 45, 52, 93, 20, 21, 61, 58, 75, 85, 36, 92, 29, 26, 100, 86, 79, 46, 43, 95, 9, 8, 98, 29, 27, 70, 93, 60, 20, 14, 10, 77, 71, 12, 38, 91, 59, 57, 84, 77, 15, 81, 17, 10, 42, 89, 4, 72, 16, 85, 27, 80, 85, 85, 9, 94, 3, 59, 30, 43, 30, 87, 20, 19, 33, 92, 8, 52, 46, 67, 26, 76, 3, 21, 71, 10, 37, 49, 61, 15, 70, 57, 66, 55, 52, 87, 36, 18, 30, 69, 28, 68, 26, 82, 86, 87, 16, 15, 46, 92, 54, 100, 92, 89, 52, 97, 53, 21, 31, 51, 31, 17, 46, 68, 53, 93, 64, 87, 43, 39, 94, 2, 38, 30, 87, 35, 53, 97, 28, 54, 58, 42, 55, 23, 27, 2, 27, 4, 78, 31, 14, 87, 21, 75, 26, 28, 67, 56, 65, 80, 10, 21, 48, 71, 52, 24, 67, 38, 62, 68, 93, 17, 56, 85, 87, 75, 62, 68, 45, 88, 49, 97, 78, 14, 94, 3, 67, 86, 9, 24, 92, 2, 12, 89, 73, 94, 63, 89, 65, 92, 61, 100, 90, 44, 57, 17, 74, 59, 5, 63, 5, 73, 46, 76, 69, 12, 97, 91, 9, 6, 61, 37, 5, 20, 39, 32, 19, 14, 46, 2, 46, 41, 28, 39, 29, 41, 59, 25, 97, 94, 63, 31, 64, 63, 72, 41, 46, 58, 79, 79, 35, 49, 42, 43, 82, 32, 41, 37, 84, 96, 100, 33, 87, 38, 89, 97, 25, 56, 61, 4, 100, 9, 83, 66, 77, 65, 22, 81, 52, 27, 6, 79, 29, 34, 15, 64, 22, 80, 61, 10, 74, 1, 68, 80, 74, 86, 98, 9, 24, 76, 57, 23, 5, 50, 7, 11, 80, 39, 10, 75, 38, 73, 8, 47, 3, 92, 90, 51, 42, 22, 45, 63, 27, 62, 78, 38, 5, 46, 46, 80, 51, 6, 43, 43, 7, 13, 50, 10, 64, 4, 67, 94, 69, 58, 58, 77, 71, 42, 80, 35, 15, 34, 65, 23, 43, 21, 24, 69, 24, 37, 68, 11, 38, 18, 12, 37, 41, 81, 12, 3, 91, 44, 98, 5, 1, 90, 53, 100, 90, 26, 36, 23, 14, 76, 23, 70, 58, 7, 35, 42, 11, 19, 48, 11, 24, 61, 49, 52, 69, 68, 82, 11, 57, 87, 65, 68, 54, 69, 39, 99, 1, 86, 44, 35, 36, 58, 73, 17, 14, 14, 87, 20, 57, 11, 65, 98, 77, 10, 51, 45, 50, 28, 56, 23, 64, 6, 11, 15, 93, 32, 77, 45, 57, 84, 49, 66, 98, 71, 8, 35, 62, 23, 82, 30, 75, 41, 15, 52, 22, 93, 68, 12, 83, 76, 19, 93, 67, 19, 35, 76, 49, 95, 40, 21, 78, 76, 86, 26, 31, 85, 15, 29, 82, 68, 54, 29, 70, 79, 93, 35, 2, 60, 78, 74, 32, 77, 94, 21, 21, 87, 48, 58, 76, 5, 87, 41, 6, 74, 83, 2, 56, 8, 2, 81, 3, 59, 7, 49, 62, 72, 98, 81, 68, 6, 82, 20, 97, 71, 16, 10, 58, 37, 98, 49, 23, 61, 80, 15, 77, 26, 56, 99, 21, 19, 60, 80, 61, 31, 6, 59, 70, 7, 87, 41, 9, 2, 34, 43, 84, 12, 24, 67, 63, 40, 78, 3, 100, 22, 100, 61, 59, 92, 26, 9, 39, 56, 93, 74, 47, 21, 71, 67, 81, 40, 74, 56, 34, 35, 82, 94, 35, 35, 15, 52, 44, 5, 83, 30, 10, 18, 65, 31, 45, 49, 100, 41, 26, 51, 3, 86, 17, 62, 13, 92, 58, 76, 53, 34, 81, 98, 57, 99, 81, 67, 23, 25, 99, 88, 62, 99, 37, 85, 17, 60, 23, 56, 97, 65, 41, 91, 16, 90, 47, 86, 56, 99, 44, 28, 18, 89, 27, 43, 43, 14, 64, 96, 8, 92, 74, 65, 24, 26, 96, 92, 19, 57, 24, 25, 3, 80, 99, 89, 78, 78, 80, 89, 27, 6, 49, 78, 81, 75, 99, 21, 64, 51, 98, 32, 53, 59, 74, 33, 1, 93, 9, 1, 24, 15, 8, 55, 76, 51, 98, 41, 77, 48, 81, 47, 76, 47, 65, 25, 2, 80, 67, 9, 85, 18, 73, 35, 50, 69, 46, 33, 14, 47, 25, 93, 28, 39, 12, 87, 85, 81, 16, 51, 91, 93, 32, 60, 55, 43, 54, 32, 57, 4, 30, 20, 15, 96, 64, 3, 99, 41, 5, 78, 28, 52, 39, 45, 41, 54, 1, 13, 53, 84, 75, 24, 100, 44, 8, 18, 46, 42, 86, 65, 27, 74, 1, 75, 99, 90, 33, 31, 4, 22, 17, 30, 44, 36, 72, 47, 75, 100, 47, 85, 86, 59, 37, 32, 30, 67, 98, 94, 85, 93, 1, 81, 60, 33, 97, 88, 73, 68, 8, 35, 30, 83, 19, 99, 74, 21, 93, 42, 80, 95, 27, 65, 24, 73, 31, 43, 92, 81, 24, 70, 67, 78, 48, 47, 70, 76, 12, 79, 89, 7, 28, 83, 78, 22, 25, 32, 17, 4, 68, 42, 15, 1, 3, 18, 43, 75, 48, 84, 17, 60, 100, 73, 59, 80, 68, 13, 89, 7, 93, 16, 22, 1, 58, 92, 87, 90, 23, 95, 76, 67, 10, 14, 70, 17, 99, 77, 6, 63, 69, 2, 93, 27, 29, 88, 39, 35, 25, 50, 91, 13, 16, 91, 50, 53, 54, 12, 53, 25, 11, 6, 10, 44, 36, 87, 67, 69, 5, 5, 78, 25, 19, 24, 50, 88, 62, 24, 89, 39, 86, 6, 7, 70, 56, 92, 18, 76, 57, 50, 28, 71, 50, 74, 19, 89, 49, 8, 76, 92, 80, 41, 34, 33, 63, 88, 31, 95, 97, 71, 52, 36, 26, 99, 72, 50, 76, 33, 62, 79, 11, 76, 54, 64, 42, 76, 5, 45, 79, 61, 39, 66, 72, 74, 76, 25, 63, 35, 100, 42, 61, 12, 9, 41, 95, 90, 48, 24, 8, 66, 65, 29, 74, 97, 54, 51, 31, 31, 51, 30, 63, 32, 70, 79, 49, 7, 35, 53, 76, 83, 62, 20, 13, 92, 95, 40, 99, 10, 98, 13, 7, 88, 16, 40, 10, 22, 29, 88, 64, 39, 13, 26, 12, 27, 69, 70, 23, 41, 67, 50, 96, 24, 97, 29, 31, 42, 27, 90, 50, 69, 42, 92, 22, 88, 23, 35, 83, 82, 74, 50, 72, 98, 94, 94, 46, 82, 16, 35, 88, 46, 89, 77, 86, 19, 17, 20, 5, 13, 25, 69, 79, 90, 55, 88, 71, 13, 30};
    float b[K] = {66, 97, 50, 69, 24, 42, 23, 82, 25, 79, 66, 26, 76, 25, 75, 25, 43, 40, 55, 8, 20, 53, 66, 94, 57, 10, 39, 70, 5, 57, 22, 36, 45, 94, 24, 44, 89, 41, 14, 87, 9, 46, 74, 23, 72, 62, 52, 74, 36, 13};

    clock_t begin = clock();

    linear_solve(A, b);

    clock_t end = clock();
    double time_spent = (double)(end - begin) / CLOCKS_PER_SEC;

    printf("seconds: %f\n", time_spent);
    printf("Result vector is: ");
    for (i=0; i<K; i++)
    {
        printf("%f,", b[i]);
    }
    printf("\n");

    return 0;
}

The following is Julia code for solving Ax=b, by calling LAPACK (LU-decomp followed by upper/lower triangular solver),

using BenchmarkTools

A = reshape(Float64[38, 49, 38, 73, 70, 71, 33, 24, 14, 82, 46, 99, 82, 36, 21, 32, 48, 40, 27, 60, 31, 15, 38, 88, 95, 57, 36, 86, 42, 56, 1, 37, 73, 7, 92, 93, 16, 95, 59, 76, 18, 42, 57, 9, 14, 40, 68, 61, 8, 26, 90, 33, 95, 8, 5, 87, 66, 84, 45, 78, 27, 16, 9, 83, 46, 61, 74, 44, 17, 21, 21, 53, 96, 49, 58, 67, 73, 60, 18, 40, 32, 68, 68, 21, 57, 86, 69, 7, 80, 10, 36, 46, 94, 59, 41, 80, 70, 2, 90, 57, 92, 50, 92, 98, 88, 14, 39, 80, 68, 78, 49, 40, 54, 51, 68, 80, 95, 22, 37, 88, 10, 30, 54, 7, 84, 99, 42, 94, 75, 45, 22, 41, 75, 38, 54, 97, 64, 62, 6, 48, 92, 49, 72, 5, 75, 67, 24, 55, 76, 17, 62, 19, 75, 41, 63, 97, 19, 83, 69, 12, 43, 94, 48, 92, 94, 54, 76, 11, 99, 96, 20, 29, 43, 97, 86, 23, 55, 2, 75, 61, 17, 45, 88, 79, 9, 26, 1, 3, 10, 91, 94, 85, 13, 58, 3, 53, 24, 76, 9, 2, 33, 34, 51, 65, 100, 67, 84, 21, 77, 17, 88, 65, 2, 46, 1, 18, 15, 57, 1, 88, 60, 64, 39, 36, 79, 89, 51, 39, 98, 67, 62, 34, 56, 98, 74, 52, 93, 11, 87, 45, 48, 82, 87, 5, 97, 65, 1, 81, 39, 85, 33, 26, 24, 90, 41, 69, 74, 43, 21, 54, 91, 94, 78, 41, 17, 11, 71, 25, 72, 52, 36, 27, 100, 48, 67, 52, 94, 44, 94, 91, 83, 95, 76, 19, 70, 34, 87, 67, 62, 67, 81, 55, 81, 45, 68, 1, 56, 95, 76, 38, 72, 88, 37, 64, 29, 16, 19, 81, 36, 18, 25, 28, 21, 17, 57, 51, 22, 87, 61, 39, 56, 51, 65, 44, 59, 3, 75, 98, 5, 21, 48, 95, 53, 23, 96, 4, 11, 11, 77, 21, 58, 78, 9, 93, 81, 17, 77, 97, 97, 44, 96, 26, 35, 89, 73, 26, 37, 3, 51, 76, 14, 67, 45, 92, 52, 83, 43, 91, 20, 62, 4, 48, 75, 35, 17, 65, 6, 98, 2, 78, 69, 39, 30, 57, 27, 49, 8, 71, 46, 82, 16, 62, 57, 69, 94, 15, 56, 15, 29, 42, 93, 96, 57, 2, 63, 23, 57, 54, 47, 88, 40, 1, 90, 48, 1, 4, 26, 32, 12, 97, 38, 62, 72, 92, 71, 72, 34, 93, 84, 56, 20, 33, 53, 42, 7, 54, 98, 37, 27, 2, 13, 88, 30, 24, 91, 22, 95, 100, 53, 53, 31, 91, 95, 9, 36, 89, 25, 60, 28, 47, 61, 81, 41, 47, 88, 6, 46, 83, 4, 48, 73, 88, 8, 83, 78, 18, 21, 75, 6, 90, 87, 92, 18, 71, 5, 82, 36, 2, 50, 86, 49, 72, 92, 67, 41, 38, 81, 37, 67, 93, 99, 51, 79, 95, 76, 85, 90, 27, 93, 44, 79, 97, 7, 11, 52, 76, 61, 23, 52, 97, 58, 74, 87, 58, 70, 77, 97, 74, 85, 65, 71, 79, 91, 36, 92, 35, 97, 9, 6, 38, 90, 46, 84, 98, 65, 4, 89, 9, 72, 55, 3, 21, 77, 43, 76, 83, 34, 16, 33, 21, 6, 28, 98, 27, 86, 93, 66, 55, 34, 76, 93, 42, 1, 36, 82, 82, 13, 45, 48, 8, 4, 66, 51, 32, 68, 81, 49, 70, 93, 73, 89, 16, 76, 95, 90, 37, 83, 28, 40, 14, 3, 18, 27, 34, 24, 53, 42, 24, 57, 93, 48, 43, 91, 28, 75, 86, 47, 40, 61, 20, 34, 81, 31, 62, 20, 75, 80, 81, 95, 75, 14, 8, 89, 13, 7, 9, 27, 80, 24, 52, 27, 75, 4, 58, 20, 82, 89, 31, 100, 48, 57, 73, 34, 52, 24, 26, 64, 18, 90, 74, 17, 58, 8, 44, 43, 56, 56, 51, 58, 56, 4, 87, 80, 24, 100, 47, 72, 60, 41, 2, 26, 81, 17, 57, 28, 6, 21, 4, 99, 92, 42, 37, 22, 45, 5, 93, 72, 27, 91, 13, 44, 93, 6, 100, 31, 17, 78, 16, 96, 32, 57, 45, 95, 76, 92, 3, 77, 84, 92, 87, 63, 42, 70, 79, 77, 90, 16, 100, 82, 61, 23, 67, 55, 45, 38, 27, 95, 19, 10, 4, 53, 75, 62, 1, 99, 62, 94, 30, 95, 65, 35, 62, 25, 59, 26, 62, 98, 50, 73, 31, 11, 89, 20, 1, 74, 45, 49, 55, 78, 49, 82, 35, 9, 45, 100, 99, 87, 10, 56, 79, 85, 89, 8, 9, 53, 87, 13, 27, 95, 81, 7, 71, 63, 44, 38, 84, 40, 87, 79, 54, 42, 58, 49, 85, 49, 6, 55, 83, 93, 52, 63, 76, 52, 40, 91, 36, 74, 70, 92, 92, 67, 57, 51, 74, 22, 35, 22, 48, 60, 86, 87, 79, 18, 65, 1, 36, 65, 91, 24, 33, 71, 52, 43, 20, 100, 94, 68, 19, 93, 66, 89, 45, 39, 97, 57, 67, 51, 92, 20, 97, 45, 32, 10, 82, 86, 2, 8, 27, 15, 60, 7, 6, 90, 71, 40, 91, 10, 16, 39, 40, 32, 2, 11, 5, 81, 31, 72, 41, 7, 89, 89, 85, 28, 67, 54, 44, 47, 26, 44, 51, 50, 65, 41, 68, 17, 88, 45, 43, 8, 11, 79, 10, 99, 58, 42, 75, 75, 86, 73, 24, 33, 15, 46, 84, 33, 27, 96, 14, 25, 11, 67, 48, 51, 85, 61, 87, 71, 85, 62, 32, 71, 15, 56, 6, 20, 43, 64, 97, 81, 94, 94, 61, 39, 46, 99, 37, 66, 40, 17, 74, 44, 6, 2, 11, 53, 44, 75, 29, 58, 77, 66, 96, 82, 13, 32, 43, 13, 36, 10, 39, 54, 39, 79, 22, 4, 41, 19, 44, 37, 73, 76, 84, 78, 94, 13, 98, 26, 56, 55, 51, 38, 37, 60, 55, 92, 19, 53, 48, 4, 7, 85, 82, 8, 60, 34, 67, 98, 76, 38, 14, 20, 62, 41, 58, 29, 70, 71, 16, 60, 26, 8, 64, 92, 17, 26, 40, 12, 59, 69, 97, 63, 52, 81, 27, 10, 99, 73, 74, 68, 8, 44, 70, 38, 65, 3, 27, 80, 90, 8, 64, 98, 89, 10, 45, 42, 55, 61, 49, 45, 82, 48, 27, 22, 16, 50, 58, 41, 92, 64, 54, 35, 65, 23, 66, 22, 9, 68, 79, 45, 69, 71, 94, 24, 41, 55, 48, 84, 12, 80, 71, 41, 91, 77, 83, 2, 12, 55, 21, 100, 99, 65, 20, 77, 37, 29, 75, 6, 59, 84, 25, 70, 40, 31, 73, 26, 61, 77, 16, 73, 41, 5, 83, 51, 9, 60, 97, 44, 21, 21, 87, 20, 74, 91, 43, 10, 69, 67, 14, 30, 71, 31, 20, 21, 98, 58, 21, 51, 83, 20, 69, 70, 13, 8, 62, 66, 28, 46, 75, 66, 65, 21, 32, 83, 7, 62, 4, 46, 98, 89, 20, 11, 57, 93, 72, 14, 80, 57, 10, 53, 67, 52, 88, 21, 97, 67, 42, 14, 86, 5, 12, 44, 35, 82, 3, 69, 87, 32, 10, 15, 54, 40, 60, 11, 46, 23, 77, 97, 46, 61, 90, 74, 82, 50, 15, 73, 59, 83, 68, 52, 54, 54, 89, 99, 44, 7, 85, 29, 65, 87, 20, 57, 5, 45, 98, 36, 98, 36, 99, 3, 54, 78, 100, 91, 73, 77, 63, 30, 11, 31, 21, 12, 78, 66, 36, 6, 50, 27, 55, 97, 79, 85, 29, 91, 72, 64, 18, 78, 77, 93, 74, 76, 33, 68, 71, 48, 10, 4, 19, 32, 53, 87, 75, 11, 25, 71, 23, 55, 16, 74, 28, 66, 90, 49, 75, 95, 19, 50, 75, 49, 52, 28, 57, 90, 20, 77, 52, 9, 42, 4, 20, 49, 78, 99, 78, 38, 100, 90, 7, 12, 8, 35, 26, 49, 54, 78, 43, 86, 23, 55, 11, 79, 20, 56, 61, 26, 81, 42, 93, 4, 3, 84, 3, 55, 46, 27, 67, 74, 28, 100, 44, 5, 14, 65, 22, 71, 13, 61, 65, 53, 14, 44, 53, 67, 69, 2, 76, 76, 90, 63, 21, 46, 46, 96, 19, 40, 12, 22, 45, 98, 6, 81, 7, 70, 51, 16, 62, 66, 33, 21, 69, 34, 24, 92, 23, 14, 51, 84, 36, 73, 83, 45, 52, 93, 20, 21, 61, 58, 75, 85, 36, 92, 29, 26, 100, 86, 79, 46, 43, 95, 9, 8, 98, 29, 27, 70, 93, 60, 20, 14, 10, 77, 71, 12, 38, 91, 59, 57, 84, 77, 15, 81, 17, 10, 42, 89, 4, 72, 16, 85, 27, 80, 85, 85, 9, 94, 3, 59, 30, 43, 30, 87, 20, 19, 33, 92, 8, 52, 46, 67, 26, 76, 3, 21, 71, 10, 37, 49, 61, 15, 70, 57, 66, 55, 52, 87, 36, 18, 30, 69, 28, 68, 26, 82, 86, 87, 16, 15, 46, 92, 54, 100, 92, 89, 52, 97, 53, 21, 31, 51, 31, 17, 46, 68, 53, 93, 64, 87, 43, 39, 94, 2, 38, 30, 87, 35, 53, 97, 28, 54, 58, 42, 55, 23, 27, 2, 27, 4, 78, 31, 14, 87, 21, 75, 26, 28, 67, 56, 65, 80, 10, 21, 48, 71, 52, 24, 67, 38, 62, 68, 93, 17, 56, 85, 87, 75, 62, 68, 45, 88, 49, 97, 78, 14, 94, 3, 67, 86, 9, 24, 92, 2, 12, 89, 73, 94, 63, 89, 65, 92, 61, 100, 90, 44, 57, 17, 74, 59, 5, 63, 5, 73, 46, 76, 69, 12, 97, 91, 9, 6, 61, 37, 5, 20, 39, 32, 19, 14, 46, 2, 46, 41, 28, 39, 29, 41, 59, 25, 97, 94, 63, 31, 64, 63, 72, 41, 46, 58, 79, 79, 35, 49, 42, 43, 82, 32, 41, 37, 84, 96, 100, 33, 87, 38, 89, 97, 25, 56, 61, 4, 100, 9, 83, 66, 77, 65, 22, 81, 52, 27, 6, 79, 29, 34, 15, 64, 22, 80, 61, 10, 74, 1, 68, 80, 74, 86, 98, 9, 24, 76, 57, 23, 5, 50, 7, 11, 80, 39, 10, 75, 38, 73, 8, 47, 3, 92, 90, 51, 42, 22, 45, 63, 27, 62, 78, 38, 5, 46, 46, 80, 51, 6, 43, 43, 7, 13, 50, 10, 64, 4, 67, 94, 69, 58, 58, 77, 71, 42, 80, 35, 15, 34, 65, 23, 43, 21, 24, 69, 24, 37, 68, 11, 38, 18, 12, 37, 41, 81, 12, 3, 91, 44, 98, 5, 1, 90, 53, 100, 90, 26, 36, 23, 14, 76, 23, 70, 58, 7, 35, 42, 11, 19, 48, 11, 24, 61, 49, 52, 69, 68, 82, 11, 57, 87, 65, 68, 54, 69, 39, 99, 1, 86, 44, 35, 36, 58, 73, 17, 14, 14, 87, 20, 57, 11, 65, 98, 77, 10, 51, 45, 50, 28, 56, 23, 64, 6, 11, 15, 93, 32, 77, 45, 57, 84, 49, 66, 98, 71, 8, 35, 62, 23, 82, 30, 75, 41, 15, 52, 22, 93, 68, 12, 83, 76, 19, 93, 67, 19, 35, 76, 49, 95, 40, 21, 78, 76, 86, 26, 31, 85, 15, 29, 82, 68, 54, 29, 70, 79, 93, 35, 2, 60, 78, 74, 32, 77, 94, 21, 21, 87, 48, 58, 76, 5, 87, 41, 6, 74, 83, 2, 56, 8, 2, 81, 3, 59, 7, 49, 62, 72, 98, 81, 68, 6, 82, 20, 97, 71, 16, 10, 58, 37, 98, 49, 23, 61, 80, 15, 77, 26, 56, 99, 21, 19, 60, 80, 61, 31, 6, 59, 70, 7, 87, 41, 9, 2, 34, 43, 84, 12, 24, 67, 63, 40, 78, 3, 100, 22, 100, 61, 59, 92, 26, 9, 39, 56, 93, 74, 47, 21, 71, 67, 81, 40, 74, 56, 34, 35, 82, 94, 35, 35, 15, 52, 44, 5, 83, 30, 10, 18, 65, 31, 45, 49, 100, 41, 26, 51, 3, 86, 17, 62, 13, 92, 58, 76, 53, 34, 81, 98, 57, 99, 81, 67, 23, 25, 99, 88, 62, 99, 37, 85, 17, 60, 23, 56, 97, 65, 41, 91, 16, 90, 47, 86, 56, 99, 44, 28, 18, 89, 27, 43, 43, 14, 64, 96, 8, 92, 74, 65, 24, 26, 96, 92, 19, 57, 24, 25, 3, 80, 99, 89, 78, 78, 80, 89, 27, 6, 49, 78, 81, 75, 99, 21, 64, 51, 98, 32, 53, 59, 74, 33, 1, 93, 9, 1, 24, 15, 8, 55, 76, 51, 98, 41, 77, 48, 81, 47, 76, 47, 65, 25, 2, 80, 67, 9, 85, 18, 73, 35, 50, 69, 46, 33, 14, 47, 25, 93, 28, 39, 12, 87, 85, 81, 16, 51, 91, 93, 32, 60, 55, 43, 54, 32, 57, 4, 30, 20, 15, 96, 64, 3, 99, 41, 5, 78, 28, 52, 39, 45, 41, 54, 1, 13, 53, 84, 75, 24, 100, 44, 8, 18, 46, 42, 86, 65, 27, 74, 1, 75, 99, 90, 33, 31, 4, 22, 17, 30, 44, 36, 72, 47, 75, 100, 47, 85, 86, 59, 37, 32, 30, 67, 98, 94, 85, 93, 1, 81, 60, 33, 97, 88, 73, 68, 8, 35, 30, 83, 19, 99, 74, 21, 93, 42, 80, 95, 27, 65, 24, 73, 31, 43, 92, 81, 24, 70, 67, 78, 48, 47, 70, 76, 12, 79, 89, 7, 28, 83, 78, 22, 25, 32, 17, 4, 68, 42, 15, 1, 3, 18, 43, 75, 48, 84, 17, 60, 100, 73, 59, 80, 68, 13, 89, 7, 93, 16, 22, 1, 58, 92, 87, 90, 23, 95, 76, 67, 10, 14, 70, 17, 99, 77, 6, 63, 69, 2, 93, 27, 29, 88, 39, 35, 25, 50, 91, 13, 16, 91, 50, 53, 54, 12, 53, 25, 11, 6, 10, 44, 36, 87, 67, 69, 5, 5, 78, 25, 19, 24, 50, 88, 62, 24, 89, 39, 86, 6, 7, 70, 56, 92, 18, 76, 57, 50, 28, 71, 50, 74, 19, 89, 49, 8, 76, 92, 80, 41, 34, 33, 63, 88, 31, 95, 97, 71, 52, 36, 26, 99, 72, 50, 76, 33, 62, 79, 11, 76, 54, 64, 42, 76, 5, 45, 79, 61, 39, 66, 72, 74, 76, 25, 63, 35, 100, 42, 61, 12, 9, 41, 95, 90, 48, 24, 8, 66, 65, 29, 74, 97, 54, 51, 31, 31, 51, 30, 63, 32, 70, 79, 49, 7, 35, 53, 76, 83, 62, 20, 13, 92, 95, 40, 99, 10, 98, 13, 7, 88, 16, 40, 10, 22, 29, 88, 64, 39, 13, 26, 12, 27, 69, 70, 23, 41, 67, 50, 96, 24, 97, 29, 31, 42, 27, 90, 50, 69, 42, 92, 22, 88, 23, 35, 83, 82, 74, 50, 72, 98, 94, 94, 46, 82, 16, 35, 88, 46, 89, 77, 86, 19, 17, 20, 5, 13, 25, 69, 79, 90, 55, 88, 71, 13, 30], (50,50))

b = Float64[66, 97, 50, 69, 24, 42, 23, 82, 25, 79, 66, 26, 76, 25, 75, 25, 43, 40, 55, 8, 20, 53, 66, 94, 57, 10, 39, 70, 5, 57, 22, 36, 45, 94, 24, 44, 89, 41, 14, 87, 9, 46, 74, 23, 72, 62, 52, 74, 36, 13]

linear_solve(A, b) = A \ b

@benchmark linear_solve(A, b)

The C code runs in approximately 166 microseconds, while the LAPACK (via Julia) code runs in an average of 33 microseconds (5 times faster!).

I suppose this is a testament to the quality of LAPACK and the associated Julia wrapper.

Unfortunately since this C code is to be part of an OpenCL kernel, I can't really take advantage of either, is there a way to make my C code more performant? So that it achieves a performance more similar to that of LAPACK?


回答1:


TL;DR: The current C code is inefficient on a modern hardware. Moreover, using OpenCL on dedicated GPUs or CUDA will only be fast for quite big matrices here (ie. not 50x50 ones).

The biggest problem in the C code comes from the line A[K * l + (i + 1)] += c * A[K * l + j];. Indeed, as the loop iterator is l, the memory access pattern is not contiguous but strided. Strided memory access pattern is much more inefficient than a contiguous ones on modern hardware architectures (due to code vectorization, cache lines, memory prefetching, etc.). This is especially true on GPUs. You can fix this problem by transposing the A matrix. Here is the modified version:

// Naive (inefficient) transposition
// Please use the much faster BLAS function to do this (if possible)
void transpose(float A[K * K])
{
    for (long j=0; j<K; ++j)
    {
        for (long i=j+1; i<K; ++i)
        {
            float tmp = A[K * i + j];
            A[K * i + j] = A[K * j + i];
            A[K * j + i] = tmp;
        }
    }
}

// Solve the system Ax=b using Gaussian elimination with partial pivoting.
// Work directly on the transposed version of A rather than transposing A every time should be much faster (especially for small matrices).
void fast_linear_solve(float A[K * K], float b[K])
{
    // Not useful if A is already transposed
    transpose(A);

    for (long j=0; j<K; j++)
    {
        // Begin partial pivoting.
        float maxval = fabs(A[K * j + j]);

        long maxrow = j;

        for (long i=j+1; i<K; i++)
        {
            if (fabs(A[K * i + j]) > maxval)
            {
                maxval = fabs(A[K * i + j]);
                maxrow = i;
            }
        }
            
        for (long l=0; l<K; l++)
        {
            float A_temp = A[K * maxrow + l];
            A[K * maxrow + l] = A[K * j + l];
            A[K * j + l] = A_temp;
        }

        float b_temp = b[maxrow];
        b[maxrow] = b[j];
        b[j] = b_temp;
        // End partial pivoting.

        // Begin putting [A; b] into row echelon form.
        for (long i=j; i<K-1; i++)
        {
            float c = -A[K * (i + 1) + j] / A[K * j + j];

            for (long l=j+1; l<K; l++)
                A[K * (i + 1) + l] += c * A[K * j + l];

            b[i + 1] += c * b[j];
        }
        // End putting [A; b] into row echelon form.
    }

    // Begin backsolving for x (by overwritting b).
    for (long j=K-1; j>0; j--)
        for (long i=j-1; i>=0; i--)
            b[i] -= b[j] * A[K * i + j] / A[K * j + j];

    for (long j=0; j<K; j++)
        b[j] *= 1 / A[K * j + j];
    // End backsolving for x.

    // Not useful if A is already transposed
    transpose(A);
}

Another problem comes from the way the benchmark is performed. Indeed, Julia run multiple time the code while the C code is executed once and with the clock function. To have a more fair comparison with the Julia implementation, the linear_solve function of the C implementation must be evaluated multiple times (by putting it in a loop and taking care of possible clever compiler optimizations that could add some biases). gettimeofday should be preferred over clock (as the former compute the wall-clock time and the latter compute the sum of the user time and the system time).

Here are (average) results with a 50x50 matrix on my machine (with GCC 9.3 using -O3, Clang 9.0 using -O3 too, and with Julia 1.4):

Original C code (GCC):  25 us      |      Original C code (Clang):  25 us
New C code (GCC):       11 us      |      New C code (Clang):       12 us
Julia code:             80 us

Here are results with a 500x500 random matrix:

Original C code (GCC):  37.9 ms    |    Original C code (Clang):  38.8 ms
New C code (GCC):        6.7 ms    |    New C code (Clang):        6.1 ms
Julia code:              2.3 ms

There is still a room for improvement for big matrices: the C code can be improved using loop tiling for example (at the cost of decreasing the code readability and maintainability).

One should keep in mind that although using (dedicated) GPUs should improves performance for big matrices, it should however not be the case for small matrices due to the relatively high latency of GPUs (eg. data transfers, synchronizations, memory latency) unless batch processing is used on many small matrices.



来源:https://stackoverflow.com/questions/62723424/writing-a-fast-linear-system-solver-in-opencl-c

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!