LSTM Keras API predicting multiple outputs

安稳与你 提交于 2019-11-29 03:59:48

问题


I'm training an LSTM model using as input a sequence of 50 steps of 3 different features laid out as below:

#x_train
[[[a0,b0,c0],.....[a49,b49,c49]],
  [a1,b1,c1]......[a50,b50,c50]],
  ...
  [a49,b49,c49]...[a99,b99,c99]]]

Using the following dependent variable

#y_train
[a50, a51, a52, ... a99]

The code below works to predict just a, how do I get it to predict and return a vector of [a,b,c] at a given timestep?

def build_model():
model = Sequential()

model.add(LSTM(
    input_shape=(50,3),
    return_sequences=True, units=50))
model.add(Dropout(0.2))

model.add(LSTM(
    250,
    return_sequences=False))
model.add(Dropout(0.2))

model.add(Dense(1))
model.add(Activation("linear"))

model.compile(loss="mse", optimizer="rmsprop")
return model

回答1:


The output of every layer is based on how many cells/units/filters it has.

Your output has 1 feature because Dense(1...) has only one cell.

Just making it a Dense(3...) would solve your problem.


Now, if you want the output to have the same number of time steps as the input, then you need to turn on return_sequences = True in all your LSTM layers.

The output of an LSTM is:

  • (Batch size, units) - with return_sequences=False
  • (Batch size, time steps, units) - with return_sequences=True

Then you use a TimeDistributed layer wrapper in your following layers to work as if they also had time steps (it will basically preserve the dimension in the middle).

def build_model():
    model = Sequential()

    model.add(LSTM(
        input_shape=(50,3),
        return_sequences=True, units=50))
    model.add(Dropout(0.2))

    model.add(LSTM(
        250,
        return_sequences=True))
    model.add(Dropout(0.2))

    model.add(TimeDistributed(Dense(3)))
    model.add(Activation("linear"))

    model.compile(loss="mse", optimizer="rmsprop")
    return model


来源:https://stackoverflow.com/questions/46102332/lstm-keras-api-predicting-multiple-outputs

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!