How to set class_weight in keras package of R?

前端 未结 2 481
猫巷女王i
猫巷女王i 2020-12-30 13:44

I am using keras package in R to train a deep learning model. My data set is highly imbalanced. Therefore, I want to set class_weight argument in t

2条回答
  •  醉话见心
    2020-12-30 14:21

    Class_weight needs to be a list, so

        history <- model %>% fit(
            trainData, trainClass, 
            epochs = 5, batch_size = 1000, 
            class_weight = list("0"=1,"1"=30),
            validation_split = 0.2
        )
    

    seems to work. Keras internally uses a function called as_class_weights to change the list to a python-dictionary (see https://rdrr.io/cran/keras/src/R/model.R).

         class_weight <- dict(list('0'=1,'1'=10))
         class_weight
         >>> {0: 1.0, 1: 10.0}
    

    Looks just like the python dictionary that you mentioned above.

提交回复
热议问题