R: LIME returns error on different feature numbers when it's not the case

心不动则不痛 提交于 2019-12-01 12:08:44

We can trace the error to predict_model, which calls predict.textmodel_nb_fitted (I used only the first 10 rows of train_raw to speed up computation):

traceback()
# 7: stop("feature set in newdata different from that in training set")
# 6: predict.textmodel_nb_fitted(x, newdata = newdata, type = type, 
#        ...)
# 5: predict(x, newdata = newdata, type = type, ...)
# 4: predict_model.default(explainer$model, case_perm, type = o_type)
# 3: predict_model(explainer$model, case_perm, type = o_type)
# 2: explain.data.frame(train_raw[1:10, 1:5], explainer, n_labels = 1, 
#        n_features = 5, cols = 2, verbose = 0)
# 1: lime::explain(train_raw[1:10, 1:5], explainer, n_labels = 1, 
#        n_features = 5, cols = 2, verbose = 0)

The problem is that predict.textmodel_nb_fitted expects a dfm, not a data frame. For example, predict(nb_model, test_raw[1:5]) gives you the same "feature set in newdata different from that in training set" error. However, explain takes a data frame as its x argument.

A solution is to write a custom textmodel_nb_fitted method for predict_model that does the necessary object conversions before calling predict.textmodel_nb_fitted:

predict_model.textmodel_nb_fitted <- function(x, newdata, type, ...) {
  X <- corpus(newdata)
  X <- dfm_select(dfm(X), x$data$x)   
  res <- predict(x, newdata = X, ...)
  switch(
   type,
   raw = data.frame(Response = res$nb.predicted, stringsAsFactors = FALSE),
   prob = as.data.frame(res$posterior.prob, check.names = FALSE)
  )  
}

This gives us

explanation <- lime::explain(train_raw[1:10, 1:5], 
                              explainer,
                              n_labels = 1,
                              n_features = 5,
                              cols = 2,
                              verbose = 0)
explanation[1, 1:5]
#       model_type case label label_prob    model_r2
# 1 classification    1 FALSE  0.9999986 0.001693861
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!