How to use the classWeight option of model.fit in Tensorflow.js?

霸气de小男生 提交于 2019-12-24 11:30:14

问题


In my training set, my classes are not represented equally. Therefore, I'm trying to use the classWeight option of the model.fit function. Quote from the docs:

classWeight ({[classIndex: string]: number}) Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples from this class during training. This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

So this seems to be exactly what I'm looking for, but I'm not able to figure out how to use it. The classIndex is nowhere else present in the documentation and I don't understand how a class can be written as an index. As my code is one-hot encoded I even tried to use the index of that (0 for [1,0,0], etc.), but it did not work (see below).

Here is a minimal code sample:

const xs = tf.tensor2d([ [1,2,3,4,5,6], /* ... */ ]); // shape: [1000,6]
const ys = tf.tensor2d([ [1,0,0], /* ... */ ]);       // shape: [1000,3] (one-hot encoded classes)

const model = tf.sequential();
// ... some layers

await model.fit(xs, ys, {
    /* ... */
    classWeight: {
        0: 10000, // doesn't do anything
        1: 1,
        2: 1,
    },
});

I would expect my model to pay "more attention" to my first class ([1,0,0]) and therefore predict it more often. But it seems Tensorflow.js is just ignoring the parameter.

来源:https://stackoverflow.com/questions/56040846/how-to-use-the-classweight-option-of-model-fit-in-tensorflow-js

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!