I am trying to understand the role of the Flatten function in Keras. Below is my code, which is a simple two-layer network. It takes in 2-dimensional data of sh
Here I would like to present another alternative to Flatten function. This may help to understand what is going on internally. The alternative method adds three more code lines. Instead of using
#==========================================Build a Model
model = tf.keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28, 3)))#reshapes to (2352)=28x28x3
model.add(layers.experimental.preprocessing.Rescaling(1./255))#normalize
model.add(keras.layers.Dense(128,activation=tf.nn.relu))
model.add(keras.layers.Dense(2,activation=tf.nn.softmax))
model.build()
model.summary()# summary of the model
we can use
#==========================================Build a Model
tensor = tf.keras.backend.placeholder(dtype=tf.float32, shape=(None, 28, 28, 3))
model = tf.keras.models.Sequential()
model.add(keras.layers.InputLayer(input_tensor=tensor))
model.add(keras.layers.Reshape([2352]))
model.add(layers.experimental.preprocessing.Rescaling(1./255))#normalize
model.add(keras.layers.Dense(128,activation=tf.nn.relu))
model.add(keras.layers.Dense(2,activation=tf.nn.softmax))
model.build()
model.summary()# summary of the model
In the second case, we first create a tensor (using a placeholder) and then create an Input layer. After, we reshape the tensor to flat form. So basically,
Create tensor->Create InputLayer->Reshape == Flatten
Flatten is a convenient function, doing all this automatically. Of course both ways has its specific use cases. Keras provides enough flexibility to manipulate the way you want to create a model.