What is the role of “Flatten” in Keras?

前端 未结 6 1545
醉话见心
醉话见心 2020-12-07 07:51

I am trying to understand the role of the Flatten function in Keras. Below is my code, which is a simple two-layer network. It takes in 2-dimensional data of sh

6条回答
  •  没有蜡笔的小新
    2020-12-07 08:10

    If you read the Keras documentation entry for Dense, you will see that this call:

    Dense(16, input_shape=(5,3))
    

    would result in a Dense network with 3 inputs and 16 outputs which would be applied independently for each of 5 steps. So, if D(x) transforms 3 dimensional vector to 16-d vector, what you'll get as output from your layer would be a sequence of vectors: [D(x[0,:]), D(x[1,:]),..., D(x[4,:])] with shape (5, 16). In order to have the behavior you specify you may first Flatten your input to a 15-d vector and then apply Dense:

    model = Sequential()
    model.add(Flatten(input_shape=(3, 2)))
    model.add(Dense(16))
    model.add(Activation('relu'))
    model.add(Dense(4))
    model.compile(loss='mean_squared_error', optimizer='SGD')
    

    EDIT: As some people struggled to understand - here you have an explaining image:

提交回复
热议问题