Multiplying two matrices in Java

夙愿已清 提交于 2019-12-12 17:13:16

问题


I am currently developing a class to represent matrices, it represents any general mxn matrix. I have worked out addition and scalar multiplication but I am struggling to develop the multiplication of two matrices. The data of the matrix is held in a 2D array of doubles.

The method looks a little bit like this:

   public Matrix multiply(Matrix A) {
            ////code
   }

It will return the product matrix. This is multiplication on the right. So, if I called A.multiply(B) then it would return the matrix AB, with B on the right.

I don't yet need to worry about checking whether the multiplication is defined on the given matrices, I can assume that I will be given matrices of the correct dimensions.

Does anyone know of an easy algorithm, possibly even in pseudocode to carry out the multiplication process?

Thanks in advance.


回答1:


Mathematically the Product of Matrices A (l x m) and B (m x n) is defined as a Matrix C (l x n) consisting of the elements:

        m
c_i_j = ∑  a_i_k * b_k_j
       k=1

So if you're not too much up for speed you might be happy with the straight forward O(n^3) implementation:

  for (int i=0; i<l; ++i)
    for (int j=0; j<n; ++j)
      for (int k=0; k<m; ++k)
        c[i][j] += a[i][k] * b[k][j]  

If instead you're up for speed you might want to check for other alternatives like Strassen algorithm (see: Strassen algorithm).

Nevertheless be warned - especially if you're multiplying small matrices on modern processor architectures speed heavily depends on matrix data and multiplication order arranged in a way to make best use of in cache lines.

I strongly doubt there will be any chance to influence this factor from withing a vm, so I'm not sure if this is to be taken into consideration.




回答2:


Java. Matrix multiplication.

Here is the "code to carry out the multiplication process". Tested with matrices of different size.

public class Matrix {

/**
 * Matrix multiplication method.
 * @param m1 Multiplicand
 * @param m2 Multiplier
 * @return Product
 */
    public static double[][] multiplyByMatrix(double[][] m1, double[][] m2) {
        int m1ColLength = m1[0].length; // m1 columns length
        int m2RowLength = m2.length;    // m2 rows length
        if(m1ColLength != m2RowLength) return null; // matrix multiplication is not possible
        int mRRowLength = m1.length;    // m result rows length
        int mRColLength = m2[0].length; // m result columns length
        double[][] mResult = new double[mRRowLength][mRColLength];
        for(int i = 0; i < mRRowLength; i++) {         // rows from m1
            for(int j = 0; j < mRColLength; j++) {     // columns from m2
                for(int k = 0; k < m1ColLength; k++) { // columns from m1
                    mResult[i][j] += m1[i][k] * m2[k][j];
                }
            }
        }
        return mResult;
    }

    public static String toString(double[][] m) {
        String result = "";
        for(int i = 0; i < m.length; i++) {
            for(int j = 0; j < m[i].length; j++) {
                result += String.format("%11.2f", m[i][j]);
            }
            result += "\n";
        }
        return result;
    }

