Strange loss curve while training LSTM with Keras

佐手、 提交于 2019-12-22 08:57:18

问题


I'm trying to train an LSTM for some a binary classification problem. When I plot loss curve after the training, there are strange picks in it. Here are some examples:

Here is the basic code

model = Sequential()
model.add(recurrent.LSTM(128, input_shape = (columnCount,1), return_sequences=True))
model.add(Dropout(0.5))
model.add(recurrent.LSTM(128, return_sequences=False))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(optimizer='adam', 
             loss='binary_crossentropy', 
             metrics=['accuracy'])

new_train = X_train[..., newaxis]

history = model.fit(new_train, y_train, nb_epoch=500, batch_size=100, 
                    callbacks = [EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=2, verbose=0, mode='auto'), 
                                 ModelCheckpoint(filepath="model.h5", verbose=0, save_best_only=True)],
                    validation_split=0.1)

# list all data in history
print(history.history.keys())
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

I don't understand why do that picks occur? Any ideas?


回答1:


There are many possibilities why something like this occurs:

  1. Your parameters trajectory changed its basin of attraction - this means that your system left a stable trajectory and switched to another one. This was probably due to randomization like e.g. batch sampling or dropout.

  2. LSTM instability- LSTMs are believed to be extremely unstable in terms of training. It was also reported that very often it's really time consuming for them to stabilize.

Due to the latest research (e.g. from here) I would recommend you decreasing the batch size and leaving it for more epochs. I would also try to check if e.g. topology of a network is not to complexed (or plain) in terms of amount of patterns it need to learn. I would also try switch to either GRU or SimpleRNN.




回答2:


This question is old, but I've seen this happen before when re-starting training from a checkpoint. If the spike corresponded to a break in training, you may be inadvertently resetting some of the weights.



来源:https://stackoverflow.com/questions/45027234/strange-loss-curve-while-training-lstm-with-keras

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!