What is the role of “Flatten” in Keras?

前端 未结 6 1539
醉话见心
醉话见心 2020-12-07 07:51

I am trying to understand the role of the Flatten function in Keras. Below is my code, which is a simple two-layer network. It takes in 2-dimensional data of sh

6条回答
  •  醉酒成梦
    2020-12-07 08:12

    short read:

    Flattening a tensor means to remove all of the dimensions except for one. This is exactly what the Flatten layer do.

    long read:

    If we take the original model (with the Flatten layer) created in consideration we can get the following model summary:

    Layer (type)                 Output Shape              Param #   
    =================================================================
    D16 (Dense)                  (None, 3, 16)             48        
    _________________________________________________________________
    A (Activation)               (None, 3, 16)             0         
    _________________________________________________________________
    F (Flatten)                  (None, 48)                0         
    _________________________________________________________________
    D4 (Dense)                   (None, 4)                 196       
    =================================================================
    Total params: 244
    Trainable params: 244
    Non-trainable params: 0
    

    For this summary the next image will hopefully provide little more sense on the input and output sizes for each layer.

    The output shape for the Flatten layer as you can read is (None, 48). Here is the tip. You should read it (1, 48) or (2, 48) or ... or (16, 48) ... or (32, 48), ...

    In fact, None on that position means any batch size. For the inputs to recall, the first dimension means the batch size and the second means the number of input features.

    The role of the Flatten layer in Keras is super simple:

    A flatten operation on a tensor reshapes the tensor to have the shape that is equal to the number of elements contained in tensor non including the batch dimension.


    Note: I used the model.summary() method to provide the output shape and parameter details.

提交回复
热议问题