What is the role of “Flatten” in Keras?

前端 未结 6 1550
醉话见心
醉话见心 2020-12-07 07:51

I am trying to understand the role of the Flatten function in Keras. Below is my code, which is a simple two-layer network. It takes in 2-dimensional data of sh

6条回答
  •  余生分开走
    2020-12-07 08:26

    Here I would like to present another alternative to Flatten function. This may help to understand what is going on internally. The alternative method adds three more code lines. Instead of using

    #==========================================Build a Model
    model = tf.keras.models.Sequential()
    
    model.add(keras.layers.Flatten(input_shape=(28, 28, 3)))#reshapes to (2352)=28x28x3
    model.add(layers.experimental.preprocessing.Rescaling(1./255))#normalize
    model.add(keras.layers.Dense(128,activation=tf.nn.relu))
    model.add(keras.layers.Dense(2,activation=tf.nn.softmax))
    
    model.build()
    model.summary()# summary of the model
    

    we can use

        #==========================================Build a Model
        tensor = tf.keras.backend.placeholder(dtype=tf.float32, shape=(None, 28, 28, 3))
        
        model = tf.keras.models.Sequential()
        
        model.add(keras.layers.InputLayer(input_tensor=tensor))
        model.add(keras.layers.Reshape([2352]))
    model.add(layers.experimental.preprocessing.Rescaling(1./255))#normalize
        model.add(keras.layers.Dense(128,activation=tf.nn.relu))
        model.add(keras.layers.Dense(2,activation=tf.nn.softmax))
        
        model.build()
        model.summary()# summary of the model
    

    In the second case, we first create a tensor (using a placeholder) and then create an Input layer. After, we reshape the tensor to flat form. So basically,

    Create tensor->Create InputLayer->Reshape == Flatten
    

    Flatten is a convenient function, doing all this automatically. Of course both ways has its specific use cases. Keras provides enough flexibility to manipulate the way you want to create a model.

提交回复
热议问题