numpy矩阵乘法的解惑

杀马特。学长 韩版系。学妹 提交于 2020-03-12 02:30:27

#源码如下: 批量梯度下降法

import numpy as np
# Setting a random seed, feel free to change it and see different solutions.
np.random.seed(42)


# TODO: Fill in code in the function below to implement a gradient descent
# step for linear regression, following a squared error rule. See the docstring
# for parameters and returned variables.
def MSEStep(X, y, W, b, learn_rate = 0.005):
    """
    This function implements the gradient descent step for squared error as a
    performance metric.
    
    Parameters
    X : array of predictor features
    y : array of outcome values
    W : predictor feature coefficients
    b : regression function intercept
    learn_rate : learning rate

    Returns
    W_new : predictor feature coefficients following gradient descent step
    b_new : intercept following gradient descent step
    """
    
    # Fill in code
    
    y_pred = np.matmul(X, W) + b
    
    print("np.matmul(X, W) shi ge sha:",np.matmul(X, W) ,"np.matmul(X, W)=",np.matmul(X, W).shape)
    error = y - y_pred
    
    # compute steps
    W_new = W + learn_rate * np.matmul(error, X)
    print("np.matmul(error, X).shape=",np.matmul(error, X).shape," W.shape=",W.shape,"err.shape=",error.shape,"X.shape=",X.shape)
    b_new = b + learn_rate * error.sum()
    return W_new, b_new


# The parts of the script below will be run when you press the "Test Run"
# button. The gradient descent step will be performed multiple times on
# the provided dataset, and the returned list of regression coefficients
# will be plotted.
def miniBatchGD(X, y, batch_size = 20, learn_rate = 0.005, num_iter = 25):
    """
    This function performs mini-batch gradient descent on a given dataset.

    Parameters
    X : array of predictor features
    y : array of outcome values
    batch_size : how many data points will be sampled for each iteration
    learn_rate : learning rate
    num_iter : number of batches used

    Returns
    regression_coef : array of slopes and intercepts generated by gradient
      descent procedure
    """
    
    n_points = X.shape[0]
    W = np.zeros(X.shape[1]) # coefficients
    b = 0 # intercept
    print("typex=",type(X),"typeW=",type(W))
    print("type(y)=",type(y),"typeB=",type(b))
    print("X.shape[0]=",X.shape[0])
    print("X=",X,"w=",W)
    # run iterations
    regression_coef = [np.hstack((W,b))]
    for _ in range(num_iter):
        batch = np.random.choice(range(n_points), batch_size)
        if _==0:
            print("type(batch)",type(batch))
        X_batch = X[batch,:]
        y_batch = y[batch]
        W, b = MSEStep(X_batch, y_batch, W, b, learn_rate)
        regression_coef.append(np.hstack((W,b)))
    
    return regression_coef


if __name__ == "__main__":
    # perform gradient descent
    data = np.loadtxt('data.csv', delimiter = ',')
    X = data[:,:-1]
    y = data[:,-1]
   
    regression_coef = miniBatchGD(X, y)
    
    # plot the results
    import matplotlib.pyplot as plt
    
    plt.figure()
    X_min = X.min()
    X_max = X.max()
    counter = len(regression_coef)
    for W, b in regression_coef:
        counter -= 1
        color = [1 - 0.92 ** counter for _ in range(3)]
        plt.plot([X_min, X_max],[X_min * W + b, X_max * W + b], color = color)
    plt.scatter(X, y, zorder = 3)
    plt.show()

 

#$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$

#有几点疑惑澄清:

1 矩阵相乘遇到shape为一维数组,即(n,)类似形式实际是当作1*n向量来处理的,与1*n的区别是需要将一维数组最外层再加层[]就可以了显示获得shape=(1,n)了

numpy.zeros() 返回的默认的都是(n,)的一维数组 ,实际是1*n的数组

  err.shape= (20,) X.shape= (20, 1)  np.matmul(error, X).shape= (1,)

