Numpy dot too clever about symmetric multiplications

后端 未结 2 1096
爱一瞬间的悲伤
爱一瞬间的悲伤 2020-12-05 14:33

Anybody know about documentation for this behaviour?

import numpy as np
A  = np.random.uniform(0,1,(10,5))
w  = np.ones(5)
Aw = A*w
Sym1 = Aw.dot(Aw.T)
Sym2          


        
相关标签:
2条回答
  • 2020-12-05 14:56

    I suspect this is to do with promotion of intermediate floating point registers to 80 bit precision. Somewhat confirming this hypothesis is that if we use fewer floats we consistently get 0 in our results, ala

    A  = np.random.uniform(0,1,(4,2))
    w  = np.ones(2)
    Aw = A*w
    Sym1 = Aw.dot(Aw.T)
    Sym2 = (A*w).dot((A*w).T)
    diff = Sym1 - Sym2
    # diff is all 0's (ymmv)
    
    0 讨论(0)
  • 2020-12-05 14:58

    This behaviour is the result of a change introduced for NumPy 1.11.0, in pull request #6932. From the release notes for 1.11.0:

    Previously, gemm BLAS operations were used for all matrix products. Now, if the matrix product is between a matrix and its transpose, it will use syrk BLAS operations for a performance boost. This optimization has been extended to @, numpy.dot, numpy.inner, and numpy.matmul.

    In the changes for that PR, one finds this comment:

    /*
     * Use syrk if we have a case of a matrix times its transpose.
     * Otherwise, use gemm for all other cases.
     */
    

    So NumPy is making an explicit check for the case of a matrix times its transpose, and calling a different underlying BLAS function in that case. As @hpaulj notes in a comment, such a check is cheap for NumPy, since a transposed 2d array is simply a view on the original array, with inverted shape and strides, so it suffices to check a few pieces of metadata on the arrays (rather than having to compare the actual array data).

    Here's a slightly simpler case that shows the discrepancy. Note that using a .copy on one of the arguments to dot is enough to defeat NumPy's special-casing.

    import numpy as np
    random = np.random.RandomState(12345)
    A = random.uniform(size=(10, 5))
    Sym1 = A.dot(A.T)
    Sym2 = A.dot(A.T.copy())
    print(abs(Sym1 - Sym2).max())
    

    I guess one advantage of this special-casing, beyond the obvious potential for speed-up, is that you're guaranteed (I'd hope, but in practice it'll depend on the BLAS implementation) to get a perfectly symmetric result when syrk is used, rather than a matrix which is merely symmetric up to numerical error. As an (admittedly not very good) test for this, I tried:

    import numpy as np
    random = np.random.RandomState(12345)
    A = random.uniform(size=(100, 50))
    Sym1 = A.dot(A.T)
    Sym2 = A.dot(A.T.copy())
    print("Sym1 symmetric: ", (Sym1 == Sym1.T).all())
    print("Sym2 symmetric: ", (Sym2 == Sym2.T).all())
    

    Results on my machine:

    Sym1 symmetric:  True
    Sym2 symmetric:  False
    
    0 讨论(0)
提交回复
热议问题