How to one hot encode several categorical variables in R

前端 未结 5 929
再見小時候
再見小時候 2020-12-01 06:55

I\'m working on a prediction problem and I\'m building a decision tree in R, I have several categorical variables and I\'d like to one-hot encode them consistently in my tra

5条回答
  •  情歌与酒
    2020-12-01 07:15

    I have a tidy solution that gives more control to user over the entire process. My solution has a JavaScript component that splits each cell and stores the column names as JSON. Then I use tidyjson::spread_all function to spread JSON into different column names.

    JavaScript component that you need to save as encoder.js:

    function oneHotSplitEncoder(inputStrArray, prefix, spliterRegExStr, spliterRegExStrOptions){
      if (Array.isArray(inputStrArray)) {
        return inputStrArray.map(function(str) {
          try{
            if(typeof(str) === 'string' && typeof(spliterRegExStr)==='string' && typeof(spliterRegExStrOptions)==='string' && typeof(prefix) === 'string'){
              return JSON.stringify(
                str.split(RegExp(spliterRegExStr, spliterRegExStrOptions))
                   .reduce(function(p, component){
                     p[prefix + component] = 1;
                       return p;
                   }, {})
              )
            } else {
              return NaN;
            }
          } catch (e) {
            console.warn("\n"+e+"\n"+str+"\n"+spliterRegExStr+' string expected')
            return NaN;
          }
        });
      } else {    
        console.warn("Error: oneHotSplitEncoder function needs array type inputs");
        return NaN;
      }
    };
    

    R components:

    library('dplyr')
    js <<- V8::v8(); 
    js$source("encoder.js");
    oneHotSplitEncoder <- function(inputStrArray, prefix, spliterRegExStr, spliterRegExStrOptions)
      js$call("oneHotSplitEncoder", inputStrArray, prefix, spliterRegExStr, spliterRegExStrOptions)
    
    df_one_hot <- df %>%
      mutate(
        fooColumn = oneHotSplitEncoder(fooColumn, 'prefix.', ' *[,;] *', 'g')
      ) %>%
      bind_cols(tidyjson::spread_all(.$fooColumn) %>% select(-document.id) %>% replace(is.na(.), 0))
    

提交回复
热议问题