Convert “regular” plot to ggplot object (and then plotly)

无人久伴 提交于 2021-01-07 02:49:57

问题


I am using the R programming language. I incorporated my own code along with a lengthy tutorial over here : https://michael.hahsler.net/SMU/EMIS7332/R/viz_classifier.html . In the end, I produced a visual "plot" (see the end of this code, "final_plot")

library(cluster)
library(Rtsne)
library(dplyr)

library(randomForest)
library(caret)
library(ggplot2)
library(plotly)


#PART 1 : Create Data

#generate 4 random variables : response_variable ~ var_1 , var_2, var_3

var_1 <- rnorm(10000,1,4)
var_2<-rnorm(10000,10,5)
var_3 <- sample( LETTERS[1:4], 10000, replace=TRUE, prob=c(0.1, 0.2, 0.65, 0.05) )
response_variable <- sample( LETTERS[1:2], 10000, replace=TRUE, prob=c(0.4, 0.6) )


#put them into a data frame called "f"
f <- data.frame(var_1, var_2, var_3, response_variable)

#declare var_3 and response_variable as factors
f$response_variable = as.factor(f$response_variable)
f$var_3 = as.factor(f$var_3)

#create id
f$ID <- seq_along(f[,1])

#PART 2: random forest

#split data into train set and test set
index = createDataPartition(f$response_variable, p=0.7, list = FALSE)
train = f[index,]
test = f[-index,]

#create random forest statistical model
rf = randomForest(response_variable ~ var_1 + var_2 + var_3, data=train, ntree=20, mtry=2)

#have the model predict the test set
pred = predict(rf, test, type = "prob")
labels = as.factor(ifelse(pred[,2]>0.5, "A", "B"))
confusionMatrix(labels, test$response_variable)

#PART 3: Visualize in 2D (source: https://dpmartin42.github.io/posts/r/cluster-mixed-types)

gower_dist <- daisy(test[, -c(4,5)],
                    metric = "gower")

gower_mat <- as.matrix(gower_dist)

labels = data.frame(labels)
labels$ID = test$ID


tsne_obj <- Rtsne(gower_dist,  is_distance = TRUE)

tsne_data <- tsne_obj$Y %>%
    data.frame() %>%
    setNames(c("X", "Y")) %>%
    mutate(cluster = factor(labels$labels),
           name = labels$ID)

plot = ggplot(aes(x = X, y = Y), data = tsne_data) +
    geom_point(aes(color = labels$labels))

plotly_plot = ggplotly(plot)


a = tsne_obj$Y
a = data.frame(a)
data = a
data$class = labels$labels


decisionplot <- function(model, data, class = NULL, predict_type = "class",
                         resolution = 100, showgrid = TRUE, ...) {
    
    if(!is.null(class)) cl <- data[,class] else cl <- 1
    data <- data[,1:2]
    k <- length(unique(cl))
    
    plot(data, col = as.integer(cl)+1L, pch = as.integer(cl)+1L, ...)
    
    # make grid
    r <- sapply(data, range, na.rm = TRUE)
    xs <- seq(r[1,1], r[2,1], length.out = resolution)
    ys <- seq(r[1,2], r[2,2], length.out = resolution)
    g <- cbind(rep(xs, each=resolution), rep(ys, time = resolution))
    colnames(g) <- colnames(r)
    g <- as.data.frame(g)
    
    ### guess how to get class labels from predict
    ### (unfortunately not very consistent between models)
    p <- predict(model, g, type = predict_type)
    if(is.list(p)) p <- p$class
    p <- as.factor(p)
    
    if(showgrid) points(g, col = as.integer(p)+1L, pch = ".")
    
    z <- matrix(as.integer(p), nrow = resolution, byrow = TRUE)
    contour(xs, ys, z, add = TRUE, drawlabels = FALSE,
            lwd = 2, levels = (1:(k-1))+.5)
    
    invisible(z)
}


model <- randomForest(class ~ ., data=data, mtry=2, ntrees=500)
 final_plot = decisionplot(model, data, class = "class", main = "rf (1)")

Now, I would like to turn this into an "interactive" plot using the plotly library in R:

plotly_plot = ggplotly(final_plot)

But I got the following error:

Error in UseMethod("ggplotly", p) : 
  no applicable method for 'ggplotly' applied to an object of class "c('matrix', 'array', 'integer', 'numeric')"

Is there a way to convert "Regular" plots to "ggplot" in R? Can my "final_plot" be passed through a "plotly" object?


回答1:


As @mischva11 commented, I think it is easier to create the ggplot from scratch. Your function is actually returning a matrix and not a kind of plot object. the plot and countour functions draw the plots directly in the active graphic window. I am not sure if there is a way to convert these base plots to ggplot (maybe there is).

Here is a way to create a similar plot as you have in ggplot and then convert it to plotly.

decisionplot <- function(model, data, class = NULL, predict_type = "class", resolution = 100, showgrid = TRUE) {
  
  # create ggplot with minimal theme and no grid lines
  g <- ggplot() + theme_minimal() + theme(panel.grid = element_blank())

  # make grid values for contour and grid points
  r <- sapply(data[ ,1:2], range, na.rm = TRUE)
  xs <- seq(r[1,1], r[2,1], length.out = resolution)
  ys <- seq(r[1,2], r[2,2], length.out = resolution)
  g1 <- cbind(rep(xs, each=resolution), rep(ys, time = resolution))
  colnames(g1) <- colnames(r)
  g1 <- as.data.frame(g1)
    
  ### guess how to get class labels from predict
  ### (unfortunately not very consistent between models)
  p <- predict(model, g1, type = predict_type)
  if(is.list(p)) p <- p$class
  g1$class <- as.factor(p)
  
  if(showgrid) {        
    # add labeled grid points to ggplot
    g <- g + geom_point(data=g1, aes(x=X1, y=X2, col = class), shape = ".") 
  }

  # add points to plot
  g <- g + geom_point(data=data, aes(x=X1, y=X2, col = class, shape = class)) 
  
  # add contour curves
  g <- g + geom_contour(data=g1, aes(x=X1, y=X2, z=as.integer(class)), colour='black', linetype=1, size=rel(0.2), bins=length(unique(g1$class)))
  
  # return ggplot object
  return(g)
}    

# get ggplot object
final_plot <- decisionplot(model, data, class = "class")

# convert to plotly
ggplotly(final_plot)

This works. The final plot does not look that good, but you can play around with the parameters.

One thing that in my opinion could make the final plot better is to use geom_raster to plot the regions with different label predictions (instead of plotting the small points). However, when I did this the conversion to plotly took forever (I actually gave up). I think there is an issue in the conversion to plotly when you use discrete labels for geom_raster, because when i converted the discrete labels to numeric values, it converted to plotly very fast.

Another option is to work directly in plot_ly, but I don't have much experience on this.

Hope this works.



来源:https://stackoverflow.com/questions/65406196/convert-regular-plot-to-ggplot-object-and-then-plotly

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!