Understanding Spark RandomForest featureImportances results

纵然是瞬间 提交于 2019-11-27 15:25:08

问题


I'm using RandomForest.featureImportances but I don't understand the output result.

I have 12 features, and this is the output I get.

I get that this might not be an apache-spark specific question but I cannot find anywhere that explains the output.

// org.apache.spark.mllib.linalg.Vector = (12,[0,1,2,3,4,5,6,7,8,9,10,11],
 [0.1956128039688559,0.06863606797951556,0.11302128590305296,0.091986700351889,0.03430651625283274,0.05975817050022879,0.06929766152519388,0.052654922125615934,0.06437052114945474,0.1601713590349946,0.0324327322375338,0.057751258970832206])

回答1:


Given a tree ensemble model, RandomForest.featureImportances computes the importance of each feature.

This generalizes the idea of "Gini" importance to other losses, following the explanation of Gini importance from "Random Forests" documentation by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.

For collections of trees, which includes boosting and bagging, Hastie et al. suggests to use the average of single tree importances across all trees in the ensemble.

And this feature importance is calculated as followed :

  • Average over trees:
    • importance(feature j) = sum (over nodes which split on feature j) of the gain, where gain is scaled by the number of instances passing through node
    • Normalize importances for tree to sum to 1.
  • Normalize feature importance vector to sum to 1.

References: Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001. - 15.3.2 Variable Importance page 593.

Let's go back to your importance vector :

val importanceVector = Vectors.sparse(12,Array(0,1,2,3,4,5,6,7,8,9,10,11), Array(0.1956128039688559,0.06863606797951556,0.11302128590305296,0.091986700351889,0.03430651625283274,0.05975817050022879,0.06929766152519388,0.052654922125615934,0.06437052114945474,0.1601713590349946,0.0324327322375338,0.057751258970832206))

First, let's sort this features by importance :

importanceVector.toArray.zipWithIndex
            .map(_.swap)
            .sortBy(-_._2)
            .foreach(x => println(x._1 + " -> " + x._2))
// 0 -> 0.1956128039688559
// 9 -> 0.1601713590349946
// 2 -> 0.11302128590305296
// 3 -> 0.091986700351889
// 6 -> 0.06929766152519388
// 1 -> 0.06863606797951556
// 8 -> 0.06437052114945474
// 5 -> 0.05975817050022879
// 11 -> 0.057751258970832206
// 7 -> 0.052654922125615934
// 4 -> 0.03430651625283274
// 10 -> 0.0324327322375338

So what does this mean ?

It means that your first feature (index 0) is the most important feature with a weight of ~ 0.19 and your 11th (index 10) feature is the least important in your model.




回答2:


Adding on to the previous answer:

One of the problems that I faced was in dumping the result in the form of (featureName,Importance) as a csv.One can get the metadata for the input vector of features as

 val featureMetadata = predictions.schema("features").metadata

This is the json structure for this metadata:

{
"ml_attr": {
              "attrs":
                  {"numeric":[{idx:I,name:N},...],
                   "nominal":[{vals:V,idx:I,name:N},...]},
                   "num_attrs":#Attr
                   }
            }
}            

Code for extracting the importance:

val attrs =featureMetadata.getMetadata("ml_attr").getMetadata("attrs")
val f: (Metadata) => (Long,String) = (m => (m.getLong("idx"), m.getString("name")))
val nominalFeatures= attrs.getMetadataArray("nominal").map(f)
val numericFeatures = attrs.getMetadataArray("numeric").map(f)
val features = (numericFeatures ++ nominalFeatures).sortBy(_._1)

val fImportance = pipeline.stages.filter(_.uid.startsWith("rfc")).head.asInstanceOf[RandomForestClassificationModel].featureImportances.toArray.zip(features).map(x=>(x._2._2,x._1)).sortBy(-_._2)

//Save It now
sc.parallelize(fImportance.toSeq, 1).map(x => s"${x._1},${x._2}").saveAsTextFile(fPath)


来源:https://stackoverflow.com/questions/37878519/understanding-spark-randomforest-featureimportances-results

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!