    public static void main(String[] args) {
        // #1
        double[][] multiplicand = new double[][] {
                {3, -1, 2},
                {2,  0, 1},
                {1,  2, 1}
        };
        double[][] multiplier = new double[][] {
                {2, -1, 1},
                {0, -2, 3},
                {3,  0, 1}
        };
        System.out.println("#1\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
        // #2
        multiplicand = new double[][] {
                {1, 2, 0},
                {-1, 3, 1},
                {2, -2, 1}
        };
        multiplier = new double[][] {
                {2},
                {-1},
                {1}
        };
        System.out.println("#2\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
        // #3
        multiplicand = new double[][] {
                {1, 2, -1},
                {0,  1, 0}
        };
        multiplier = new double[][] {
                {1, 1, 0, 0},
                {0, 2, 1, 1},
                {1, 1, 2, 2}
        };
        System.out.println("#3\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
    }
}

Output:

#1
      12.00      -1.00       2.00
       7.00      -2.00       3.00
       5.00      -5.00       8.00

#2
       0.00
      -4.00
       7.00

#3
       0.00       4.00       0.00       0.00
       0.00       2.00       1.00       1.00



回答3:


In this answer, I created a class named Matrix, and another class is known as MatrixOperations which defines the various operations that can be performed on matrices (except for row operations of course). But I will extract the code for multiplication from MatrixOperations. The full project can be found on my GitHub page here.

Below is the definition of the Matrix class.

package app.matrix;

import app.matrix.util.MatrixException;

public class Matrix {

private double[][] entries;

public void setEntries(double[][] entries) {
    this.entries = entries;
}

private String name;

public double[][] getEntries() {
    return entries;
}

public String getName() {
    return name;
}

public void setName(String name) {
    this.name = name;
}

public class Dimension {
    private int rows;
    private int columns;

    public int getRows() {
        return rows;
    }

    public void setRows(int rows) {
        this.rows = rows;
    }

    public int getColumns() {
        return columns;
    }

    public void setColumns(int columns) {
        this.columns = columns;
    }

    public Dimension(int rows, int columns) {
        this.setRows(rows);
        this.setColumns(columns);
    }

    @Override
    public boolean equals(Object obj) {
        if(obj instanceof Dimension){
            return (this.getColumns() == ((Dimension) obj).getColumns()) && (this.getRows() == ((Dimension) obj).getRows());
        }
        return false;
    }
}

private Dimension dimension;

public Dimension getDimension() {
    return dimension;
}

public void setDimension(Dimension dimension) {
    this.dimension = dimension;
}

public Matrix(int dimension, String name) throws MatrixException {
    if (dimension == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
    else this.setEntries(new double[Math.abs(dimension)][Math.abs(dimension)]);
    this.setDimension(new Dimension(dimension, dimension));
    this.setName(name);
}

public Matrix(int dimensionH, int dimensionV, String name) throws MatrixException {
    if (dimensionH == 0 || dimensionV == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
    else this.setEntries(new double[Math.abs(dimensionH)][Math.abs(dimensionV)]);
    this.setDimension(new Dimension(dimensionH, dimensionV));
    this.setName(name);

}

private static final String OVERFLOW_ITEMS_MSG = "The values are too many for the matrix's specified dimensions";
private static final String ZERO_UNIT_DIMENSION = "Zero cannot be a value for a dimension";

public Matrix(int dimensionH, int dimensionV, String name, double... values) throws MatrixException {
    if (dimensionH == 0 || dimensionV == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
    else if (values.length > dimensionH * dimensionV) throw new MatrixException(Matrix.OVERFLOW_ITEMS_MSG);
    else this.setEntries(new double[Math.abs(dimensionH)][Math.abs(dimensionV)]);
    this.setDimension(new Dimension(dimensionH, dimensionV));
    this.setName(name);

    int iterator = 0;
    int j;
    for (int i = 0; i < dimensionH; i++) {
        j = 0;
        while (j < dimensionV) {
            this.entries[i][j] = values[iterator];
            j++;
            iterator++;
        }
    }
}

public Matrix(Dimension dimension) throws MatrixException {
    this(dimension.getRows(), dimension.getColumns(), null);
}

public static Matrix identityMatrix(int dim) throws MatrixException {
    if (dim == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);

    double[] i = new double[dim * dim];
    int constant = dim + 1;
    for (int j = 0; j < i.length; j = j + constant) {
        i[j] = 1.0;
    }

    return new Matrix(dim, dim, null, i);
}

public String toString() {

    StringBuilder builder = new StringBuilder("Matrix \"" + (this.getName() == null ? "Null Matrix" : this.getName()) + "\": {\n");

    for (int i = 0; i < this.getDimension().getRows(); i++) {
        for (int j = 0; j < this.getDimension().getColumns(); j++) {
            if (j == 0) builder.append("\t");
            builder.append(this.entries[i][j]);
            if (j != this.getDimension().getColumns() - 1)
                builder.append(", ");
        }
        if (i != this.getDimension().getRows()) builder.append("\n");
    }

    builder.append("}");

    return builder.toString();
}

public boolean isSquare() {
    return this.getDimension().getColumns() == this.getDimension().getRows();
}

}

and here is the code method for matrix multiplication from MatrixOperations

public static Matrix multiply(Matrix matrix1, Matrix matrix2) throws MatrixException {

    if (matrix1.getDimension().getColumns() != matrix2.getDimension().getRows())
        throw new MatrixException(MATRIX_MULTIPLICATION_ERROR_MSG);

    Matrix retVal = new Matrix(matrix1.getDimension().getRows(), matrix2.getDimension().getColumns(), matrix1.getName() + " x " + matrix2.getName());


    for (int i = 0; i < matrix1.getDimension().getRows(); i++) {
        for (int j = 0; j < matrix2.getDimension().getColumns(); j++) {
            retVal.getEntries()[i][j] = sum(arrayProduct(matrix1.getEntries()[i], getColumnMatrix(matrix2, j)));
        }
    }

    return retVal;
}

and below again are the codes for methods sum, arrayProduct, and getColumnMatrix

private static double sum(double... values) {
    double sum = 0;
    for (double value : values) {
        sum += value;
    }
    return sum;
}

private static double[] arrayProduct(double[] arr1, double[] arr2) throws MatrixException {
    if (arr1.length != arr2.length) throw new MatrixException("Array lengths must be the same");
    double[] retVal = new double[arr1.length];
    for (int i = 0; i < arr1.length; i++) {
        retVal[i] = arr1[i] * arr2[i];
    }

    return retVal;
}


private static double[] getColumnMatrix(Matrix matrix, int col) {
    double[] ret = new double[matrix.getDimension().getRows()];
    for (int i = 0; i < matrix.getDimension().getRows(); i++) {
        ret[i] = matrix.getEntries()[i][col];
    }
    return ret;
}


来源:https://stackoverflow.com/questions/15733829/multiplying-two-matrices-in-java

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!