Extracting Class Probabilities from SparkR ML Classification Functions

天涯浪子 提交于 2020-01-01 12:14:07

问题


I'm wondering if it's possible (using the built in features of SparkR or any other workaround), to extract the class probabilities of some of the classification algorithms that included in SparkR. Particular ones of interest are.

spark.gbt()
spark.mlp()
spark.randomForest()

Currently, when I use the predict function on these models I am able to extract the predictions, but not the actual probabilities or "confidence."

I've seen several other questions that are similar to this topic, but none that are specific to SparkR, and many have not been answered in regards to Spark's most recent updates.


回答1:


i ran into the same problem, and following this answer now use SparkR:::callJMethod to transform the probability DenseVector (which R cannot deserialize) to an Array (which R reads as a List). It's not very elegant or fast, but it does the job:

  denseVectorToArray <- function(dv) {
    SparkR:::callJMethod(dv, "toArray")
  }

e.g.: start your spark session

#library(SparkR)
#sparkR.session(master = "local") 

generate toy data

data <- data.frame(clicked = base::sample(c(0,1),100,replace=TRUE),
                  someString = base::sample(c("this", "that"),
                                           100, replace=TRUE), 
                  stringsAsFactors=FALSE)

trainidxs <- base::sample(nrow(data), nrow(data)*0.7)
traindf <- as.DataFrame(data[trainidxs,])
testdf <- as.DataFrame(data[-trainidxs,])

train a random forest and run predictions:

rf <- spark.randomForest(traindf, 
                        clicked~., 
                        type = "classification", 
                        maxDepth = 2, 
                        maxBins = 2,
                        numTrees = 100)

predictions <- predict(rf, testdf)

collect your predictions:

collected = SparkR::collect(predictions)    

now extract the probabilities:

collected$probabilities <- lapply(collected$probability, function(x)  denseVectorToArray(x))     
str(probs) 

ofcourse, the function wrapper around SparkR:::callJMethod is a bit of an overkill. You can also use it directly, e.g. with dplyr:

withprobs = collected %>%
            rowwise() %>%
            mutate("probabilities" = list(SparkR:::callJMethod(probability,"toArray"))) %>%
            mutate("prob0" = probabilities[[1]], "prob1" = probabilities[[2]])


来源:https://stackoverflow.com/questions/41942974/extracting-class-probabilities-from-sparkr-ml-classification-functions

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!