Java 8 matrix * vector multiplication

最后都变了- 提交于 2021-02-04 17:38:29

问题


I'm wondering if there is a more condensed way of doing the following in Java 8 with streams:

public static double[] multiply(double[][] matrix, double[] vector) {
    int rows = matrix.length;
    int columns = matrix[0].length;

    double[] result = new double[rows];

    for (int row = 0; row < rows; row++) {
        double sum = 0;
        for (int column = 0; column < columns; column++) {
            sum += matrix[row][column]
                    * vector[column];
        }
        result[row] = sum;
    }
    return result;
}

Making an Edit. I received a very good answer, however the performance is about 10X slower than the old implementation, so I'm adding the test code here in case anyone wants to investigate it:

@Test
public void profile() {
    long start;
    long stop;
    int tenmillion = 10000000;
    double[] vector = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };

    double[][] matrix = new double[tenmillion][10];

    for (int i = 0; i < tenmillion; i++) {
        matrix[i] = vector.clone();
    }
    start = System.currentTimeMillis();
    multiply(matrix, vector);
    stop = System.currentTimeMillis();
 }

回答1:


A direct way using Stream would be the following:

public static double[] multiply(double[][] matrix, double[] vector) {
    return Arrays.stream(matrix)
                 .mapToDouble(row -> 
                    IntStream.range(0, row.length)
                             .mapToDouble(col -> row[col] * vector[col])
                             .sum()
                 ).toArray();
}

This creates a Stream of each row of the matrix (Stream<double[]>), then maps each row to the double value resulting of calculating the product with the vector array.

We have to use a Stream over the indexes to calculate the product because there are unfortunately no built-in facility to zip two Streams together.




回答2:


The way you are measuring the performance is not very reliable to measure performance and its usually not a good idea to write the micro-benchmarks manually. For example, while compiling the code, JVM may choose to change the order of execution, and the start and stop variables might not be getting assigned where you expect them to be assigned and consequently will give unexpected results in your measurements. It is also very important to warm-up JVM as well to let JIT compiler make all the optimizations. GC can also play a very big role in introducing variations in the throughput and response time of your application. I would strongly recommend to use specialized tools such as JMH and Caliper for micro-benchmarking.

I also wrote some benchmarking code with JVM warmup, random data set and higher number of iterations. It turns out that Java 8 streams is giving better results.

/**
 *
 */
public class MatrixMultiplicationBenchmark {
    private static AtomicLong start = new AtomicLong();
    private static AtomicLong stop = new AtomicLong();
    private static Random random = new Random();

    /**
     * Main method that warms-up each implementation and then runs the benchmark.
     *
     * @param args main class args
     */
    public static void main(String[] args) {
        // Warming up with more iterations and smaller data set
        System.out.println("Warming up...");
        IntStream.range(0, 10_000_000).forEach(i -> run(10, MatrixMultiplicationBenchmark::multiplyWithStreams));
        IntStream.range(0, 10_000_000).forEach(i -> run(10, MatrixMultiplicationBenchmark::multiplyWithForLoops));

        // Running with less iterations and larger data set
        startWatch("Running MatrixMultiplicationBenchmark::multiplyWithForLoops...");
        IntStream.range(0, 10).forEach(i -> run(10_000_000, MatrixMultiplicationBenchmark::multiplyWithForLoops));
        endWatch("MatrixMultiplicationBenchmark::multiplyWithForLoops");

        startWatch("Running MatrixMultiplicationBenchmark::multiplyWithStreams...");
        IntStream.range(0, 10).forEach(i -> run(10_000_000, MatrixMultiplicationBenchmark::multiplyWithStreams));
        endWatch("MatrixMultiplicationBenchmark::multiplyWithStreams");
    }

    /**
     * Creates the random matrix and vector and applies them in the given implementation as BiFunction object.
     *
     * @param multiplyImpl implementation to use.
     */
    public static void run(int size, BiFunction<double[][], double[], double[]> multiplyImpl) {
        // creating random matrix and vector
        double[][] matrix = new double[size][10];
        double[] vector = random.doubles(10, 0.0, 10.0).toArray();
        IntStream.range(0, size).forEach(i -> matrix[i] = random.doubles(10, 0.0, 10.0).toArray());

        // applying matrix and vector to the given implementation. Returned value should not be ignored in test cases.
        double[] result = multiplyImpl.apply(matrix, vector);
    }

    /**
     * Multiplies the given vector and matrix using Java 8 streams.
     *
     * @param matrix the matrix
     * @param vector the vector to multiply
     *
     * @return result after multiplication.
     */
    public static double[] multiplyWithStreams(final double[][] matrix, final double[] vector) {
        final int rows = matrix.length;
        final int columns = matrix[0].length;

        return IntStream.range(0, rows)
                .mapToDouble(row -> IntStream.range(0, columns)
                        .mapToDouble(col -> matrix[row][col] * vector[col])
                        .sum()).toArray();
    }

    /**
     * Multiplies the given vector and matrix using vanilla for loops.
     *
     * @param matrix the matrix
     * @param vector the vector to multiply
     *
     * @return result after multiplication.
     */
    public static double[] multiplyWithForLoops(double[][] matrix, double[] vector) {
        int rows = matrix.length;
        int columns = matrix[0].length;

        double[] result = new double[rows];

        for (int row = 0; row < rows; row++) {
            double sum = 0;
            for (int column = 0; column < columns; column++) {
                sum += matrix[row][column] * vector[column];
            }
            result[row] = sum;
        }
        return result;
    }

    private static void startWatch(String label) {
        System.out.println(label);
        start.set(System.currentTimeMillis());
    }

    private static void endWatch(String label) {
        stop.set(System.currentTimeMillis());
        System.out.println(label + " took " + ((stop.longValue() - start.longValue()) / 1000) + "s");
    }
}

Here is the output

Warming up...
Running MatrixMultiplicationBenchmark::multiplyWithForLoops...
MatrixMultiplicationBenchmark::multiplyWithForLoops took 100s
Running MatrixMultiplicationBenchmark::multiplyWithStreams...
MatrixMultiplicationBenchmark::multiplyWithStreams took 89s


来源:https://stackoverflow.com/questions/34519952/java-8-matrix-vector-multiplication

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!