SparkR 2.0 Classification: how to get performance matrices?

僤鯓⒐⒋嵵緔 提交于 2019-12-25 09:02:06

问题


How to get performance matrices in sparkR classification, e.g., F1 score, Precision, Recall, Confusion Matrix

# Load training data
df <- read.df("data/mllib/sample_libsvm_data.txt", source = "libsvm")
training <- df
 testing <- df

# Fit a random forest classification model with spark.randomForest
model <- spark.randomForest(training, label ~ features, "classification", numTrees = 10)

# Model summary
  summary(model)

 # Prediction
  predictions <- predict(model, testing)
  head(predictions)

 # Performance evaluation 

I've tried caret::confusionMatrix(testing$label,testing$prediction) it shows error:

   Error in unique.default(x, nmax = nmax) :   unique() applies only to vectors

回答1:


Caret's confusionMatrix will not work, since it needs R dataframes while your data are in Spark dataframes.

One not recommended way for getting your metrics is to "collect" locally your Spark dataframes to R using as.data.frame, and then use caret etc.; but this means that your data can fit in the main memory of your driver machine, in which case of course you have absolutely no reason to use Spark...

So, here is a way to get the accuracy in a distributed manner (i.e. without collecting data locally), using the iris data as an example:

sparkR.version()
# "2.1.1"

df <- as.DataFrame(iris)
model <- spark.randomForest(df, Species ~ ., "classification", numTrees = 10)
predictions <- predict(model, df)
summary(predictions)
# SparkDataFrame[summary:string, Sepal_Length:string, Sepal_Width:string, Petal_Length:string, Petal_Width:string, Species:string, prediction:string]

createOrReplaceTempView(predictions, "predictions")
correct <- sql("SELECT prediction, Species FROM predictions WHERE prediction=Species")
count(correct)
# 149
acc = count(correct)/count(predictions)
acc
# 0.9933333

(Regarding the 149 correct predictions out of 150 samples, if you do a showDF(predictions, numRows=150) you will see indeed that there is a single virginica sample misclassified as versicolor).



来源:https://stackoverflow.com/questions/45400833/sparkr-2-0-classification-how-to-get-performance-matrices

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!