问题
Consider this minimal runnable example:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
import numpy as np
import matplotlib.pyplot as plt
max = 30
step = 0.5
n_steps = int(30/0.5)
x = np.arange(0,max,step)
x = np.cos(x)*(max-x)/max
y = np.roll(x,-1)
y[-1] = x[-1]
shape = (n_steps,1,1)
batch_shape = (1,1,1)
x = x.reshape(shape)
y = y.reshape(shape)
model = Sequential()
model.add(LSTM(50, return_sequences=True, stateful=True, batch_input_shape=batch_shape))
model.add(LSTM(50, return_sequences=True, stateful=True))
model.add(Dense(1))
model.compile(loss='mse', optimizer='rmsprop')
for i in range(1000):
model.reset_states()
model.fit(x,y,nb_epoch=1, batch_size=1)
p = model.predict(x, batch_size=1)
plt.clf()
plt.axis([-1,31, -1.1, 1.1])
plt.plot(x[:, 0, 0], '*')
plt.plot(y[:,0,0],'o')
plt.plot(p[:,0,0],'.')
plt.draw()
plt.pause(0.001)
As stated in the keras API https://keras.io/layers/recurrent/
the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch
So I'm using batch_size = 1 and I'm trying to predict the next value in the decaying cos-function for each timestep. The prediction, or the red dots in the picture below should go into the green circles for the script to predict it correctly, however it doesn't converge... Have any idea to make it learn?
回答1:
The problem lied in a calling model.fit for each epoch separately. In this case optimizer parameters are reset what was harmful for a training process. Other thing is calling reset_states also before prediction - as if it wasn't called - the states from fit are starting states for prediction what also might be harmful. The final code is following:
for epoch in range(1000):
model.reset_states()
tot_loss = 0
for batch in range(n_steps):
batch_loss = model.train_on_batch(x[batch:batch+1], y[batch:batch+1])
tot_loss+=batch_loss
print "Loss: " + str(tot_loss/float(n_steps))
model.reset_states()
p = model.predict(x, batch_size=1)
来源:https://stackoverflow.com/questions/42811746/train-keras-stateful-lstm-return-seq-true-not-learning