Imbalanced classes in multi-class classification problem

北战南征 提交于 2020-06-25 09:18:28

问题


I'm trying to use TensorFlow's DNNClassifier for my multi-class (softmax) classification problem with 4 different classes. I have an imbalanced dataset with the following distribution:

  • Class 0: 14.8%
  • Class 1: 35.2%
  • Class 2: 27.8%
  • Class 3: 22.2%

How do I assign the weights for the DNNClassifier's weight_column for each class? I know how to code this, but I am wondering what values should I give for each class.


回答1:


there are various options to build weights for un unbalance classification problems. one of the most common is to use directly the class counts in train to estimate sample weights. this option is easily computed by sklearn. The 'balanced' mode uses the values of y to automatically adjust weights inversely proportional to class frequencies.

what we try to do in the example below is to 'incorporate' the compute_sample_weight method in fitting our DNNClassifier. as label distribution, I used the same expressed in the question

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.utils.class_weight import compute_sample_weight

train_size = 1000
test_size = 200
columns = 30

## create train data
y_train = np.random.choice([0,1,2,3], train_size, p=[0.15, 0.35, 0.28, 0.22])
x_train = pd.DataFrame(np.random.uniform(0,1, (train_size,columns)).astype('float32'))
x_train.columns = [str(i) for i in range(columns)]

## create train weights
weight = compute_sample_weight(class_weight='balanced', y=y_train)
x_train['weight'] = weight.astype('float32')

## create test data
y_test = np.random.choice([0,1,2,3], test_size, p=[0.15, 0.35, 0.28, 0.22])
x_test = pd.DataFrame(np.random.uniform(0,1, (test_size,columns)).astype('float32'))
x_test.columns = [str(i) for i in range(columns)]

## create test weights
x_test['weight'] = np.ones(len(y_test)).astype('float32') ## set them all to 1

## utility functions to pass data to DNNClassifier
def train_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((dict(x_train), y_train))
    dataset = dataset.shuffle(1000).repeat().batch(10)
    return dataset

def eval_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((dict(x_test), y_test))
    return dataset.shuffle(1000).repeat().batch(10)

## define DNNClassifier
classifier = tf.estimator.DNNClassifier(
    feature_columns=[tf.feature_column.numeric_column(str(i), shape=[1]) for i in range(columns)],
    weight_column = tf.feature_column.numeric_column('weight'),
    hidden_units=[10],
    n_classes=4,
)

## train DNNClassifier
classifier.train(input_fn=lambda: train_input_fn(), steps=100)

## make evaluation
eval_results = classifier.evaluate(input_fn=eval_input_fn, steps=1)

considering that our weights are built as a function of the target we have to set them to 1 in our test data because the labels are unknown.




回答2:


I would highly suggest you to use undersampling (if you have enough data to do so) or oversampling with SMOTE, in imblearn library, you can find it here

As an experienced ML engineer, I can say that none of the "weighting" methods will ever work for you, XGBoost has a parameter called scale_pos_weight, or you could use logistic regression with class_weight="balanced" but they're quite insignificant because the problem is not about the estimator, it's about your data. So I would highly suggest you to play with your data instead of assigning weights.




回答3:


you can try the following formula to balanced all classes:

weight_for_class_X = total_samples_size / size_of_class_X / num_classes

for exampe:

num_CLASS_0: 10000   
num_CLASS_1: 1000
num_CLASS_2: 100

wgt_for_0 = 11100 / 10000 / 3 = 0.37  
wgt_for_1 = 11100 / 1000 / 3 = 3.7
wgt_for_2 = 11100 / 100 / 3 = 37

# so after one epoch training the total weights of each class will be:
total_wgt_of_0 = 0.37 * 10000 = 3700
total_wgt_of_1 = 3.7 * 1000 = 3700
total_wgt_of_2 = 37 * 100 = 3700


来源:https://stackoverflow.com/questions/52383967/imbalanced-classes-in-multi-class-classification-problem

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!