Returning probabilities in a classification prediction in Keras?

后端 未结 2 1284
情话喂你
情话喂你 2021-01-02 03:02

I am trying to make a simple proof-of-concept where I can see the probabilities of different classes for a given prediction.

However, everything I try seems to only

2条回答
  •  慢半拍i
    慢半拍i (楼主)
    2021-01-02 03:36

    Keras predict indeed returns probabilities, and not classes.

    Cannot reproduce your issue with my system configuration:

    Python version 2.7.12
    Tensorflow version 1.3.0
    Keras version 2.0.9
    Numpy version 1.13.3
    

    Here is my prediction output for your x_slice with the loaded model (trained for 20 epochs, as in your code):

    print(prev_model.predict(x_slice))
    # Result: 
    [[  1.00000000e+00   3.31656316e-37   1.07806675e-21   7.11765177e-30
        2.48000320e-31   5.34837679e-28   3.12470132e-24   4.65175406e-27
        8.66994134e-31   5.26426367e-24]
     [  0.00000000e+00   5.34361977e-30   3.91144999e-35   0.00000000e+00
        1.00000000e+00   0.00000000e+00   1.05583665e-36   1.01395577e-29
        0.00000000e+00   1.70868685e-29]
     [  3.99137559e-38   1.00000000e+00   1.76682222e-24   9.33333581e-31
        3.99846307e-15   1.17745576e-24   1.87529709e-26   2.18951752e-20
        3.57518280e-17   1.62027896e-28]
     [  6.48006586e-26   1.48974980e-17   5.60530329e-22   1.81973780e-14
        9.12573406e-10   1.95987500e-14   8.08566866e-27   1.17901132e-12
        7.33970447e-13   1.00000000e+00]
     [  2.01602060e-16   6.58242856e-14   1.00000000e+00   6.84244084e-09
        1.19809885e-16   7.94907624e-14   3.10690434e-19   8.02848586e-12
        4.68330721e-11   5.14736501e-15]
     [  2.31014903e-35   1.00000000e+00   6.02224725e-21   2.35928828e-23
        7.50006509e-15   4.06930881e-22   1.13288827e-24   4.20440718e-17
        4.95182972e-17   1.85492109e-18]
     [  0.00000000e+00   0.00000000e+00   0.00000000e+00   1.00000000e+00
        0.00000000e+00   6.30200370e-27   0.00000000e+00   5.19937755e-33
        1.63205659e-31   1.21508034e-20]
     [  1.44608573e-26   1.00000000e+00   1.78712268e-18   6.84598301e-19
        1.30042071e-11   2.53873986e-14   5.83169942e-17   1.20201071e-12
        2.21844570e-14   3.75015198e-15]
     [  0.00000000e+00   6.29184453e-34   9.22474943e-29   0.00000000e+00
        1.00000000e+00   3.05067233e-34   1.43097161e-28   1.34234082e-29
        4.28647272e-36   9.29760838e-34]
     [  4.68828449e-30   5.55172479e-20   3.26705529e-19   9.99999881e-01
        3.49577992e-22   1.27715460e-11   4.99185615e-36   1.19164204e-20
        4.21086124e-16   1.52631387e-07]]
    

    I suspect some rounding issue when printing (or you have trained for much more epochs, and your probabilities for the training set have gotten very close to 1)...

    To convince yourself that you indeed get probabilities and not class predictions, I suggest to try getting predictions from your model trained for a single epoch; normally you should see much less 1.0's - here is the case here for a model trained for epochs=1:

    print(model.predict(x_slice))
    # Result: 
    
    [[  9.99916673e-01   5.36548761e-08   6.10747229e-05   8.21199933e-07
        6.64725164e-08   6.78853041e-07   9.09637220e-06   4.56192402e-06
        1.62688798e-06   5.23997733e-06]
     [  7.59836894e-07   1.78043920e-05   1.79073555e-04   2.95592145e-05
        9.98031914e-01   1.75839632e-05   5.90557102e-06   1.27705920e-03
        3.94643757e-06   4.36416740e-04]
     [  4.48473330e-08   9.99895334e-01   2.82608235e-05   5.33154832e-07
        9.78453227e-06   1.58954310e-06   3.38150176e-06   5.26260410e-05
        8.09341054e-06   3.28643267e-07]
     [  7.38236849e-07   4.80247072e-05   2.81726116e-05   4.77648537e-05
        7.21933879e-03   2.52177160e-05   3.88786475e-07   3.56770557e-04
        2.83472677e-04   9.91990149e-01]
     [  5.03611082e-05   2.69402866e-04   9.92011130e-01   4.68175858e-03
        9.57477605e-05   4.26214538e-04   7.66683661e-05   7.05923303e-04
        1.45670515e-03   2.26032615e-04]
     [  1.36330849e-10   9.99994516e-01   7.69141934e-07   1.44130311e-07
        9.52201333e-07   1.45219332e-07   4.43408908e-07   6.93398249e-07
        2.18685204e-06   1.50741769e-07]
     [  2.39427478e-09   3.75754922e-07   3.89349816e-06   9.99889374e-01
        1.85837867e-09   1.16176770e-05   1.89989760e-11   3.12301523e-07
        1.13220040e-05   8.29571582e-05]
     [  1.45760115e-08   9.99900222e-01   3.67058942e-06   4.04857201e-06
        1.97999962e-05   7.85745397e-06   8.13850420e-06   1.87294081e-05
        2.81870762e-05   9.38157609e-06]
     [  7.52560858e-09   8.84437856e-09   9.71140025e-07   5.20911703e-10
        9.99986649e-01   3.12135370e-07   1.06521384e-05   1.25693066e-06
        7.21853368e-08   5.21001624e-08]
     [  8.67672298e-08   2.17907742e-04   2.45352840e-06   9.95455265e-01
        1.43749105e-06   1.51766278e-03   1.83744309e-08   3.83995541e-07
        9.90309782e-05   2.70584645e-03]]
    

提交回复
热议问题