Wrong number of dimensions on model.fit

馋奶兔 提交于 2019-12-04 02:46:50

You are trying to run a RNN. This means that you want to include previous time steps in your calculation. In order to do so, you have to preprocess your data before giving it to the SimpleRNN layer.

For simplicity, let us assume that instead of 88 samples with 88 features each you have 8 samples with 4 features each. Now, when using a RNN you will have to decide on a maximum for the backpropagation (i.e. number of previous time steps that are included in the calculation). In this case, you could choose to include a maximum of 2 previous time steps. Therefore, for the calculation of the weights of the RNN you will have to provide at each time step the input of the current time step (with its 4 features) and the input of the 2 previous time steps (with 4 features each). Just like in this visualization:

sequence    sample0  sample1  sample2  sample3  sample4  sample5  sample6 sample7       
   0        |-----------------------|
   1                 |-----------------------|
   2                          |-----------------------|
   3                                   |-----------------------|
   4                                             |----------------------|
   5                                                      |----------------------|

So instead of giving a (nb_samples, nb_features) matrix as an input to the SimpleRNN, you will have to give it a (nb_sequences, nb_timesteps, nb_features) shaped input. In this example, it means that instead of giving a (8x4) input you give it a (5x3x4) input.

The keras Embedding layer might do this job but in this case you can also write a short code for it:

input = np.random.rand(8,4)
nb_timesteps = 3    # 2 (previous) + 1 (current)
nb_sequences = input.shape[0] - nb_timesteps    #8-3=5

input_3D = np.array([input[i:i+nb_timesteps] for i in range(nb_sequences)])

The error is probably because your input dimensions are not in the format of:

(nb_samples, timesteps, input_dim)

It is expecting 3 dimensions, and you're providing only 2 of them (88,88).

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!