Keras network producing inverse predictions

前端 未结 2 1194
故里飘歌
故里飘歌 2021-02-06 05:58

I have a timeseries dataset and I am trying to train a network so that it overfits (obviously, that\'s just the first step, I will then battle the overfitting).

2条回答
  •  面向向阳花
    2021-02-06 06:36

    EDIT: After author's comments I do not believe this is the correct answer but I will keep it posted for posterity.

    Great question and the answer is due to how the Time_generator works! Apparently instead of grabbing x,y pairs with the same index (e.g input x[0] to output target y[0]) it grabs target with offset 1 (so x[0] to y[1]).

    Thus plotting y with offset 1 will produce the desired fit.

    Code to simulate:

    import keras 
    import matplotlib.pyplot as plt
    
    x=np.random.uniform(0,10,size=41).reshape(-1,1)
    x[::2]*=-1
    y=x[1:]
    x=x[:-1]
    train_gen = keras.preprocessing.sequence.TimeseriesGenerator(
            x,
            y,
            length=1,
            sampling_rate=1,
            batch_size=1,
            shuffle=False
        )
    
    model = keras.models.Sequential()
    model.add(keras.layers.LSTM(100, input_shape=(1, 1), return_sequences=False))
    model.add(keras.layers.Dense(1))
    
    
    model.compile(
        loss="mse",
        optimizer="rmsprop",
        metrics=[keras.metrics.mean_squared_error]
    )
    model.optimizer.lr/=.1
    
    history = model.fit_generator(
        train_gen,
        epochs=20,
        steps_per_epoch=100
    )
    

    Proper plotting:

    y_pred = model.predict_generator(train_gen)
    plot_points = 39
    epochs = range(1, plot_points + 1)
    pred_points = np.resize(y_pred[:plot_points], (plot_points,))
    
    target_points = train_gen.targets[1:plot_points+1] #NOTICE DIFFERENT INDEXING HERE
    
    plt.plot(epochs, pred_points, 'b', label='Predictions')
    plt.plot(epochs, target_points, 'r', label='Targets')
    plt.legend()
    plt.show()
    

    Output, Notice how the fit is no longer inverted and is mostly very accurate:

    This is how it looks when the offset is incorrect:

提交回复
热议问题