2 向量+常量b是按照每个向量都加上该常量b来处理的

y_pred = np.matmul(X, W) + b

X.shape= (20, 1) 

W.shape= (1, ) 

np.matmul(X, W)= (20,)

Loading data... Performng gradient descent (default params)... typex= <class 'numpy.ndarray'> typeW= <class 'numpy.ndarray'> type(y)= <class 'numpy.ndarray'> typeB= <class 'int'> X.shape[0]= 100 X= [[-7.24070e-01] [-2.40724e+00] [ 2.64837e+00] [ 3.60920e-01] [ 6.73120e-01] [-4.54600e-01] [ 2.20168e+00] [ 1.15605e+00] [ 5.06940e-01] [-8.59520e-01] [-5.99700e-01] [ 1.46804e+00] [-1.05659e+00] [ 1.29177e+00] [-7.45650e-01] [ 1.50330e-01] [-1.49627e+00] [-7.20710e-01] [ 3.29240e-01] [-2.80530e-01] [-1.36115e+00] [ 7.46780e-01] [ 1.06210e-01] [ 3.25600e-02] [-9.82900e-01] [-1.15661e+00] [ 9.02400e-02] [-1.03816e+00] [-6.04000e-03] [ 1.62780e-01] [-6.98690e-01] [ 1.03857e+00] [-1.17830e-01] [-9.54090e-01] [-8.18390e-01] [-1.28802e+00] [ 6.28220e-01] [-2.29674e+00] [-8.56010e-01] [-1.75223e+00] [-1.19662e+00] [ 9.77810e-01] [-1.17110e+00] [ 1.58350e-01] [-5.89180e-01] [-1.79678e+00] [-9.57270e-01] [ 6.45560e-01] [ 2.46250e-01] [ 4.59170e-01] [ 1.21036e+00] [-6.01160e-01] [ 2.68510e-01] [ 4.95940e-01] [-2.67877e+00] [ 4.94020e-01] [ 1.18643e+00] [-1.77410e-01] [ 5.79380e-01] [-2.14926e+00] [ 2.27700e+00] [-1.05695e+00] [ 1.68288e+00] [-1.53513e+00] [ 9.90000e-04] [ 4.55200e-01] [-3.78550e-01] [ 1.35638e+00] [ 1.76300e-02] [ 2.21725e+00] [-4.44420e-01] [ 8.95830e-01] [ 1.30499e+00] [ 1.08830e-01] [ 1.79466e+00] [-7.33000e-03] [ 7.98620e-01] [-1.23530e-01] [-1.34999e+00] [-6.78250e-01] [-1.79010e-01] [ 1.25770e-01] [ 1.11943e+00] [-3.02296e+00] [ 6.49650e-01] [ 1.05994e+00] [ 5.33600e-01] [-7.35910e-01] [-9.56900e-02] [ 1.04694e+00] [ 4.65110e-01] [-7.54630e-01] [-9.41590e-01] [-9.31400e-02] [-9.86410e-01] [-9.21590e-01] [ 7.69530e-01] [ 3.28300e-02] [-1.07619e+00] [ 2.01740e-01]] w= [0.] type(batch) <class 'numpy.ndarray'> np.matmul(X, W) shi ge sha: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.29708128 -0.18945281 -0.26524356 -0.16798167 -0.01454159 -0.00090461 -0.02189445 0.0921613 -0.01180925 0.0303901 0.05739996 0.0715022 0.12067307 -0.09313008 -0.26524356 -0.0837039 -0.09202184 -0.13043986 -0.13043986 -0.1181382 ] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.0815292 0.09336268 -0.20663037 -0.11841423 0.20428543 0.09336268 0.16982942 -0.10499406 0.10066208 -0.06602943 -0.05559289 0.02784003 -0.00738117 -0.16578594 0.09964234 0.03910347 0.08075707 0.02071186 -0.18568556 -0.23317991] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.10558319 -0.30071126 -0.06218052 0.02215536 0.16174741 -0.13393534 -0.1145041 -0.01728356 -0.02504598 -0.18021185 0.06424425 0.05049771 -0.33680623 -0.06360484 0.06938888 0.05049771 0.06938888 -0.13174149 0.23545823 -0.10083732] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.14675307 0.02219645 -0.13373798 0.01525506 -0.14815621 0.02827857 0.18107171 -0.13826838 0.0904903 -0.10452025 0.12557148 -0.01731561 0.07479649 -0.14815621 -0.24561593 0.09106361 -0.09507257 0.0176296 0.03763794 0.00456404] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-1.79445215e-01 -1.48460697e-02 -3.33452437e-01 -1.85652669e-01 -9.37091240e-04 -1.15685775e-01 -9.14098372e-02 1.53596081e-04 -1.48460697e-02 -6.89506770e-02 7.86505022e-02 -1.14174638e-01 -1.12337691e-01 1.79358332e-01 -1.14174638e-01 2.61094720e-01 -9.30419895e-02 -2.77729641e-02 1.79358332e-01 -1.26971209e-01] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.1230817 -0.01772103 0.10123383 -0.17996557 -0.15613399 0.33112149 0.19626387 0.13472828 0.22078576 -0.14349029 -0.01772103 0.09708895 0.01597344 -0.15895991 -0.11067713 0.0944811 -0.16185351 0.02381504 0.15940959 0.06995018] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-1.62509877e-01 1.96586931e-04 -2.13701908e-01 -1.90087648e-01 -2.45296804e-02 5.25893869e-01 -1.43780504e-01 1.33663227e-01 2.07893658e-01 2.56510202e-01 1.79191966e-02 1.00664423e-01 -2.68071102e-01 -1.48065702e-01 2.07893658e-01 1.94166330e-01 1.58584096e-01 2.40344402e-01 3.34173954e-01 -1.83002575e-01] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.15895533 -0.24368162 0.09543243 -0.02462755 0.02808345 -0.19716056 -0.30965565 -0.00159706 -0.34057097 -0.27937756 0.27461281 -0.11751103 0.15319639 0.28026335 -0.2745044 0.12036141 0.258547 -0.15578765 -0.27947275 0.31370912] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.12939871 -0.29550498 -0.29550498 0.04507322 -0.86046441 0.04633419 -0.30085342 0.51083741 -0.21480015 -0.02723749 -0.30085342 0.21904133 -0.20610146 0.02568618 -0.30085342 0.22732159 0.75383998 0.63112469 0.25499174 0.02568618] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.15803179 -0.32949007 0.19583921 0.23989072 0.3773136 0.0493635 0.01015015 -0.42084139 0.18061399 0.32376035 -0.2872934 -0.22941013 -0.18740362 -0.32949007 -0.0553052 -0.18740362 0.4576419 -0.26684971 -0.75042498 0.82559404] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.16429693 -0.05953361 0.19268523 -0.80057921 -0.80057921 -0.25096837 0.16493547 0.17746011 -0.30649449 0.25592368 -0.24080498 0.10949581 -0.80057921 0.08929875 0.05266268 0.34818232 0.34539869 0.73739397 0.34539869 0.45109321] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-1.01041465 0.6769341 0.18634114 -0.56438333 -0.86631542 0.01228142 0.00664992 0.01238326 0.8363323 0.39980249 -0.22620295 0.0567035 0.29026172 0.49223375 0.21853837 0.8363323 -0.2558315 -0.35516163 0.99894797 -0.10581409] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.22397949 -0.49792892 0.12727917 0.40473107 -0.14634167 0.12727917 -0.10844863 -0.35627266 -0.17180601 -0.2323993 -0.04555128 -0.67738544 -0.33092043 0.04862077 -0.27991444 -0.23183489 -0.29172847 0.45865578 -0.0369923 0.17750813] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.03997731 -0.26298605 -0.03505778 0.36804648 -0.40507659 0.82870964 0.05658403 0.39406693 -0.80897881 -0.90608217 -0.27253906 0.24298799 0.55256845 0.00663591 0.23646123 0.39091647 0.19081159 -0.40507659 0.12392553 0.24298799] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.27434967 1.07125609 -0.11347337 0.0131704 0.20060594 -0.04766181 0.0131704 0.72593348 0.36236 -0.52099943 -0.92902303 -1.22277639 -0.43531496 -0.03870626 -0.43531496 -0.3975795 -0.38086975 -0.29152459 0.05087351 0.18412675] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.19103629 -0.31524416 -0.26126218 0.87710133 -0.46093851 0.07771033 -0.04538816 0.52247813 -0.04538816 0.49759033 -0.52431554 0.24866998 -0.10806027 0.44531093 0.84808803 -0.14581761 -0.57636381 -0.04538816 0.24866998 -0.00282352] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.24169333 0.44216306 0.31152568 0.06790507 -0.95810478 0.50491205 0.20688562 0.48225617 0.03764439 0.03764439 0.01369532 -0.56781539 0.06790507 0.32101604 -0.43307734 -0.64039264 0.32101604 0.00735451 0.94987007 0.26930089] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.13640584 0.14953103 -0.33906323 -0.63601233 0.10202265 -0.61991243 0.06560523 -0.31264713 0.0674406 -0.39010563 -0.74441528 -0.18834314 -0.44587111 0.26027481 0.01348978 -0.39010563 -0.74441528 0.11124509 -0.40867478 -0.44587111] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.85631074 0.30659706 0.67049506 0.25883433 0.4137883 0.21259755 -0.04694597 -0.15082234 -0.28714614 -0.39160819 -0.39300665 0.19759301 -0.0706839 -0.15082234 -0.71587523 0.01297259 0.4137883 -0.38139666 0.42230256 0.04231632] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 0.17931287 0.03554744 -0.94826256 0.41241172 -0.58941228 -0.04641572 0.19969435 -0.46132096 0.25429969 -0.33720037 -0.37091214 0.38517996 -0.45561305 -0.42393392 0.18087674 -0.38718502 0.01282607 -0.41621306 -0.8466388 0.86728814] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [ 4.99166150e-01 -5.41908065e-01 -2.47885432e-01 -1.18027258e-01 4.16522247e-04 4.86384387e-01 6.32482720e-02 5.43485800e-01 -3.08394754e-03 2.24501284e-01 -3.13716983e-01 -3.17495134e-01 1.38125509e-02 1.91516087e-01 4.36957080e-01 2.24501284e-01 7.08037332e-01 4.45948071e-01 5.09234208e-01 -4.13535067e-01] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.07580145 0.71903918 -0.45159991 0.31907449 -0.07580145 -0.07580145 0.45287744 0.10521451 -0.25685586 0.417786 0.94735788 -0.31859168 0.21189882 -0.91830799 0.32879482 0.49394208 0.11472548 -0.91830799 0.287602 0.57953649] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.2075635 -0.42078408 -0.04252632 -0.43707503 -0.49137211 -1.22308595 -0.80004177 -0.27448058 0.06863841 -0.48242304 0.07432289 0.15032601 -0.6831743 0.7683776 0.15032601 -0.34455267 -0.08100273 -1.22308595 0.478017 0.478017 ] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) np.matmul(X, W) shi ge sha: [-0.4766563 0.04789788 0.22861644 0.66204694 -0.32653629 -0.08000718 -0.32653629 -0.43026918 -0.41561254 0.29113037 -0.04315364 -0.32653629 0.06779483 1.02686635 -0.69230274 0.75893406 0.00795066 0.33677789 -0.42463201 -0.17071597] np.matmul(X, W)= (20,) np.matmul(error, X).shape= (1,) W.shape= (1,) err.shape= (20,) X.shape= (20, 1) Plotting the results... Regression lines start from the lightest line, with the darkest, black line as the last line. Do you see it getting closer to the data over each iteration?

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!