Train Keras Stateful LSTM return_seq=true not learning

一世执手 提交于 2019-12-23 03:04:14

问题


Consider this minimal runnable example:

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
import numpy as np
import matplotlib.pyplot as plt


max = 30
step = 0.5
n_steps = int(30/0.5)

x = np.arange(0,max,step)
x = np.cos(x)*(max-x)/max

y = np.roll(x,-1)
y[-1] = x[-1]

shape = (n_steps,1,1)
batch_shape = (1,1,1)

x = x.reshape(shape)
y = y.reshape(shape)

model = Sequential()
model.add(LSTM(50, return_sequences=True, stateful=True, batch_input_shape=batch_shape))
model.add(LSTM(50, return_sequences=True, stateful=True))

model.add(Dense(1))

model.compile(loss='mse', optimizer='rmsprop')

for i in range(1000):
    model.reset_states()
    model.fit(x,y,nb_epoch=1, batch_size=1)
    p = model.predict(x, batch_size=1)
    plt.clf()
    plt.axis([-1,31, -1.1, 1.1])
    plt.plot(x[:, 0, 0], '*')
    plt.plot(y[:,0,0],'o')
    plt.plot(p[:,0,0],'.')
    plt.draw()
    plt.pause(0.001)

As stated in the keras API https://keras.io/layers/recurrent/

the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch

So I'm using batch_size = 1 and I'm trying to predict the next value in the decaying cos-function for each timestep. The prediction, or the red dots in the picture below should go into the green circles for the script to predict it correctly, however it doesn't converge... Have any idea to make it learn?


回答1:


The problem lied in a calling model.fit for each epoch separately. In this case optimizer parameters are reset what was harmful for a training process. Other thing is calling reset_states also before prediction - as if it wasn't called - the states from fit are starting states for prediction what also might be harmful. The final code is following:

for epoch in range(1000):
    model.reset_states()
    tot_loss = 0
    for batch in range(n_steps):
        batch_loss = model.train_on_batch(x[batch:batch+1], y[batch:batch+1])
        tot_loss+=batch_loss

    print "Loss: " + str(tot_loss/float(n_steps))
    model.reset_states()
    p = model.predict(x, batch_size=1)


来源:https://stackoverflow.com/questions/42811746/train-keras-stateful-lstm-return-seq-true-not-learning

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!