How to set class_weight in keras package of R?

前端 未结 2 473
猫巷女王i
猫巷女王i 2020-12-30 13:44

I am using keras package in R to train a deep learning model. My data set is highly imbalanced. Therefore, I want to set class_weight argument in t

相关标签:
2条回答
  • 2020-12-30 14:21

    Class_weight needs to be a list, so

        history <- model %>% fit(
            trainData, trainClass, 
            epochs = 5, batch_size = 1000, 
            class_weight = list("0"=1,"1"=30),
            validation_split = 0.2
        )
    

    seems to work. Keras internally uses a function called as_class_weights to change the list to a python-dictionary (see https://rdrr.io/cran/keras/src/R/model.R).

         class_weight <- dict(list('0'=1,'1'=10))
         class_weight
         >>> {0: 1.0, 1: 10.0}
    

    Looks just like the python dictionary that you mentioned above.

    0 讨论(0)
  • 2020-12-30 14:37

    I found a generic solution in Python solution, so I converted into R:

    counter=funModeling::freq(Y_data_aux_tr, plot=F) %>% select(var, frequency)
    majority=max(counter$frequency)
    counter$weight=ceil(majority/counter$frequency)
    
    
    l_weights=setNames(as.list(counter$weight), counter$var)
    

    Using it:

     fit(..., class_weight = l_weights)
    

    An advice if you are using fit_generator: since the weights are based on frequency, having a different number of training-validation samples may bias the validation results. They should be equally-sized.

    0 讨论(0)
提交回复
热议问题