What is the difference between the predict and predict_on_batch methods of a Keras model?

前端 未结 3 746
被撕碎了的回忆
被撕碎了的回忆 2020-12-17 09:52

According to the keras documentation:

predict_on_batch(self, x)
Returns predictions for a single batch of samples.

However, there does not

3条回答
  •  生来不讨喜
    2020-12-17 10:29

    I just want to add something that does not fit in a comment. It seems that predict check carefully the output shape:

    class ExtractShape(keras.engine.topology.Layer):
        def call(self, x):
            return keras.backend.sum(x, axis=0)
        def compute_output_shape(self, input_shape):
            return input_shape
    
    a = keras.layers.Input((None, None))
    b = ExtractShape()(a)
    m = keras.Model(a, b)
    m.compile(optimizer=keras.optimizers.Adam(), loss='binary_crossentropy')
    A = np.ones((5,4,3))
    

    Then:

    In [163]: m.predict_on_batch(A)
    Out[163]: 
    array([[5., 5., 5.],
           [5., 5., 5.],
           [5., 5., 5.],
           [5., 5., 5.]], dtype=float32)
    In [164]: m.predict_on_batch(A).shape
    Out[164]: (4, 3)
    

    But:

    In [165]: m.predict(A)
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
     in ()
    
    ----> 1 m.predict(A)
    
    ~/miniconda3/envs/ccia/lib/python3.6/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
       1746         f = self.predict_function
       1747         return self._predict_loop(f, ins, batch_size=batch_size,
    -> 1748                                   verbose=verbose, steps=steps)
       1749 
       1750     def train_on_batch(self, x, y,
    
    ~/miniconda3/envs/ccia/lib/python3.6/site-packages/keras/engine/training.py in _predict_loop(self, f, ins, batch_size, verbose, steps)
       1306                         outs.append(np.zeros(shape, dtype=batch_out.dtype))
       1307                 for i, batch_out in enumerate(batch_outs):
    -> 1308                     outs[i][batch_start:batch_end] = batch_out
       1309                 if verbose == 1:
       1310                     progbar.update(batch_end)
    
    ValueError: could not broadcast input array from shape (4,3) into shape (5,3)
    

    I am not sure if this is a bug really.

提交回复
热议问题