问题
I'm following the MNIST tutorial here for recognizing handwritten characters.
I'm able to load and recognize handwritten digits without issue, but now I want to train the model again on new images (specifically one at a time).
For some reason, when I choose a training size equal to 1, all my predictions become NaN.
If I pick a value >=2, it works fine.
Train Function:
async function train(model, data)
{
const TRAIN_DATA_SIZE = 1; // WHEN THIS IS 1, CAUSES PREDICT TO OUTPUT NaN
const [trainXs, trainYs] = tf.tidy(() =>
{
const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
return [
d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
d.labels
];
});
console.log(trainXs.dataSync());
console.log(trainYs.dataSync());
return model.fit(trainXs, trainYs);
}
The code for nextTrainBatch
is here.
Example output for prediction:
currentTensor = tf.tensor2d(inputs, [1, PIXELSSQUARED]);
const output = model.predict(currentTensor.reshape([1, 28, 28, 1]));
const prediction_value = Array.from(output.argMax(1).dataSync());
console.log(output.dataSync());
When training size is 2 or greater:
Float32Array(10) [3.308702423154841e-9, 5.89648436744028e-8, 0.00005333929220796563, 0.8063259720802307, 7.401082784824764e-13, 1.1464327087651327e-7, 6.5924318955190575e-12, 0.1936144232749939, 0.000004253268798493082, 0.000001676815713835822]
When training size is 1:
Float32Array(10) [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN]
回答1:
The model is reaching a numerical instability. Use an optimizer such as SGD might help. However, using a batch size of 1 is practically not a good idea as the model might oscillate around optimum values
I want the user to select the correct value after the model has made it's prediction e.g. Make Prediction, Select Correct Output, Retrain based on this information
if you want to train further, you would need to have the data that matches the model inputShape. So the value predicted and the result chosen by the user will be collected and It can be used to train the model further
// the model has been trained
y = model.predict(x) // predict
Suppose that the user will validate the result y. To train further
model.fit(x, y)
And the cycle continue
来源:https://stackoverflow.com/questions/59331219/training-with-an-input-size-of-1-causes-nan-in-subsequent-predictions